In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    %cd /content/drive/MyDrive/Github/Abstract-generator/bumbleBERT/notebooks
    
    %%capture
    !pip install feedparser tokenizers transformers;

In [2]:
import os, torch, time, math, sys, re, csv
import numpy as np

sys.path.append('..' + os.sep )
from src import default
import src.data.dataset_class as dsc
import src.data.dataloader_class as dlc

from src.model.transformer_hf import TransformerModel
from src.model.generate_text import gen_some_text
from src.model.train_evaluate import train, evaluate
#from src.model.transformer import make_gpt_model # imports don't work

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Parameters

In [3]:
# ARCHITECTURE
maxLen     = 40 # maximum sentence length
vocabSize  = None # None if you want to let tokenizer do its thing
emsize     = 512 # embedding dimension
nhid       = 2048 # the dimension of the feedforward network model in torch.nn.TransformerEncoder
nlayers    = 12 # the number of torch.nn.TransformerEncoderLayer in torch.nn.TransformerEncoder
nhead      = 8 # the number of heads in the multiheadattention models
dropout    = 0.2 # the dropout value
batchSize = 10 #32
valBatchSize = 10 #32, not used right now.
epochs     = 10  # The number of epochs

TRAIN = True

### Format Dataset

Uses a custom dataset class, which is an iterable and callable structure that returns a sample from our dataset. Within this custom dataset, can determine all preprocessing.

In [4]:
# create dataset
dataset = dsc.ArxivDataset()
dataset = dsc.WikiTextDataset()

#train tokenizer (or use one already trained)
tknzrType = 'BPE'
tknzrTrain = True
tknzrFast = True

_ = dataset.tokenizer(tknzrTrain, tknzrType, tknzrFast=tknzrFast)






### Creating DataLoaders

Training is done on batches, so we need a way to extract groupings of the data in the appropriate format for our transformer model.
Note that for transformers which we are training, dataloaders outputs both src (x[:-1] and tgt ([1:]).
The collation of batches for different transformer models we have vary. For HuggingFace it's ( maxLen x batch_size ) whereas I think that the Annotated Transformer has ( batch_size x maxLen ).

I created a custom Dataloader class that wraps splitting the dataset and also outputs different dataloaders for each.

NOTE : Do not use the tokenizer before the training if you use num_workers>0!
FastTokenizer does not play nicely with forking if you use it before the forking of your data:
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning

In [5]:
dataloader = dlc.CustomDataloader(dataset, batchSize, maxLen)

### Selecting model

Here we choose which model we shall use for training. For now, I've selected the black box Transformer from HuggingFace because the collate_fn I've written gives the correct input size force it... however this can easily be changed! 

In [6]:
# transformer from huggingface
# TODO : Change to the Annotated Transformer if I want
model = TransformerModel(dataset.vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()#ignore_index=tknzr.get_vocab()["<pad>"])

# optimizer
paramsAdam  = [{'params' : model.parameters(), 'lr' : 1e-3, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsAdamW = [{'params' : model.parameters(), 'lr' : 5e-5, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsSGD   = [{'params' : model.parameters(), 'lr' : 0.5, 'momentum' : 0.0, 'dampening' : 0.0, 'weight_decay' : 0.0}]

#optimizer = torch.optim.SGD( paramsSGD )
#optimizer = torch.optim.Adam( paramsAdam )
optimizer = torch.optim.AdamW( paramsAdamW )

# scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) # 1.0 to signify no decay rate

### Training

Training loop!

In [7]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # fasttokenizer should not be used before forking. Something
                                                # to figure out. What this does is suppress some warning messages 
                                                # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
                                                # doesn't seem to affect the timing though
if TRAIN:
    best_val_loss = float("inf")
    best_model = None
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train( model, dataloader.train, device, dataset.vocabSize, epoch, optimizer, scheduler, criterion, maxLen)
        val_loss = evaluate(model, dataloader.valid, device, dataset.vocabSize, criterion, maxLen, len(dataloader.dsetValid))
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
                                         # Why is math.exp so large????
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()

    # save best model (two methods)
    modelFull = default.MODEL_DIR + os.sep + f'{dataset.name}_epoch{epochs}.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{dataset.name}_weights_epoch{epochs}.pth'
    modelFullBest = default.MODEL_DIR + os.sep + f'{dataset.name}_epoch{epochs}_best.pth'
    modelWeightsBest = default.MODEL_DIR + os.sep + f'{dataset.name}_weights_epoch{epochs}_best.pth'
    # approach 1: save model (class) entirely (uses pickle)
    torch.save(model, modelFull)
    torch.save(best_model, modelFullBest)
    # approach 2: save model weights
    torch.save(best_model.state_dict(), modelWeightsBest)

| epoch   1 |   200/  486 batches | lr 0.00 | ms/batch 29.12 | loss  7.71 | ppl  2236.18
| epoch   1 |   400/  486 batches | lr 0.00 | ms/batch 28.81 | loss  6.85 | ppl   947.28
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 14.57s | valid loss 25.62 | valid ppl 133193578879.72
-----------------------------------------------------------------------------------------
| epoch   2 |   200/  486 batches | lr 0.00 | ms/batch 29.12 | loss  6.48 | ppl   650.65
| epoch   2 |   400/  486 batches | lr 0.00 | ms/batch 29.28 | loss  6.38 | ppl   591.71
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 14.72s | valid loss 24.90 | valid ppl 64927638480.80
-----------------------------------------------------------------------------------------
| epoch   3 |   200/  486 batches | lr 0.00 | ms/batch 30.70 | loss  6.23 | ppl   509.03
| epoch   3 |   400/  486 batches 

### Text Generation

Here I've simply taken the code Matt uses to generate text.

In [8]:
if not TRAIN:
    customFilename = 'arxiv_10000'
    customEpochs = 10
    modelFull = default.MODEL_DIR + os.sep + f'{customFilename}_epoch{customEpochs}_best.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{customFilename}_weights_epoch{customEpochs}_best.pth'
    
    # approach 1: load model (class) entirely (uses pickle)
    modelFullLoad = torch.load(modelFull, map_location=device)

    # approach 2: load model weights, need to have some parameter or something 
    modelLoad = TransformerModel(vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)
    modelWeightsLoad = modelLoad.load_state_dict( torch.load(modelWeights) )

In [11]:
# inspect both models
#print('model_A info...\n', modelFullLoad)
#print('\nmodel_B info...\n', modelWeightsLoad)

#print('model_A == model_B:', modelFullLoad == modelWeightsLoad)
#model = modelFullLoad
# Text generation example

#model = modelLoad
prompt = 'The dog ran'
ngen = 100
decode_style = 'sample_topp' #greedy, sample_topp
model.to('cpu')
generated_text = gen_some_text(
    best_model, dataset.transform, 'cpu', maxLen, text_prompt=prompt, tokens_to_gen=ngen, vis=False,
    decode_style=decode_style)
print("Text prompt:\n", prompt)
print("Number of tokens to generate:", ngen)
print("Generated_text:\n", generated_text)

# TODO: alternative generation
# currently 'greedy method'
# see: https://huggingface.co/blog/how-to-generate

<s>  The  dog  ran
[  215   191   329 ...  2179 16284   176]
[0.1631927  0.21704185 0.2625346  ... 0.999936   0.999936   0.999936  ]
topp_indices 330 [  215   191   329   244   299   213   295   184   186   210   250   310
   266   205    84   281   292   553   804   327   263   363   251   973
   584   919   444  1170   765   590   548   311  1374   735  1670   788
   535   258   750  1116   505   195  2281   958   572   691  1394  6928
   528  1294   668  2288  3184   567  1094  1052   376  1461  2234   627
   276   460   821  1162  1302   445   800  1841   697  1267  2621   378
  1306  2669  3603  2996   921   272  3144  2678   645   224  2675  1417
  2079   918   940   631   699  1633  1551  2088   652  2617   862  1122
   939  1766  1708  1973   414  1657  1629   187  1783   214  1087  2622
  2116  5128  1375  3031  2894  4482  3207  1845  2692  6598  2625  1671
  2613   318   976  3051  1960  3018  5743  1276  2814  2172   943  1239
  2948   591  2687  6156   308   748   952  450

[  244  1417   191 ... 22221 17823 19831]
[0.15564361 0.1986399  0.24094766 ... 0.99993294 0.99993294 0.99993294]
topp_indices 449 [  244  1417   191   940   363   266  1267   817   215   327   444   310
   553   735   251   590   376  2288  1657  1394   184  1783  1291   210
  1315   505   788  1094  1629   329  1374  1863  2617  1551  1116  1843
   186  3050  1656  1810   276  2622   509  5139  1370  2234  2325   821
   213  2173  2745  1685   750  1735   804   295  1336   973  1030  3869
   919  4319  2893  2239   263  1162   292  1302  2625   272  3205  3509
  1973  1784  2624   958  3031  3764  1608   584  1766   205  1670   838
  1294   528  2272  1135   299  2415  1617  2681  1480  1838  1968  1183
  1136   281   411  1525   699   918   645   548   567  2313  1170  5105
  2041  2841  1064  2414  4128  3184  3753  3267  2077  2470  5361  1365
  1952  2172   199  2754  4983  2563  2678  1708  2409  1820   691  2965
   451   955  2600  5987  3035  3222   195   577  7028  6949  4492

<s>  The  dog  ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his
[ 1215  1199   451 ... 20162 15122 28503]
[0.09025329 0.12354585 0.15295626 ... 0.99994165 0.99994165 0.99994165]
topp_indices 1837 [ 1215  1199   451 ... 14110 14885  3186]
tensor([-2.1317, -7.9609,  0.2840,  ..., -2.1276, -3.5113, -1.9682],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own
<s>  The  dog  ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own
[  195   923  1184 ...  9516 20162  4300]
[0.03394539 0.05365226 0.07267669 ... 0.99993527 0.99993527 0.99993527]
topp_indices 2345 [  195   923  1184 ... 13209  3437  6339]
tensor([-4.8607, -8.1718, -1.4840,  ..., -1.7933, -3.6342, -2.9913],
       grad_

[0.17044033 0.33247066 0.48855588 ... 0.9999412  0.9999412  0.9999412 ]
topp_indices 124 [   2  184  509  215  318  186  411  735  210  907  631  251 1657  614
 2600  444  308  572  292  445 3060  281 1336 2288 1771  765  498  529
  548 2231  955 3008 1545  788 1968  263  295  652  460 2883  363  191
 2272 2965 2470 1692  451 1291 1688 2701 1766  311 1346 1135 1347  250
  266  276 2415  570  587  299  749 1525  272 1838 1784 1752 1608 1375
 1808 1656 1461 3525 1491 3041 1863 1267 3097 1116 1564 3787 1783 2325
 2049 1953 4474  821 1629 2019  906  758 2625 5127  327 1670 1394 2238
  195 4647 5363 1963  282 3143  449 2806 1365  918 4918 1820 1058  750
  590 1792  329 3487  676 1512 1417 3207 4287  213  952 1302]
tensor([-3.6741, -6.7810,  7.6986,  ..., -1.6797, -3.0773, -2.5033],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s>
<

[    0  1647   219 ... 27178 11544  3022]
[0.99993527 0.9999468  0.9999498  ... 0.99996984 0.99996984 0.99996984]
topp_indices 1 [0]
tensor([16.9213, -1.5015,  4.8685,  ..., -0.7182, -2.8083,  0.6255],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
[    0  1647   219 ... 27178 11544  3022]
[0.9998144  0.9998511  0.9998569  ... 0.99993867 0.99993867 0.99993867]
topp_indices 1 [0]
tensor([16.0489, -1.6456,  5.3609,  ..., -1.0997, -2.7015,  0.6966],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  bod

[    0  1647  1892 ... 11544 23301  3022]
[0.9867565  0.9911699  0.99161446 ... 0.99990684 0.99990684 0.99990684]
topp_indices 1 [0]
tensor([12.4789, -2.5782,  5.0832,  ..., -1.6825, -2.7035,  0.3756],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
 The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
[    0  1647  1892 ...  9474 22728  3022]
[0.9736541  0.97994244 0.98112565 ... 0.99993247 0.99993247 0.99993247]
topp_indices 1 [0]
tensor([11.7567, -2.6714,  5.0053,  ..., -2.1229, -2.6645,  0.3833],
       grad_fn

[    0  1647  1892 ... 10694  9549 11281]
[0.4385211 0.5018875 0.5288134 ... 0.9999927 0.9999927 0.9999927]
topp_indices 1568 [    0  1647  1892 ...  9161 21347 20378]
tensor([ 8.3414, -3.5234,  5.2279,  ..., -2.7976, -2.2112,  0.0709],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
[    0  1647   258 ... 10694  9549 11281]
[0.28024545 0.3614427  0.39583275 ... 0.9999967  0.9999967  0

[    0     2  1647 ...  9549 11281 10694]
[0.46572715 0.5466121  0.5608666  ... 0.9999881  0.9999881  0.9999881 ]
topp_indices 1381 [    0     2  1647 ... 14807  3326  2824]
tensor([ 8.2912, -2.9620,  6.8324,  ..., -2.8125, -2.4777, -1.1242],
       grad_fn=<SelectBackward0>)
The dog ran  and  a  large  practice ,  the  CDO  made  a  new  body  and  only  not  armed .  The  storm  to  withdraw ,  he  was  unable  to  his  own  claim ,  and  had <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  rapidly  percentage <\s>  bombarding
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  rapidly  percentage <\s>  bombarding
[    2   251   186 ... 112

[    0     2   258 ...  9549 11281 10694]
[0.5718715 0.598783  0.6167609 ... 0.9999884 0.9999884 0.9999884]
topp_indices 876 [    0     2   258  1647  1892  2312   205  1495   251 16852   284   833
  3785   295   979   691  1832   527  3621   213   191   414   186  1266
  1244   572  1237  2627  2628  1893   836  8865  1520  1430   195  1183
  1044   210  1076  4153  2030  1690  3139   393   425   262  2631  8811
  2405   949  3337  1606  4682   329   458  1308   485 11380  2617   567
  2419   832  3582  3537  1335   528  5415  1415  2243  8609  1840  1499
  4843   184  2067  3039  3093   444  2264  3276  8283  1109  4383   225
   311  1450  1042   649  4009  4920  4008 13986  1417  3202   930   285
 14898   756   609 23063 11662   611 14754   977   215  3744  5800 16512
  1233   535 17256   821 11608  1572  2466 10162  2513  4101  1310  3501
  7460   244  1823   758  1780  4684 23373 26552  2054  1129   498   878
  3435  6374  1740 28673  9190   714  3879   281 11549 19544  1824  3374

[    0     2   258 ...  9549 11281 10694]
[0.68548924 0.7060409  0.7177128  ... 0.99998397 0.99998397 0.99998397]
topp_indices 351 [    0     2   258  1647  1892  2312   251   205 16852   284  1495  3785
   833   295   691   979   527   186  1832   191  3621   213  1244   414
   572  1266  8865  1237  2627  1893  2628  1520  1430   195  1044  1183
   836  1076  2030  4153   210   393  3139  2631  8811  1690  1606   949
   262  3337   425  2405 11380  4682   485   458   832  2617   184  1308
  3582  2419  1335   329  1499  1415   567  3537  5415  2243  2067  3093
  3039   528  8609  4843  1109  8283  2264  1840   311   444   225  3276
  1450  3202  4383 13986  4920  1417  1042  4009   649  4008  5800   930
   611 17256 11662 16512   756 23063   609   977  1233   821  4101  3744
   535 14898 14754  7460 11608  1823  2513   215  1780  1572   498   244
   878  2466 10162 23373 26552  3501  1740   285  1310 11549  3435   281
  1129   758  6374  4684  9190  2054  1824 19544  6089  5767 13316

[    0     2  1647 ...  9549 11281 10694]
[0.6164069  0.6401452  0.6542462  ... 0.99998415 0.99998415 0.99998415]
topp_indices 671 [    0     2  1647   258  1892   251  2312   284 16852   205  1495   833
   295  3785   691   186   191   979   527   213  1832  3621  1244  1266
   572  8865   414  1520  1237   195  1893  1044  2030  1183  2627  1430
  2628  1076  4153   210   836   393  1606  8811  1690  3139  3337  2631
   425   485   262   184 11380  4682   949   832  2419  2617   458  3582
  1308  5415  1499  3093  3039  1335  2405  4383  3537  1415  4843  1450
   528   329  4009  3202   567  1840   311  2243   930   444  2067  2264
  8283  8609   611  4920  3276 13986   225  1109  1417 17256  5800 14754
  4101   535  4008 11662   649   244   609   821 16512  1042   977  3744
   756 23063 11549  1233   215  2466 14898   498  7460   285  1740  1780
  3501 26552  1823 10162  2513  1129  1824 23373   878 11608  2054  1572
   758  3435   281   310  9190  6374  1310  5767  8215  6089  6788