Train a transformer model to convert decimal numbers from roman literals, ex:

LVII=57

https://en.wikipedia.org/wiki/Roman_numerals


In [1]:
from gptbench import Train, empty_config, LogFlag

In [2]:
ben = Train('roman2dec')

# set datasets
ben.set_datasets('padlinechar', train_path='../data/roman2decimal10000.txt', train_split=(9000-1)/10000) # -1 because numbers start at 1

# set config settings
cfg = empty_config()
cfg.train.log_period=0
cfg.model.set(n_layer=8, n_head=8, n_embd=96, block_size=26)
cfg.sample.set(top=1, max_batch_size=256) # top_k(1) - always pick the best item
cfg.train.set(sample_period=-5)
cfg.trainer.set(batch_size=128)

# and init a new model with config
if ben.can_resume() and False:
    ben.init_resume(cfg)
else:
    ben.init_new(cfg)
# print(do.get_config().dump(1))


New random seed 2050019466
Initializing new model roman2dec
Dataset train_path: ../data/roman2decimal10000.txt, val_path: None, train_split: 0.8999, vocab_size: 19
Model params: 0.90M


In [3]:
ben.val_dataset.get_src_data()[:70]

'MMMMMMMMM=9000\nMMMMMMMMMI=9001\nMMMMMMMMMII=9002\nMMMMMMMMMIII=9003\nMMMM'

In [4]:
ben.train(iter_count=5000)

Training
Batches per epoch: 70
iter 0 (0.000 epoch): loss train=2.4947, val=2.6246, eval->2.6246
==> Saving model at iter=0, eval loss->2.6246 
9X
.CUDA max memory used: 380.12M
...................................................................................................iter 100 (1.422 epoch): loss train=1.1859, val=1.3579, eval->1.3579
==> Saving model at iter=100, eval loss->1.3579 
....................................................................................................iter 200 (2.845 epoch): loss train=0.8872, val=1.0967, eval->1.0967
==> Saving model at iter=200, eval loss->1.0967 
....................................................................................................iter 300 (4.267 epoch): loss train=0.7197, val=0.9320, eval->0.9320
==> Saving model at iter=300, eval loss->0.9320 
....................................................................................................iter 400 (5.690 epoch): loss train=0.6164, val=0.8905, eval->0.8905
==> 

KeyboardInterrupt: 

In [34]:
ds = ben.val_dataset
q,a=ds.sample_split(0, len(ds), sep='=', sep_included=-1)

errs = []
def test(q,a,g):
    global errs
    
    res = float(a == g)
    if not res:
        errs += [f"{q}: {a} != {g}"]
    return res
    
print(ben.measure_accuracy(q,a, test_fn=test))
print(len(errs),errs[:20])

0.0
['13000=: MMMMMMMMMMMMM != MMMMMMMMMMMM', '13001=: MMMMMMMMMMMMMI != MMMMMMMMMMMMI', '13002=: MMMMMMMMMMMMMII != MMMMMMMMMMMMII', '13003=: MMMMMMMMMMMMMIII != MMMMMMMMMMMMIII', '13004=: MMMMMMMMMMMMMIV != MMMMMMMMMMMMIV', '13005=: MMMMMMMMMMMMMV != MMMMMMMMMMMMV', '13006=: MMMMMMMMMMMMMVI != MMMMMMMMMMMMVI', '13007=: MMMMMMMMMMMMMVII != MMMMMMMMMMMMVII', '13008=: MMMMMMMMMMMMMVIII != MMMMMMMMMMMMVIII', '13009=: MMMMMMMMMMMMMIX != MMMMMMMMMMMMIX', '13010=: MMMMMMMMMMMMMX != MMMMMMMMMMMMCC', '13011=: MMMMMMMMMMMMMXI != MMMMMMMMMMMMCCI', '13012=: MMMMMMMMMMMMMXII != MMMMMMMMMMMMCCII', '13013=: MMMMMMMMMMMMMXIII != MMMMMMMMMMMMXIII', '13014=: MMMMMMMMMMMMMXIV != MMMMMMMMMMMMXIV', '13015=: MMMMMMMMMMMMMXV != MMMMMMMMMMMMXV', '13016=: MMMMMMMMMMMMMXVI != MMMMMMMMMMMMCCVI', '13017=: MMMMMMMMMMMMMXVII != MMMMMMMMMMMMCCVII', '13018=: MMMMMMMMMMMMMXVIII != MMMMMMMMMMMMXVIII', '13019=: MMMMMMMMMMMMMXIX != MMMMMMMMMMMMCCIX']


In [5]:
ds = ben.train_dataset
q,a=ds.sample_split(0, len(ds), sep='=', sep_included=-1)

errs = []
def test(q,a,g):
    global errs
    
    res = float(a == g)
    if not res:
        errs += [f"{q}: {a} != {g}"]
    return res
    
print(ben.measure_accuracy(q,a, test_fn=test))
print(len(errs), errs[:20])

0.9994443827091899
5 ['II=: 2 != 1', 'III=: 3 != 1', 'IV=: 4 != 1', 'IX=: 9 != 10', 'MM=: 2000 != 1000']


In [6]:
ben.state

{'n_samples': 434176,
 'train_loss': 0.33498454093933105,
 'val_loss': 1.0654536485671997,
 'eval_loss': 1.0654536485671997}

In [7]:
ben.init_resume(cfg)


New random seed 766280474
Loading checkpoint from ./models/roman2dec/
Checkpoint: iter=700 (9.957 epoch), loss train=0.4330 val=0.7836 eval->0.7836
Dataset train_path: ../data/roman2decimal10000.txt, val_path: None, train_split: 0.8999, vocab_size: 19
Model params: 0.90M


In [8]:
ds = ben.val_dataset
q,a=ds.sample_split(0, len(ds), sep='=', sep_included=-1)

errs = []
def test(q,a,g):
    global errs
    
    res = float(a == g)
    if not res:
        errs += [f"{q}: {a} != {g}"]
    return res
    
print(ben.measure_accuracy(q,a, test_fn=test))
print(len(errs),errs[:20])

0.0
1001 ['MMMMMMMMM=: 9000 != 8000', 'MMMMMMMMMI=: 9001 != 8009', 'MMMMMMMMMII=: 9002 != 8002', 'MMMMMMMMMIII=: 9003 != 8003', 'MMMMMMMMMIV=: 9004 != 8004', 'MMMMMMMMMV=: 9005 != 8005', 'MMMMMMMMMVI=: 9006 != 8006', 'MMMMMMMMMVII=: 9007 != 8007', 'MMMMMMMMMVIII=: 9008 != 8008', 'MMMMMMMMMIX=: 9009 != 801', 'MMMMMMMMMX=: 9010 != 8010', 'MMMMMMMMMXI=: 9011 != 8019', 'MMMMMMMMMXII=: 9012 != 8013', 'MMMMMMMMMXIII=: 9013 != 8013', 'MMMMMMMMMXIV=: 9014 != 8014', 'MMMMMMMMMXV=: 9015 != 8015', 'MMMMMMMMMXVI=: 9016 != 8016', 'MMMMMMMMMXVII=: 9017 != 8017', 'MMMMMMMMMXVIII=: 9018 != 8018', 'MMMMMMMMMXIX=: 9019 != 8099']


In [9]:
ds = ben.train_dataset
q,a=ds.sample_split(0, len(ds), sep='=', sep_included=-1)

errs = []
def test(q,a,g):
    global errs
    
    res = float(a == g)
    if not res:
        errs += [f"{q}: {a} != {g}"]
    return res
    
print(ben.measure_accuracy(q,a, test_fn=test))
print(len(errs), errs[:20])

0.7003000333370375
2697 ['I=: 1 != 10', 'II=: 2 != 9', 'IV=: 4 != 104', 'V=: 5 != 50', 'VI=: 6 != 106', 'VII=: 7 != 107', 'VIII=: 8 != 1', 'IX=: 9 != 109', 'XII=: 12 != 41', 'XIII=: 13 != 43', 'XIV=: 14 != 41', 'XV=: 15 != 40', 'XVI=: 16 != 41', 'XVII=: 17 != 43', 'XVIII=: 18 != 48', 'XIX=: 19 != 99', 'XX=: 20 != 19', 'XXI=: 21 != 49', 'XXII=: 22 != 42', 'XXIII=: 23 != 43']
