Can the model learn how to add two 2 digit numbers, from a shuffled dataset?

The "add_two_digits" notebook sharply divided the test dataset (0..89 for first adding term) and validation dataset (90..99). This hurts generalization, because the distribution of the training data is not representative of the whole data.

Here we shuffle before splitting datasets, hoping that the model is now training on a more representative distribution of the entire data.

In [1]:
from gptbench import Train, empty_config

We'll load the data file '../data/add2.txt' used in the add_two_digits notebook, which can be created by running this script in the ../dataprep/ folder:
```
python prepare_addition.py ../data/add2.txt 2 --sep="\n"
```


In [2]:
# Opening it - the first 100 chars
with open('../data/add2.txt', 'r', newline=None) as f:
    data = f.read()
print("first:", data[:100])
print("last:", data[-100:])

first: 0+0=0
0+1=1
0+2=2
0+3=3
0+4=4
0+5=5
0+6=6
0+7=7
0+8=8
0+9=9
0+10=10
0+11=11
0+12=12
0+13=13
0+14=14

last: 99+90=189
99+91=190
99+92=191
99+93=192
99+94=193
99+95=194
99+96=195
99+97=196
99+98=197
99+99=198



In [3]:
# We'll load these data samples into two CharLineDatasets,
# taking care to shuffle the data before splitting train and validation data

In [4]:
# create the Train object - we'll name this model add2
ben = Train('add2_shuffled', seed=0xADD2B055)

# set training log periods to avoid cluttering the training output
ben.set_train_log_periods(sample_period=500, dot_period=1, loss_period=0)

# set train and validation datasets
ben.set_datasets(class_name='charline', # id for the PaddedLineCharDataset class
                 train_path='../data/add2.txt', 
                 train_split=0.9,
                 pre_shuffle=True)

# set config settings that will override the default values
cfg = empty_config()
cfg.model.set(n_layer=6, n_head=6, n_embd=90, block_size=16) # our model parameters - block_size is big enough for aa+bb=ccc
cfg.sample.set(top=1, max_batch_size=256) # note the top_k(1) - always pick the best item
cfg.trainer.set(batch_size=128)

# and init a new model with config
ben.init_new(cfg)

Initializing new model add2_shuffled
Dataset train_path: ../data/add2.txt, val_path: None, train_split: 0.9, vocab_size: 13
Model params: 0.59M


In [5]:
# both train and validation datasets use shuffled data from the add2.txt source file
'train:',ben.train_dataset.get_data()[:10], 'validation:', ben.val_dataset.get_data()[:10]

('train:',
 ['25+47=72',
  '57+16=73',
  '3+59=62',
  '24+18=42',
  '53+3=56',
  '2+3=5',
  '28+67=95',
  '72+13=85',
  '54+52=106',
  '26+21=47'],
 'validation:',
 ['18+25=43',
  '9+72=81',
  '64+75=139',
  '74+21=95',
  '54+37=91',
  '18+74=92',
  '42+11=53',
  '48+57=105',
  '31+41=72',
  '5+38=43'])

In [6]:
# Let's train for 10000 batch iterations. 
# Each dot means a batch was trained.
# Train and validation losses are evaluated each 100 iterations (or iters). 
# Also each 500 iters a random sampling is taken.
ben.train(iter_count=10000)

Training
.Iter 1 (0.014 epoch): loss train=2.1426, val=2.1427, eval->2.1427
==> Saving model at iter=1, eval loss->2.1427 
...................................................................................................
Iter 100 (1.422 epoch): loss train=1.0557, val=1.0553, eval->1.0553
==> Saving model at iter=100, eval loss->1.0553 
....................................................................................................
Iter 200 (2.845 epoch): loss train=0.8535, val=0.8551, eval->0.8551
==> Saving model at iter=200, eval loss->0.8551 
....................................................................................................
Iter 300 (4.267 epoch): loss train=0.7834, val=0.7834, eval->0.7834
==> Saving model at iter=300, eval loss->0.7834 
....................................................................................................
Iter 400 (5.690 epoch): loss train=0.7356, val=0.7349, eval->0.7349
==> Saving model at iter=400, eval loss->0.7349 
......

In [18]:
# The current state loss info - the last evaluated losses for train and validation dataset:
ben.state

{'n_samples': 1164800,
 'train_loss': 0.43667128682136536,
 'val_loss': 0.4380030333995819,
 'eval_loss': 0.4380030333995819}

In [17]:
# The last saved checkpoint info is:
ben.last_saved_state

{'n_samples': 1164800,
 'train_loss': 0.43667128682136536,
 'val_loss': 0.4380030333995819,
 'eval_loss': 0.4380030333995819}

In [19]:
# they are the same, so no point in loading the last saved. If not we would do:
# ben.load()
# ben.state

In [20]:
# take a few samples:
ben.sample('1+1=')
ben.sample('34+7=')
ben.sample('78+99=')

1+1=2
34+7=41
78+99=177


In [11]:
# Much better now - all three are correct.
# Let's measure the accuracy of entire training dataset - this should be mostly memorization,
# as the model trained on these data
train_ds = ben.train_dataset

#split each aa+bb=cc into a prompt: 'aa+bb=' and an answer 'cc'
q,a=train_ds.get_data_split(0, len(train_ds), sep='=', sep_included=-1)

print(q[:3])
print(a[:3])

['96+30=', '91+85=', '75+11=']
['126', '176', '86']


In [12]:
# Measure the accuracy - how good was the memorization? 
# This may take a while...
ben.measure_accuracy(q,a)

1.0

In [13]:
# Perfect accuracy!
# What about the accuracy of the validation dataset, on which the model never trained?
val_ds = ben.val_dataset

#split each aa+bb=cc into a prompt: 'aa+bb=' and an answer 'cc'
q,a=val_ds.get_data_split(0, len(val_ds), sep='=', sep_included=-1)

print(q[:3])
print(a[:3])

['31+19=', '80+54=', '96+68=']
['50', '134', '164']


In [14]:
# Validation dataset has sums starting in 90+..99+..., for example 90+2=92.
# The model did however see the reversed addition of 90.100 numbers, for example 2+90=92.
# Did it somehow learn the commutative property of addition?
ben.measure_accuracy(q,a)

1.0

In [15]:
# Also perfect acuracy - this means it's generalizing beyond training data. For two digits.
# What about three digit sums?
ben.sample('101+120=')
ben.sample('131+17=')
ben.sample('990+9=')

101+120=
131+17=2
990+9=188


This model doesn't work for three digits!

Perhaps a new project: three digits addition?