In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
IN_COLAB = 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    %cd /content/drive/MyDrive/Github/Abstract-generator/notebooks

In [3]:
%%capture
if IN_COLAB:
  !pip install feedparser tokenizers transformers scipy==1.7.1;

In [4]:
import os, torch, time, math, sys, re, csv
import numpy as np

PACKAGE_ROOT = os.path.dirname(os.path.abspath(''))
print(PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

from src import settings
import src.data.dataset_class as dsc
import src.data.dataloader_class as dlc

from src.model.transformer_torch import TransformerModel
from src.model.generate_text import gen_some_text

from src.model.train_evaluate import train_version_jeremy as train
from src.model.train_evaluate import evaluate_version_jeremy as evaluate

#from src.model.transformer import make_gpt_model # imports don't work

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

/home/jbrothschild/Documents/HLML/Abstract-generator


### Parameters

In [5]:
# ARCHITECTURE
# TODO : Make a class that sets all this, validates it's in use
max_len_sentence     = 40 # maximum sentence length
vocab_size  = None # None if you want to let tokenizer do its thing
emsize     = 512 # embedding dimension
nhid       = 2048 # the dimension of the feedforward network model in torch.nn.TransformerEncoder
nlayers    = 12 # the number of torch.nn.TransformerEncoderLayer in torch.nn.TransformerEncoder
nhead      = 8 # the number of heads in the multiheadattention models
dropout    = 0.2 # the dropout value
batch_size = 10 #32
val_batch_size = 10 #32, not used right now.
epochs     = 50  # The number of epochs

TRAIN = True

### Format Dataset

Uses a custom dataset class, which is an iterable and callable structure that returns a sample from our dataset. Within this custom dataset, can determine all preprocessing.

In [6]:
# create dataset
dataset = dsc.ArxivDataset()
#dataset = dsc.WikiTextDataset()

#train tokenizer (or use one already trained)
tknzr_type = 'BPE'
flag_tknzr_train = True
flag_tknzr_fast = True

_ = dataset.tokenizer(flag_tknzr_train, tknzr_type, flag_tknzr_fast=flag_tknzr_fast)



### Creating DataLoaders

Training is done on batches, so we need a way to extract groupings of the data in the appropriate format for our transformer model.
Note that for transformers which we are training, dataloaders outputs both src (x[:-1] and tgt ([1:]).
The collation of batches for different transformer models we have vary. For HuggingFace it's ( max_len x batch_size ) whereas I think that the Annotated Transformer has ( batch_size x max_len ).

I created a custom Dataloader class that wraps splitting the dataset and also outputs different dataloaders for each.

NOTE : Do not use the tokenizer before the training if you use num_workers>0!
FastTokenizer does not play nicely with forking if you use it before the forking of your data:
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning

In [7]:
flag_padding_mask = True

dataloader = dlc.CustomDataloader(dataset, batch_size, max_len_sentence, flag_padding_mask=flag_padding_mask) 

### Selecting model

Here we choose which model we shall use for training. For now, I've selected the black box Transformer from HuggingFace because the collate_fn I've written gives the correct input size force it... however this can easily be changed! 

In [8]:
# transformer from huggingface
# TODO : Change to the Annotated Transformer if I want
model = TransformerModel(dataset.vocab_size, emsize, nhead, nhid, nlayers, dropout).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()#ignore_index=tknzr.get_vocab()["<pad>"])

# optimizer TODO : why these parameters?
paramsAdam  = [{'params' : model.parameters(), 'lr' : 1e-3, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsAdamW = [{'params' : model.parameters(), 'lr' : 5e-5, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsSGD   = [{'params' : model.parameters(), 'lr' : 0.5, 'momentum' : 0.0, 'dampening' : 0.0, 'weight_decay' : 0.0}]

optimizer = torch.optim.SGD( paramsSGD )
#optimizer = torch.optim.Adam( paramsAdam )
#optimizer = torch.optim.AdamW( paramsAdamW )

# scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) # 1.0 to signify no decay rate

### Training

Training loop!

In [9]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # fasttokenizer should not be used before forking. Something
                                                # to figure out. What this does is suppress some warning messages 
                                                # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
                                                # doesn't seem to affect the timing though
if TRAIN:
    best_val_loss = float("inf")
    best_model = None
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model, dataloader.train, device, dataset.vocab_size, epoch, optimizer, scheduler, criterion, max_len_sentence)
        val_loss = evaluate(model, dataloader.valid, device, dataset.vocab_size, criterion, max_len_sentence)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
                                         # Why is math.exp so large????
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()

    # save best model (two methods)
    model_full = settings.DIR_MODELS + os.sep + f'{dataset.name}_epoch{epochs}.pth'
    model_weights = settings.DIR_MODELS + os.sep + f'{dataset.name}_weights_epoch{epochs}.pth'
    model_full_best = settings.DIR_MODELS + os.sep + f'{dataset.name}_epoch{epochs}_best.pth'
    model_weights_best = settings.DIR_MODELS + os.sep + f'{dataset.name}_weights_epoch{epochs}_best.pth'
    # approach 1: save model (class) entirely (uses pickle)
    torch.save(model, model_full)
    torch.save(best_model, model_full_best)
    # approach 2: save model weights
    torch.save(best_model.state_dict(), model_weights_best)

| epoch   1 |   200/  700 batches | lr 0.50 | ms/batch 959.29 | loss  7.82 | ppl  2500.78
| epoch   1 |   400/  700 batches | lr 0.50 | ms/batch 971.15 | loss  7.07 | ppl  1172.85
| epoch   1 |   600/  700 batches | lr 0.50 | ms/batch 967.55 | loss  6.90 | ppl   993.80
39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 3098,   213,   213,  3061,  2987,   189,   213,   213,   213,   189],
        [  146,  1383,  1091,   130,   130,  1741,   566,  1930,   714,  1662],
        [  374,   121,   514,   533,   692,  1479,   105,   146,   107,  1902],
        [  202,   133,   130,  4550,   592,   162,   613,  3446,   611,   122],
        [  109,  4251,   904,   187,   202,   107,   311,  2879,   122,   508],
        [  130,  1313,  1039,   416,  9425,  1197, 22456,   171,  1159,   130],
        [  138,  1508,   122,   280,   205,  6280,   909,   107,  2783,  2957],
        [  205,  8243,   280,   130,  1172,   122,   130,  4073,   175,  2044],
       

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [16159,   920,   189, 13041,   189,  1089,   213,   189,   339,   213],
        [ 3676,  3676,  1192,   122,   247,   130,  1013,   957,  1389,  1042],
        [  122,   786,  2389,   105,   420,   663,   107,   420,  1032,  5007],
        [ 2595,  2062,   171,  3497,   121,  2484,  4113,   391,   143,   541],
        [  665,   151,   105,  7010,  3181,   162,  9604,   122,   246,   122],
        [ 1690,   107,  6365,   130,  6078,   105,   122,  1486,  1181, 11329],
        [  146,   134,   476,  5071,  3173,  1798,   134,   130,   107,   210],
        [  415,  2342, 10520,  4240,   187,  3441,   130,   134,  2315,   121],
        [  130,  2237,   151,   391,  4579,   121,  1800, 12124,   409,  4396],
        [  555,   143,  6191,  3063,   391,  2642,   467,  1291,   210,  2318],
        [  134,   107,   105,   900,   162,   199,   121,   334,   146,   166],
        [ 1362,  2947, 17348,  4726, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  339,   213, 14679,   189,  5246,   213,   213, 19667,   189,  2284],
        [  107,  1672,  4379,   957,  2276,  1013,   553,   130,  3053,   107],
        [ 1659,   146,  4898,   130,   162,   107,   107, 19109,  1400,  1729],
        [ 1033,  1730,  2032,   420,   133,   431,  4119,  1705,   146,   122],
        [  863,   979,   295,   134, 10767,   356,   122,  4994,  3053,   266],
        [ 5665,   105,   723,   383,   266,   374,   107,   202,   102,   723],
        [ 1129,   748,   122,   122,   130,   307,  3424, 21781,  1369,   246],
        [ 3922,  4113,   107,   105,  6660,   879,  7977,  1001,  1251,  1013],
        [  508,   613,   266,   469,   662,   122,  1404,  5316,  1273,   107],
        [  134,   143,  3849,   130,   121,   107,   412,  2480,   211,  3898],
        [  597,  3158,   151,   533,  4934,  2086,    67,  3264,   910,   409],
        [  787,   130,   555,   280, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213, 15920,   213,   249, 19441,  3061,  5425,   213,   189,   213],
        [  876,   130,   510,   979,   162,   130,   122,  1930,   631,  2548],
        [  105,  5509,  1269,   105,   107,   134,   508,  2879,   671,   107],
        [ 4162,  1419,   383,  8499,  1405,   483,   130,   171,   202,  1450],
        [ 1218,   130,   130,   311,  3542,  1049,  7486,   107,  4212,  5139],
        [  542,  1319,  1053,  3618,   146,   268,  2044, 16966,   205,  3064],
        [  171,  1742,   542,  1686, 11295,   146,  3000,  2551,   122,   122],
        [  134,  1162,   202,   246,  2972,   293,   187,   130,   133,   508],
        [  280,   202,  2405,  5219,   121,  1049,   134,   474,  1935,   134],
        [  997,  2697,   205,   199,   107,   268,  3081,  1106,   166,  1033],
        [  121,   205,   683,   107,  2074,  1339,   146,   151,   125,   787],
        [ 2439,   574,   723,   631, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,  1334,   213,  8662,   189,   213,  2993,   213,  6689,   213],
        [ 1013,   107,   510,   900,  4139,  1396,   105,   566,  2623,   335],
        [  107,  3826,   778,  5887,   959,   107,   945,   107,  1362,   199],
        [  356,   130,   431,   510,   122,   706,   130,   451,   510,   107],
        [ 1463,  2347,   356,  1754,  8582,   656,   307,   122,   133,  5213],
        [  122,   130,  2124,  7530,  5748,   398,  1545,   786,  5414, 10159],
        [  107,  1789,  1033,   225,   202,   143,   122,   130,  8160,  2310],
        [  247,   925,  1162,  2812,  4002,   166,   107,  1213,   121,  4597],
        [  391,   143,   146,   322,   205,   186,  1400,  1385,  1304,  7290],
        [  122,   246,  1105,   130,   210,   202,   130,  1039,   122,   208],
        [ 1806,   553,  3740,   533,   146,   370,  3053,   122,   641,   105],
        [17062,   107,   151,   941, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 1445,  1955,   256,  1828,   189, 21957,   256,   213,  1334, 17789],
        [ 1250,   105,  4610,  1013,  2245,   130,   748,  1943,   107,   247],
        [25720,  3471,  2901,  1761,  1102,  4592,   613,  1013,  2986,   597],
        [  162,  2177,   122,   134,   122, 10479,   151,   107,  1868,   787],
        [  105, 11124,   107,   474,   105,  1234,  9551,  1971,  1874,   105],
        [ 5141,   635,   247,   143,  3412,   105,   175,   122,   202,  5766],
        [16831,   246,   146,   311,   146,   900,   130,   105,  7797,  2157],
        [  202,  6491,   374,   162,   105,   130,  1165,   508,   205,   162],
        [ 8843,   107,   541,  6701,  4173,  1356,  2984,   130,  1601,  1126],
        [ 2129, 19527,   122,   121,   368,   559,   901,   134,   817,   208],
        [  205,   541, 12190,  1592,   121,  3636,  3190,  3385,  5848,   107],
        [ 2721,   122,    23,   962, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  256, 13294,   189,   213,   256,  1334,  9481,  5134,   213,   213],
        [  748,   143,  3756,   553,   542,   505,   538,  4730,  1042,  1672],
        [ 1732,   121,   151,   107,   171,   130,   925,   621,   967,   105],
        [  162,   311,   597,   611,   133,  2000,   957,   107,   130,  2317],
        [ 1383,   107,   266,   122,   134,   723,   420,  2317,   481,   788],
        [  171,   741,   162,   134,  5762,   817,  3700,  1248,  6969,   171],
        [ 3462,   166,   105,  1945,   122,   107,   247,   122,   146,  2770],
        [ 1077,   126,  1798,   175,  3046,  1836,  2792,  2159,  1745,  1644],
        [ 4642,   166,  3441,  1108,  2423,   122,   660,  4094,  8467,   634],
        [  731,  1680,   122,  2437,   162,  4685,   322,  1322,   512,   151],
        [  789,   210,   107,  1793,  1383,  7097,   202,  1754,  3525,   280],
        [  900, 12027,   322,  4830, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,   710,  1334, 10415,   213,   213,  1082,   189,   213,   189],
        [28153,   574,   107,   228,   553,   335,  1463,  5678,  1943,  5165],
        [  133,   643,   973,   280,   134,   199,   122,  2319,  1013,   130],
        [ 4767,   105,   130,   130,   597,   105,   107,   122,   107,  1089],
        [  122,  1107,  1300,   718,   121,   682,   134,   105,   474,  5617],
        [  635,  4702,  2614,   146, 15886,   130,   656,   900,   122, 10644],
        [  102,   151,   130,   280, 11305,  5075,   398,   162,   133,   216],
        [ 4003,   876,  2530,   130,  1715,  1314,   121,  1850,  3080,   107],
        [18068,   431,   613,  1648,  5707,   134,   107,   151,   134, 10150],
        [  288,   130,   143,   334,   151,  4485,   523,  2990,   476,   122],
        [  716,  1625,   246,  4829,   322,   323,   130,   105,  4050, 25218],
        [ 3391,   134,  1396,   162, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [12986,  2318, 23844,   189, 12203,   189,   256,   189,  3371, 13041],
        [12263,   166, 23193,   134,  2167,   553,  1950,  2151,   930,   122],
        [12571,   464,  3025,   130,  1217,   175,   753,   731,   175,  2162],
        [ 2501,    23,   187,   151,  9610,   107,   122,   122,   334,  4844],
        [ 1177,   166,  4120,   130,   105,   247,  2701,   210,  1381,  1125],
        [  202,   162,   143,  8472,   322,   420,   941,   121,   121,   211],
        [15946,   105,  3906,  1821,   130,   122,  1929,   105,   107,  6579],
        [  205,  1165,  5232,   229,   533,  3324,   247,  6103,   772,   171],
        [  171,  1685,   151,   134,  1443,   225,   417,  2579, 17807,   105],
        [  508,   187, 21562,  1702,   648,  4312,   199,   420,   863, 13451],
        [  134,   105,   202,   162,   225,   130,   211,   162,   510,   745],
        [ 1778,  1644,  1006,   105, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [29264,   339,  3583,   213,   189,  4703,  4375,   189,   516,   213],
        [ 2753,   343,   121,   566,  6366,   542,   143,   247,   162,   979],
        [ 1349,  1313,   105,   133,   806,  2560,   392,  1125,   105,   107],
        [  121,   143,   900,  2186,   143,   171, 10692,   122,  1313,  3756],
        [ 3454,   246,   211,  3918,  2759,   107, 25032,  1427,  1452,   151],
        [ 2640,  2225,  9883,   122,   247,  1161,   151,    67,   892,  3516],
        [  211,   105,  1580,   107,  1680,   122,   553,   362,  1669,   134],
        [ 2455,  2170,   199,  6337,   481,   107, 12320,  3301,   897,  1362],
        [  171,   122,   857,  1419,   202,   134,  2546,   130,   123,   187],
        [  807,  4770,  1025,  1458, 11115,   130,   731,   856,   189,   415],
        [ 1030,   107,   187,   216,   205,   663,   122,   306,   967,   130],
        [  123,  2197,   890,   311, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  189,   189,  1082, 12714,   213,   213,   189,  1089,   213,   189],
        [  545,  2023,  3200,  1451,  1672,   876,  1765,  1927,  3733,  1743],
        [  122,  1472,  3134,   611,   199,   105,  1629,   121,   107,  2580],
        [  133,   122,  1545,   121,   105,   748,   122,   107,  2339,   122],
        [  826,   107,   171,  1659,  1860,   834,   202,  1710,  1835,   105],
        [  134,  5611,   107,   134,  3866,   151,  3398,   491,   122,   665],
        [  130,  1717,   957,   597,   420,   134,   143,  3464,   107,   523],
        [  134,   130,   134,   787,   323,   597,  2635, 22652,  3001,   162],
        [  295,  1400,   420,   508,  1370,   121,   205,   187,   134,   122],
        [  121,   317,   211,   130,   121,  2984,   625,  1710,   130,  1147],
        [  105,   202,   701,  1744,  7977,   592,    67,   334,  1518, 10030],
        [  334,   972,   151,  3190, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 3098,   189,   213,   213,   256,   189, 21347,  1082,   189,  2986],
        [  417,  2155,  1672,   566,  9381,  2665,   130,  8851,   538,  1277],
        [  146, 10224,   105,   105,   806,   146,  5524,   146,  1789,  1066],
        [  597,   794,  1590,  1753,   162,  4839,  5000, 12832,   122,   121],
        [ 1903,   122,   171,  1421,  1355,   122, 20207,  1385,   107,  1857],
        [  121,   508,   105,   122,   660,  7335,   211,  1732,   125,   134],
        [ 1685,  1321,  6089,  1380,   107,   130,   122,   162,   783,   941],
        [  334,  1710,   280,  2167,  2456,  1202,  2602,  1383,   483,   162],
        [ 1381,  4468,   268,   210,   134,  2753,   171,   171,  1400,  3990],
        [  211,   661,   491,   931,   803,   162,   107,   107,   351,  7339],
        [  910,  3453,   293,   169,   122,  3347,  4993,  3602,   162,  1452],
        [ 1943,   134,   473,  6546, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  339, 20274,  1528,   189, 15636,   189,   256,   213,  1334,  1089],
        [  343,   151,  1952,  2286,   409,   611,   374,  1594,   225,   146],
        [ 1042,  2197,   122,   122,  1844,   122,   307,   107,   105,  1108],
        [  143, 12266, 11956,  1844,   134,   134,  1269,  1504,   351,  1495],
        [  246,   143,  1005,   134,   130,   130,   151,   122,   107,  1549],
        [  566,  1366,   130,   130,  1800,   134,  1202,   105,  1400,   187],
        [  105,   107,  1319,  2102,   603,   924,   130,  7553,  1590,  1083],
        [ 3754,   826,  5823,  1199,   121,   175,  2766,   166,   246,  2902],
        [ 7210,  5654,  1162,   202,   105,   107,  1620,   102,  1843,   121],
        [ 1836,  9849,   246, 17620,  6225,   431,  5373,  3978,  1592,  2638],
        [ 8430,   143,   979,   205,  2984,   130,  3581,  1251,  2919,   211],
        [  107,   122,   107,   121, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  189,   189,   967,   189,   213,   213,   213,   339,   213,  5640],
        [  247,  1313,   130,  1201,   876,   566,   566,   343,  6425,   130],
        [  391,  3973,   986,   130,   105,   107,   105,  1313,   107,  1625],
        [  122, 13868,  1162,   469,   748,   774,   334,   143,  1706,   879],
        [  107,  2074,   162,   130,  1107,  1898,   542,   383,   122,   162],
        [ 1685,  8374,   133,   533,  3704,   171,   122,  1053,   107,  1391],
        [ 1291,   151,  1119,  1115,   171,   107,  2101,   542,   957,   121],
        [12124, 10481,  2600,   317,  4249,   166,   122,   723,   130,   202],
        [  334,  2541,   171,  2677,   105,   373,   105,   211,   420,  1591],
        [ 2155, 16651,   107, 25672,  4368,   166,  2407,   778,  2263,   143],
        [  130,  2643,  2901,    23,  2447,  6860,  5413,   151,   146,  5405],
        [ 2347,   123,   122,   187, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,   213,   339,   213,   189,   189,  6616,   189,   213,   189],
        [ 1013,  1960,  3698,  2823,   247,  1032,   636,   368,   553,  1330],
        [  107,  1943,  3260,   107,  1125,   162,  4836,   122,   105,   122],
        [ 1945,   146,  3264,  2560,   143,  7557,   510,   134,   469,   105],
        [  611,  1730,   508,   322,   555,   151,  1754, 11129,   130,  1753],
        [  122,   107,   130,   130,   146,   107,   854,   175,   533,   280],
        [ 2761,   247,  1321,   481,   280,  2782,   107, 14812,   143,   130],
        [  210,   391,   307,   351,   731,   122,   686, 11706,   322,  1439],
        [  121,   146,   130,   122,   122,   107,  2592,  5589,   481,   293],
        [  107,  4433,   368,  1109,   107,   134,  2623, 15505,   351,    40],
        [  953,  1789,  4660,   151,  6078,   130,   210,   146,   187,   134],
        [  867,   122,   143,  1396, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,  6612,   213,  7195,   213,  9223, 10644, 12572,  1089,  1082],
        [  977,   134,  1042,  2753,  1013,   290, 12941,   122,   731,   505],
        [ 1742,  5032,   133,   202,   107,   122,   171,   175,   121,  2000],
        [ 1162,   892,  4951,   109,  1164,   603,  3425,   130,   107,   143],
        [  151,   208,  1114,   123,   122,   122,   130,  3595,  1309,  1007],
        [  553,   105,   130,   186,  1307,   105, 15854,   210,   122,   130],
        [  247,  3169,  4522,  2034,   121,  1051, 11203,   187,   786,   533],
        [ 1125,   665,  5494,  7335,  5490,   134, 20886,  1837,   481,   334],
        [  122,  1363,  1479,   143, 17645,   523,  2826,  6202,   874,  2407],
        [ 5752,   162,   202,  6471,   146,   187,  5990,   211,  2208,  1706],
        [  146,   105, 11517,   146,   979,   508,  8713,   910,  2542,   574],
        [  568,   594,   205,   890, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 5081,   339,   189,   189,   189,   213,   339,   189, 14182,   213],
        [  307,   107,  4211,   134,  1688,  1042,   918,  7236, 29141,   553],
        [ 1742,  1314,   943,   731,   409,   774,  1817,   122,   122,  1740],
        [  146,   134,   280,   121,  3957,   505,  1565,   107,  3278,  1349],
        [  134,  2101,   130,  2208,    23,  2000,   288,   134,   210, 23655],
        [ 6774,   143,   618,   941,   146,  1051,  1717,   631,   162,   130],
        [ 1281,   107,   467,   892,  2318,   768,  2638,   671,   469,   955],
        [  105,  5141,   121,   208,    23,   723,   143,   162,   122,   247],
        [ 7513,  5177,  7158,  2595,  9559,   122,   107,   469,   107,   417],
        [ 4716,   523,   323,  2359,   105,   585,  1310,   122,  6578,  2238],
        [  151,  2101,   169,   665,   322,  4160,   122,   107,  5667,   121],
        [ 2381,  1601,  4161,  1690, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 5165,   189,  2284,   516, 10477,  5744,   213,  3098,   710, 10907],
        [ 3913,   469,   107,   162,  2189,   130,   566,   909,   162,   208],
        [ 1362,   130,  4773,   105,   151,  4703,   133,   130,   853,  1389],
        [ 3063,  3104,   692,  9727,   107,   542,  1104,  1393,   199,   930],
        [  105,   973,   143,   167,   134,  3048,   122,   611,   107,   175],
        [  322,   130,   247,   175,  4070,   107,   107,  5858,   772,  1944],
        [  130,   266,   682,   107,   130,   541,  3093,   323,   768,   280],
        [  533,  2478,  3230, 15830,   266,   122,   247,   169,  1638,  1439],
        [  134,   151,   122,  6723,   959,   107,   146,   107,   171,  3398],
        [  317,   107,  2538,   175,   146,  1777,   597,  3308,   105,  2423],
        [  211,   268,   134,   871,   151,   146,   541,  2304,   317,   143],
        [ 4086,    87,  3823,  1313, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [12635,   256,   189,   213,   339,   213,   256,   213,   213, 19442],
        [  130,  4984,   553,  1396,   343,  1355, 11300,   335,  1042,   121],
        [ 2639,   122,   122,  2425,  2119,   107,  1051,   199,   774,  5501],
        [ 6747,   107,   343,  4673,   246,   391,  2091,   107,   130,   134],
        [ 3017,   102,  1313,   121,   714,   122,   171,  9343,  1750,  3695],
        [ 4607,   130,  3593,  1168,  1007,  3536,   107,   122,   723,   211],
        [  648,  1251,   199,   122,  1119,   130,   134,   107,   199,  2271],
        [  225,   351,   134,   943,  1598,   420,   130,   280,  5275,  1244],
        [ 1465,   122,   574,   134,   122,  3033,   952,   130,  2682,  2759],
        [  166,   105,  1339,   130,  4219,  1243,   146,   146,   146,   823],
        [  464,  1740,  2590,   423,   121,  1199,  1108,  2351,   247,   130],
        [  362,   130,   955,   467, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 5374,   213,   189,   213,   213,  3583, 14784,   213,   213,   339],
        [  208,  2758,  3264,   510,   335,   199,   799,  1013,  1672,   343],
        [ 1389,   105, 13761,  1871,   199,   211,  1244,   107,   322,  1313],
        [ 2602,  7650,   551,  3157,   883,  2116,   187,   901,  1204,   107],
        [  121,  2486,  2249,  1419,   210,   151,  4639,  1793,   334,   280],
        [ 1842,   122, 10714,   130,   323,   105,  1515,   146,  2317,  2792],
        [  541,   107,  8716,  1319,  4259,   508,   162,   864,  1981,   122],
        [  122,  1105,   130,  1742,  6256,  1803,  1119,   122,   171,   107],
        [  857,   491,  6606,  1162,  3780,   692,   171,  1349,  6694,   957],
        [ 4958,  1077,  1536,   202,   121,   121,  1842,   417,  1493,   420],
        [ 1488,  2162,   464,  2697,   953,   105,   754,   122,  4254,   146],
        [  143,   555,    22,   205, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,  1089,   249, 19180,  4809, 10165,   710,  3577,   213,  8908],
        [  553,  5984,   566,   122,   754,   335,   162,   107,  1262,   202],
        [  107,   211,   105,   657,  3228,   199,   853,  1836,   105,  2073],
        [  280,  4252,   847,  2835,  1720,   745,   199,   122,   613,   205],
        [ 4110,   143,   351,   122,  4876,  1580,   121,   107,   151, 16023],
        [  143,  6103,   122,  1913,   143,   121,   909,  9442,  1396,  5228],
        [  687,  2116,   105,   121,   247,   107,   130,  5288,   187,   208],
        [  123,   134,   334,   107,   391,  1705,   134,  2498,   334,  2595],
        [  109,  1362,  3635,   811,   916,  2477,   592,   660,  1077,   789],
        [ 2034,  2759,  3941,   122, 29535,  6937,   334,   392,   731,   130],
        [  107,   682,   171,  1159,   216,  2990,   968,   162,  1039,  2752],
        [  209,   130,   311,  2219, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,  4652,  1955,   213,   189,   256,   213, 17331,   213, 17759],
        [  553,   130,   107,   553,  2386,  2170,  2875,   122,   335,   146],
        [ 1380, 13835, 17082,   423,   122,   151,   107,   107,   107,  2382],
        [  107, 24240,   117,   130,   799,  1091,  3093,  1192,  2202,   807],
        [  134,  4467,   542,  3920,  8452,   508,   731,   130,   122,  4308],
        [  130,   611,   143,   134,   323,   130,   122,  2389,  1626,  2009],
        [  134,  2303,   246,  1033,   169,   134,  1307,   171, 10884,  8644],
        [  603,   208,   553,   121,  4776,  4457,   121,   107,   121,  1390],
        [ 1747,   374,   105,  1685,  1764,  1816,  1959,   662,   107,  1775],
        [  107,   962,  1107,   334,   208,  3092,   187,   122,  6186,   146],
        [  555,   143,  1172,   733,  5686,   121,   415,  1291,  3762, 15601],
        [ 1858,  2963,   134,  1715, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,   710,  4375,   189,   213,   256, 10304,  4421,  5617,   213],
        [ 1672,   162,   143,  2008,   876,  4413,   130,   481,   130,   553],
        [  105,   853,   392,   409,   107,   542,  1422,   391,   745,   107],
        [ 1753,   199,   574,   134,  2317,   162,  4599,   122,   134,   134],
        [  134,   121,   643,   130,   542,  1606,  1878,   107,  4308,   130],
        [ 1030,  2195,   701,   134,   122,   151,   122,  3655,   323,   423],
        [  788,   416,   199,   295,   134,  1297,  1771, 10906,   169,   295],
        [  143,   374,  3204,   146,   280,  1592,  4398,  5000,   778,   121],
        [  311,   307,  6542,   134,  2499,  2835,   175,   130,   151,   107],
        [  246,   107,  8072,   130,   121,   122,   900,   955,  2299,   857],
        [ 7151,  4077,   202,   423,   202,   247, 17467,  1544,   415,   918],
        [ 3665,   134,  8133,   467, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 5374,   189, 16048,   339,   213,   339, 10055,  5246,   339, 14532],
        [  208,  7092,  2643,   107,   510,   343,  3000,  2849,  2540,  2258],
        [ 1389,  3607,   143,  1135,   778,  1032,   211,   551,   122,  1689],
        [  942,  6528,   648,   122,   107,   143,   107,  5350,   210,   130],
        [  130,   685,   225,  1389,   883,   322,  1456,   166,   288,  1253],
        [  514,    74,  6528,  4871,  2486,  2151,   122,   122,   551,   491],
        [  134,   210,   146,   121,   122,  1164,  5025,  5752,  1721,  2086],
        [ 5139,   211,  8590, 11057,  5948,   864,  3223,   146,   473,   162],
        [  930,  6204,   143,  4421,  2385,   518,  6815, 18244,   151,   105],
        [  143,   151,   146,   143,  2085,   807,   634,  5010,   109, 12328],
        [  246,  2007,   754,   246,   225,   211,   151,  4904,   166,   130],
        [ 3446,  1171,  2019,   714, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  189,   189,  2735,  1334,  5191, 10636,  2635,   213,   189,   189],
        [ 1400,  1349,   175,   774,   134,   134,  7728,   553,  2912,  3138],
        [  351,   420,   107,   130,  1495,  1695,   166,   107,   122,  2901],
        [  121,   731,  1754,  1750,  1069,   202,   464,  6449,   105,   122],
        [  107,   122,  1383,   383,   151, 14809,    22,  3745,   334, 16017],
        [ 4886,   918,   481,  1053,  2017,   205,   166,   122,  3903,  1129],
        [  836,   134,   351,   542,  1108,  3321,   162,   107,   317,  2826],
        [  162,   146,   143,   723,  1495,   634,   105,   295,   162, 14870],
        [  910,   134,   107,   143,  1281,   151,  1754,   409,  1126,   107],
        [  208,   130,   247,   246,  2812, 12217,  3655,  2102,   143,  1706],
        [ 1741,   423,  1765,  1126,   872,   122,  1202,  1307,   187,   122],
        [  130,   592,  1629,   107, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [20180,   516, 28574,   920,  5298,   189,  1089,   213,   213,   213],
        [  122,  1516, 13986,   134,   107,  2245,   731,   510,  1042,  2119],
        [ 1209,   122,  1489,   130,  6234,   368,   121,  1126,   175,   107],
        [  210,  2173,   130,   953, 21195,   122,  3204,   107,   105,   682],
        [  574, 13050, 13732,  2038,  4114,   134,  2439,   247,  2831,  4090],
        [  643,   938, 21386,   143,   143,   130,   162,   417,  1898,   121],
        [ 1754,  7521,   202,   107,   392,   134,  1394,   122,   122,   107],
        [ 3655,  1896, 23161,   416,   574,   603,   208,   107,   107,   542],
        [  121,   216,   205,  2825,   643,   146, 11225,  6528,  1194,   122],
        [  770,   107,  2067,   122,  3824,   280,   409, 23626,  1934,  1165],
        [  143,  9268,  3321,  1946,   689,   130,   133, 10008,   122,   146],
        [ 1325,  8561,   288,   280, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [15183,   189,  8704,   213,   213, 29006,  2131,   189,   189,  7158],
        [  130,   823,   130,   714,  1960, 20822,  1439,  1389,   924,  3456],
        [  904,   130,  1659,  1164,   322,  4517,   210,  2218,   122,   211],
        [ 4547,   753,  5834,   122,   474,   955, 16849,   122,   107,  2812],
        [  122,   768,   122,  1740,  1981, 17616,   105,   105,   210,  6521],
        [  266,   864, 15350,  1837,   199,    87,  1168,  1356,   307,   151],
        [  811,   146,  2531,  2193,  3560,   146,   660,   151,   130,  4490],
        [  122,   107,   121,   130,  3473,  8757,   107,   900,  2219, 13069],
        [ 2980,  2047,   105,   955,  2527,   935,   665,   559,   288,   146],
        [  210,   130,  1077,  1217,   247,  2299,  1135,   202,   105,  2682],
        [  121,  2258,   523,   208,  2086,   134,   162, 17615,  2169,   121],
        [ 2642,   603,  5706,   107, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [ 2131,  9103,   189, 13713,   189,  5881,   213,  1334,  7277,   256],
        [  105,  1159,   247,   129,  5004,   107,   714,   499,   731,   901],
        [ 1007,   897,  2151,   247,   122,  4806,   247,   523,   122,  2071],
        [  130,   162,   391,  1504,   343,   122,   597,   266,  1069,   121],
        [  533,   133,   122,   211,  3619,  4094,   787,  8842,  1715,   105],
        [  134,  1119,   105, 10188,   162,   146,   105,   146,   211,  2133],
        [  861,   266,  1806,  1684,   151,  1877,   508,   107,  3775,   134],
        [  162,  1704,   536,   151,   566,  1066,   130,  4435,  2657,   861],
        [ 5982,   788,   704,  3843,   107,   121,   134, 10707,   208,   938],
        [  151,   171,  2639,  1903,  6758,  4607,   334,   368,   107,  1754],
        [  105,  2167,   162,   648,   122,   162,   733,   143,   322,  3655],
        [ 1169,   210,  1126,   225, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [20258,   213,   189,   213,  4375,  5199,  8436,   189,  1082, 24789],
        [  208,   977,  4458,  2983,   143,   130,  1319,  2102,  1747,   134],
        [ 1389,   799,   122,   107,   105,  1319,  1742,  1244,   122,  1495],
        [ 5559,  3093,  1033,   209,  6443,  2023,  1162,   122,   107,   121],
        [ 1815,   130,   322,   130,  2646,   856,   202,   107,   134,  2439],
        [  122,  5388,   130,  1719,   122,   130,  2697,  7866,   130,  2758],
        [ 2116,  1162,   692,   134,  1253,   986,   205,   900,   423,  8901],
        [ 3104,   151,   592,   130,   574,   474,  2421,  2702,   295,   334],
        [  130,   561,   202,   134,   643,   288,   107,  2695,   175,  8747],
        [ 1144,   280,  9425,   768,   689,   884,  1444,   130,  1164,  1991],
        [ 2529,  2026,   205,   151,   121,   130,   122, 11514,   811,   122],
        [  967,   122,   171,   107, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,   189,   516,   213,   189,   339,   213,   213,   213,   256],
        [  566,  3023,  3619,  2936,  1771,   343,   714,  1930,   566,   665],
        [  105,   122,   162,   105,   585,  1313,   105,   133,   105,   130],
        [ 1048,   134,  8088,  1107,   122,   246,   351,   826,   553,  1422],
        [ 2619,   897,   151,  1601,   107,   510,  1590,   247,   122,  5040],
        [ 4060,   661,   107,   151,   134,   910,   171,  1590,   107,   745],
        [  171,   134,  2782,  2875,   861,   107,   105,   171,   247,   374],
        [  107,  2003,   122,  1178,   175,   368,  4247,  3235,   391,   307],
        [ 1545,   216,   107,   356,   105,   122,  1366,   806,   143,   162],
        [  122,  1169,  2838,   247,   322,  1556,   620,  1400,   423,  1681],
        [  107,   431,  6426,   541,   130,   374,   107,   130,  2282,   151],
        [  620,   266,   121,   208, 

39 tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  213,  1465,  6100,   213,   256, 10280,   189, 16015,   256,   213],
        [ 1042,  7360, 26079,   553,  1385,  2023,  1089,   107,   803,  1960],
        [  133,   122,   202,   107,  2600,   856, 13751,   134,   122,   107],
        [ 2023,   671,  7027,   541,   171,   130,  6032,  2052,  2797,   368],
        [  134,  3780,   205,   122, 12071,   986,   202,   886,   541,   122],
        [ 2016,   208,   162,  1666,  2804,  1135,  8779,   208,   122,   134],
        [  553,  9068,   105,  3780,  2120,  2051,   205,   134,   770,   130],
        [  122,   210,  1659,   121,   134,  3560,  1369,  1778,   146,   423],
        [ 7669,   574,   134,   322,  4138,  3158,   225,   121,  4919,   467],
        [ 5331,   823,  4699,   130,   121,  2023,  1383,  1745,   211,   175],
        [  661,   643,   143,  2057,   134,  1314,  7521,  6963, 13641,  1742],
        [  671,  5477,  2826,   134, 

-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 692.74s | valid loss  6.82 | valid ppl   916.16
-----------------------------------------------------------------------------------------


KeyboardInterrupt: 

### Text Generation

Here I've simply taken the code Matt uses to generate text.

In [None]:
if not TRAIN:
    custom_filename = 'arxiv_10000'
    custom_epochs = 10
    model_full = settings.DIR_MODELS + os.sep + f'{custom_filename}_epoch{custom_epochs}_best.pth'
    model_weights = settings.DIR_MODELS + os.sep + f'{custom_filename}_weights_epoch{custom_epochs}_best.pth'
    
    # approach 1: load model (class) entirely (uses pickle)
    model_full_load = torch.load(model_full, map_location=device)

    # approach 2: load model weights, need to have some parameter or something 
    model_load = TransformerModel(vocab_size, emsize, nhead, nhid, nlayers, dropout).to(device)
    model_weights_load = model_load.load_state_dict( torch.load(model_weights) )

In [None]:
# inspect both models
#print('model_A info...\n', model_full_load)
#print('\nmodel_B info...\n', model_weights_load)

#print('model_A == model_B:', model_full_load == model_weights_load)
#model = model_full_load
# Text generation example

#model = model_load
prompt = 'The dog ran'
ngen = 100
decode_style = 'sample_topp' #greedy, sample_topp
model.to('cpu')
generated_text = gen_some_text(
    best_model, dataset.transform, 'cpu', max_len_sentence, text_prompt=prompt, tokens_to_gen=ngen, vis=False,
    decode_style=decode_style)
print("Text prompt:\n", prompt)
print("Number of tokens to generate:", ngen)
print("Generated_text:\n", generated_text)

# TODO: alternative generation
# currently 'greedy method'
# see: https://huggingface.co/blog/how-to-generate