Can a relatively small model complete number sequences like 1,2,3?

In [1]:
from gptbench import Train, empty_config

In [2]:
ben = Train('intseq100k', seed=0xadaba5e)

# set training log periods to avoid cluttering the training output
ben.set_train_log_periods(sample_period=500, dot_period=1, loss_period=0)

# set train and validation datasets
ben.set_datasets(class_name='char', 
                 train_path='../data/intseq100k.txt', 
                 train_split=0.8)

# set config settings
cfg = empty_config()
cfg.model.set(n_layer=6, n_head=6, n_embd=90, block_size=64)
cfg.trainer.set(batch_size=128)
cfg.sample.set(top=1, max_batch_size=256) # top=1 means top_k(1) - always pick the best item

# and init a new model with config. set force_new to False to try resuming a previous checkpoint of this name
force_new = True
if ben.can_load() and not force_new:
    ben.load(cfg)
else:
    ben.init_new(cfg)

Initializing new model intseq100k
Dataset train_path: ../data/intseq100k.txt, val_path: None, train_split: 0.8, vocab_size: 11
Model params: 0.60M


In [3]:
# sequences in train and validation datasets:
'Train:', ben.train_dataset.encdec(0), 'Val', ben.val_dataset.encdec(0)

('Train:',
 '0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24',
 'Val',
 '0370 80371 80372 80373 80374 80375 80376 80377 80378 80379 80380')

In [4]:
# vocabulary used in both datasets:
ben.val_dataset.get_vocab_items()

[' ', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [5]:
# let's train:
ben.train(iter_count=5000)

Training
Iters per epoch: 3680
Iter 0 (0.000 epoch): loss train=2.3681, val=2.4267, eval->2.4267
==> Saving model at iter=0, eval loss->2.4267 
Sampling:   3 3 4333 3 323 436 33133137323 4323 33233 3233 33 33132333 3 31313131313131313131313131313131313131
CUDA max memory used: 746.01M
...................................................................................................
Iter 100 (0.027 epoch): loss train=1.9029, val=2.0760, eval->2.0760
==> Saving model at iter=100, eval loss->2.0760 
...................................................................................................
Iter 200 (0.054 epoch): loss train=1.5966, val=1.8549, eval->1.8549
==> Saving model at iter=200, eval loss->1.8549 
...................................................................................................
Iter 300 (0.082 epoch): loss train=1.2658, val=1.4586, eval->1.4586
==> Saving model at iter=300, eval loss->1.4586 
...............................................................

In [None]:
# In the last sampling lines above, we can see that it can complete sequences quite well...

In [18]:
# How are the current and lowest/best loss states?
ben.state, ben.last_saved_state

({'n_samples': 563072,
  'train_loss': 0.17054523527622223,
  'val_loss': 0.20620623230934143,
  'eval_loss': 0.20620623230934143},
 {'n_samples': 230400,
  'train_loss': 0.1777949184179306,
  'val_loss': 0.19670288264751434,
  'eval_loss': 0.19670288264751434})

In [19]:
# let's load the last (best) saved:
ben.load()

Loading checkpoint from ./checkpoints/intseq100k/
Checkpoint: iter=1800 (0.489 epoch), loss train=0.1778 val=0.1967 eval->0.1967
Dataset train_path: ../data/intseq100k.txt, val_path: None, train_split: 0.8, vocab_size: 11
Model params: 0.60M


In [6]:
# let's try completing some sequences:
ben.sample('85000 85001 ')

85000 85001 85002 85003 85004 85005 85006 85007 85008 85009 85010 85011 85012 85013 85014 85015 85016 85017 8501


In [7]:
ben.sample('3019 3020 3021 ')

3019 3020 3021 3022 3023 3024 3025 3026 3027 3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039 3040 3041 


In [14]:
ben.sample('719 720 ')

719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 


In [15]:
ben.sample('719 ')

719 77720 77721 77722 77723 77724 77725 77726 77727 77728 77729 77730 77731 77732 77733 77734 77735 7773


In [16]:
# right above it was completing 77719, not 719, so a preceding space character in important
ben.sample(' 719 ')

 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 


In [18]:
ben.sample(' 56 57 58 ')

 56 57 58 59 60 61 62 63 64 65 66 77 68 77 77 78 779 70 71 72 73 7 74 73 75 74 76 77 77 78 79 80 80 81 81 82 8


In [19]:
ben.sample(' 56 57 ')

 56 57 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 500 601 602 


In [20]:
ben.sample(' 56 57 ')

 56 57 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 500 601 602 


In [21]:
ben.sample(' 1 2 3 4 5 6 7 8 ')

 1 2 3 4 5 6 7 8 40 10 41 5 5 5 6 5 6 6 7 7 6 1 7 1 1 1 1 2 1 2 1 1 2 2 1 2 2 1 2 2 2 3 1 2 3 2 4 4 2 5 3 2 4 3 3 4 4


It seems to have problems with lower digits sequences, perhaps because there are less samples?

Next: try sequences of odd or even numbers. Prime number sequences?