In [1]:
"""
create a joint dataset with shakespeare + addition + algorithmic reasoning 
from folder shakespeare / addition_bal / algo_reasoning, respectively
if no files, run corresponding prepare.py or prepare.ipynb file
"""

import pickle
import requests
import numpy as np
import os


In [2]:
# use the existing text/addition data

out_dir = 'shakespeare_add_ar_mixed'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

addition_ar_path = f'algo_reasoning/add_examples_algorithmic_3000.txt'
addition_add_path = f'addition_bal/add_examples.txt'

with open(addition_ar_path, 'r') as f:
    data_ar = f.read()
    print(len(data_ar))

with open(addition_add_path, 'r') as f:
    data_add = f.read()
    print(len(data_add))

text_file_path = 'shakespeare/input.txt'
with open(text_file_path, 'r') as f:
    data_text = f.read()
    print(len(data_text))


813510
120027
1115394


In [3]:
chars = sorted(list(set(data_ar + data_add)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")


all the unique characters: 
 +,-./0123456789:<=>ACDEINT[]acdeghinprstu
vocab size: 43


In [4]:
chars = sorted(list(set(data_text)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 64


In [5]:
data_all = data_text + data_ar + data_add
# get all the unique characters that occur in this text
chars = sorted(list(set(data_all)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string



all the unique characters: 
 !&'+,-./0123456789:;<=>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz
vocab size: 80


In [6]:
# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(f'{out_dir}/meta.pkl', 'wb') as f:
    pickle.dump(meta, f)

# length of dataset in characters:  1115394
# all the unique characters:
#  !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
# vocab size: 65
# train has 1003854 tokens
# val has 111540 tokens

In [7]:
# create the train and test splits
n_ar = len(data_ar) # 130,023
print(n_ar)
train_data = data_ar[:int(n_ar*0.9)]
val_data = data_ar[int(n_ar*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(f'{out_dir}/train_ar.bin')
val_ids.tofile(f'{out_dir}/val_ar.bin')

print(set(train_ids))
print(set(val_ids))

813510
train has 732,159 tokens
val has 81,351 tokens
{0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 26, 28, 29, 30, 34, 39, 45, 52, 53, 54, 56, 57, 58, 60, 61, 62, 67, 69, 71, 72, 73, 74}
{0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 26, 28, 29, 30, 34, 39, 45, 52, 53, 54, 56, 57, 58, 60, 61, 62, 67, 69, 71, 72, 73, 74}


In [8]:
# create the train and test splits
n_add = len(data_add) # 130,023
print(n_add)
train_data = data_add[:int(n_add*0.9)]
val_data = data_add[int(n_add*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(f'{out_dir}/train_add.bin')
val_ids.tofile(f'{out_dir}/val_add.bin')

print(set(train_ids))
print(set(val_ids))

120027
train has 108,024 tokens
val has 12,003 tokens
{0, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 23}
{0, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 23}


In [9]:
# create the train and test splits
n_text = len(data_text) # 130,023
print(n_text)
train_data = data_text[:int(n_text*0.9)]
val_data = data_text[int(n_text*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids2 = np.array(train_ids, dtype=np.uint16)
val_ids2 = np.array(val_ids, dtype=np.uint16)
train_ids2.tofile(f'{out_dir}/train_text.bin')
val_ids2.tofile(f'{out_dir}/val_text.bin')


print(set(train_ids))
print(set(val_ids))

1115394
train has 1,003,854 tokens
val has 111,540 tokens
{0, 1, 2, 3, 4, 6, 7, 8, 13, 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
{0, 1, 2, 4, 6, 7, 8, 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}


In [10]:
# split data_ar with delimiter 'Input:\n' and then add it back in frnt of each element
ar_stripped = data_ar.split('Input:\n')[1:]
ar_stripped = ['Input:\n' + s for s in ar_stripped] 
add_stripped = data_add.split('\n')[:-1]
text_stripped = data_text.split('\n\n')[:-1]

print(ar_stripped[0])
print(add_stripped[0])
print(text_stripped[0])
print(ar_stripped[-1])
print(add_stripped[-1])
print(text_stripped[-1])


num_line_ar = len(ar_stripped)
num_line_add = len(add_stripped)
num_line_text = len(text_stripped)

print(num_line_ar, num_line_add, num_line_text)

Input:
140+201
Target:
<scratch>
[1,4,0] has 3 digits.
[2,0,1] has 3 digits.
[1,4,0] + [2,0,1] , A=[] , C=0 , 0+1+0=1 , A->1 , C->0
[1,4] + [2,0] , A=[1] , C=0 , 4+0+0=4 , A->4 , C->0
[1] + [2] , A=[4,1] , C=0 , 1+2+0=3 , A->3 , C->0
[] + [] , A=[3,4,1] C=0 , END
</scratch>
3 4 1

50+148=198
First Citizen:
Before we proceed any further, hear me speak.
Input:
913+115
Target:
<scratch>
[9,1,3] has 3 digits.
[1,1,5] has 3 digits.
[9,1,3] + [1,1,5] , A=[] , C=0 , 3+5+0=8 , A->8 , C->0
[9,1] + [1,1] , A=[8] , C=0 , 1+1+0=2 , A->2 , C->0
[9] + [1] , A=[2,8] , C=0 , 9+1+0=10 , A->0 , C->1
[] + [] , A=[0,2,8] C=1 , END
</scratch>
1 0 2 8

793+436=1229
SEBASTIAN:
I do; and surely
It is a sleepy language and thou speak'st
Out of thy sleep. What is it thou didst say?
This is a strange repose, to be asleep
With eyes wide open; standing, speaking, moving,
And yet so fast asleep.
3000 10000 7221


In [11]:
# (shakespeare) 72 - (add) 100  - (ar) 30 - 72 -100 - 30 - ...  - ... 21

data_all = ''
add_count = 0
ar_count = 0
shakespeare_count = 0

while True:
    for i in range(72):
        if shakespeare_count < num_line_text:
            data_all = data_all + text_stripped[shakespeare_count] + '\n\n'
            shakespeare_count += 1
    
    for j in range(100):
        if add_count < num_line_add:
            data_all = data_all + add_stripped[add_count] + '\n'
            add_count += 1
    if add_count < num_line_add:
        data_all = data_all + '\n'

    for k in range(30):
        if ar_count < num_line_ar:
            data_all = data_all + ar_stripped[ar_count]
            ar_count += 1
    if ar_count < num_line_ar:
        data_all = data_all + '\n'
    
    if shakespeare_count >= num_line_text and add_count >= num_line_add and ar_count >= num_line_ar:
        break

print(add_count + ar_count + shakespeare_count)

20221


In [12]:
print(data_all)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [13]:
save_file_path = f'{out_dir}/mixed_examples.txt'
with open(save_file_path, 'w+') as f:
    f.write(data_all)

In [14]:
n = len(data_all) # 130,023
print(n)
train_data_all = data_all[:int(n*0.9)]
val_data_all = data_all[int(n*0.9):]

# encode both to integers
train_ids_all = encode(train_data_all)
val_ids_all = encode(val_data_all)
print(f"train has {len(train_ids_all):,} tokens")
print(f"val has {len(val_ids_all):,} tokens")

# export to bin files
train_ids_all = np.array(train_ids_all, dtype=np.uint16)
val_ids_all = np.array(val_ids_all, dtype=np.uint16)
train_ids_all.tofile(f'{out_dir}/train_all.bin')
val_ids_all.tofile(f'{out_dir}/val_all.bin')


print(set(train_ids_all))
print(set(val_ids_all))

2049015
train has 1,844,113 tokens
val has 204,902 tokens
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
{0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}


In [15]:
# changing num_sample for add: [2000, 5000] with fixed ar_num_sample = 3000
# use the existing text/addition data

out_dir = 'shakespeare_add_ar_mixed'
num_samples = [500, 2000, 5000, 10000, 20000, 40000] # [20000] #[2000, 5000]


for num_sample in num_samples:
    addition_ar_path = f'algo_reasoning/add_examples_algorithmic_3000.txt'
    addition_add_path = f'addition_bal/add_examples_{num_sample}.txt'

    with open(addition_ar_path, 'r') as f:
        data_ar = f.read()
        print(len(data_ar))

    with open(addition_add_path, 'r') as f:
        data_add = f.read()
        print(len(data_add))

    text_file_path = 'shakespeare/input.txt'
    with open(text_file_path, 'r') as f:
        data_text = f.read()
        print(len(data_text))

    # split data_ar with delimiter 'Input:\n' and then add it back in frnt of each element
    ar_stripped = data_ar.split('Input:\n')[1:]
    ar_stripped = ['Input:\n' + s for s in ar_stripped] 
    add_stripped = data_add.split('\n')[:-1]
    text_stripped = data_text.split('\n\n')[:-1]

    num_line_ar = len(ar_stripped)
    num_line_add = len(add_stripped)
    num_line_text = len(text_stripped)

    print(num_line_ar, num_line_add, num_line_text)

    # (shakespeare) 72 - (add) 20/50  - (ar) 30 - 72 -20/50 - 30 - ...  - ... 21
    num_add_per_mixing = int(num_line_add / 100)
    num_ar_per_mixing = int(num_line_ar / 100)

    data_all = ''
    add_count = 0
    ar_count = 0
    shakespeare_count = 0

    while True:
        for i in range(72):
            if shakespeare_count < num_line_text:
                data_all = data_all + text_stripped[shakespeare_count] + '\n\n'
                shakespeare_count += 1
        
        for j in range(num_add_per_mixing):
            if add_count < num_line_add:
                data_all = data_all + add_stripped[add_count] + '\n'
                add_count += 1
        if add_count < num_line_add:
            data_all = data_all + '\n'

        for k in range(num_ar_per_mixing):
            if ar_count < num_line_ar:
                data_all = data_all + ar_stripped[ar_count]
                ar_count += 1
        if ar_count < num_line_ar:
            data_all = data_all + '\n'
        
        if shakespeare_count >= num_line_text and add_count >= num_line_add and ar_count >= num_line_ar:
            break

    print(add_count + ar_count + shakespeare_count)

    save_file_path = f'{out_dir}/mixed_examples_ar3000_add{num_sample}.txt'
    with open(save_file_path, 'w+') as f:
        f.write(data_all)

    n = len(data_all) # 130,023
    print(n)
    train_data_all = data_all[:int(n*0.9)]
    val_data_all = data_all[int(n*0.9):]

    # encode both to integers
    train_ids_all = encode(train_data_all)
    val_ids_all = encode(val_data_all)
    print(f"train has {len(train_ids_all):,} tokens")
    print(f"val has {len(val_ids_all):,} tokens")

    # export to bin files
    train_ids_all = np.array(train_ids_all, dtype=np.uint16)
    val_ids_all = np.array(val_ids_all, dtype=np.uint16)
    train_ids_all.tofile(f'{out_dir}/train_all_ar3000_add{num_sample}.bin')
    val_ids_all.tofile(f'{out_dir}/val_all_ar3000_add{num_sample}.bin')


    print(set(train_ids_all))
    print(set(val_ids_all))

813510
5463
1115394
3000 500 7221
10721
1934463
train has 1,741,016 tokens
val has 193,447 tokens
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
{0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
813510
23573
1115394
3000 2000 7221
12221
1952573
train has 1,757,315 tokens
val has 195,258 tokens
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 

In [16]:
# changing num_sample for ar: [1000, 5000] with fixed add_num_sample = 10000

out_dir = 'shakespeare_add_ar_mixed'
num_samples = [500, 1000, 2000, 3000, 5000]


for num_sample in num_samples:
    addition_ar_path = f'algo_reasoning/add_examples_algorithmic_{num_sample}.txt'
    addition_add_path = f'addition_bal/add_examples.txt'

    with open(addition_ar_path, 'r') as f:
        data_ar = f.read()
        print(len(data_ar))

    with open(addition_add_path, 'r') as f:
        data_add = f.read()
        print(len(data_add))

    text_file_path = 'shakespeare/input.txt'
    with open(text_file_path, 'r') as f:
        data_text = f.read()
        print(len(data_text))

    # split data_ar with delimiter 'Input:\n' and then add it back in frnt of each element
    ar_stripped = data_ar.split('Input:\n')[1:]
    ar_stripped = ['Input:\n' + s for s in ar_stripped] 
    add_stripped = data_add.split('\n')[:-1]
    text_stripped = data_text.split('\n\n')[:-1]

    num_line_ar = len(ar_stripped)
    num_line_add = len(add_stripped)
    num_line_text = len(text_stripped)

    print(num_line_ar, num_line_add, num_line_text)

    # (shakespeare) 72 - (add) 100 - (ar) 10/50 - 72 -100 - 10/50 - ...  - ... 21
    num_add_per_mixing = int(num_line_add / 100)
    num_ar_per_mixing = int(num_line_ar / 100)

    data_all = ''
    add_count = 0
    ar_count = 0
    shakespeare_count = 0

    while True:
        for i in range(72):
            if shakespeare_count < num_line_text:
                data_all = data_all + text_stripped[shakespeare_count] + '\n\n'
                shakespeare_count += 1
        
        for j in range(num_add_per_mixing):
            if add_count < num_line_add:
                data_all = data_all + add_stripped[add_count] + '\n'
                add_count += 1
        if add_count < num_line_add:
            data_all = data_all + '\n'

        for k in range(num_ar_per_mixing):
            if ar_count < num_line_ar:
                data_all = data_all + ar_stripped[ar_count]
                ar_count += 1
        if ar_count < num_line_ar:
            data_all = data_all + '\n'
        
        if shakespeare_count >= num_line_text and add_count >= num_line_add and ar_count >= num_line_ar:
            break

    print(add_count + ar_count + shakespeare_count)

    save_file_path = f'{out_dir}/mixed_examples_ar{num_sample}_add10000.txt'
    with open(save_file_path, 'w+') as f:
        f.write(data_all)

    n = len(data_all) # 130,023
    print(n)
    train_data_all = data_all[:int(n*0.9)]
    val_data_all = data_all[int(n*0.9):]

    # encode both to integers
    train_ids_all = encode(train_data_all)
    val_ids_all = encode(val_data_all)
    print(f"train has {len(train_ids_all):,} tokens")
    print(f"val has {len(val_ids_all):,} tokens")

    # export to bin files
    train_ids_all = np.array(train_ids_all, dtype=np.uint16)
    val_ids_all = np.array(val_ids_all, dtype=np.uint16)
    train_ids_all.tofile(f'{out_dir}/train_all_ar{num_sample}_add10000.bin')
    val_ids_all.tofile(f'{out_dir}/val_all_ar{num_sample}_add10000.bin')


    print(set(train_ids_all))
    print(set(val_ids_all))

125028
120027
1115394
500 10000 7221
17721
1360533
train has 1,224,479 tokens
val has 136,054 tokens
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
{0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}
262490
120027
1115394
1000 10000 7221
18221
1497995
train has 1,348,195 tokens
val has 149,800 tokens
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,