In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    %cd /content/drive/MyDrive/Github/Abstract-generator/bumbleBERT/notebooks
    
    %%capture
    !pip install feedparser tokenizers transformers;

In [2]:
import os, torch, time, math, sys, re, csv
import numpy as np

sys.path.append('..' + os.sep )
from src import default
import src.data.dataset_class as dsc
import src.data.dataloader_class as dlc

from src.model.transformer_hf import TransformerModel
from src.model.generate_text import gen_some_text
from src.model.train_evaluate import train, evaluate
#from src.model.transformer import make_gpt_model # imports don't work

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Parameters

In [3]:
# ARCHITECTURE
maxLen     = 40 # maximum sentence length
vocabSize  = None # None if you want to let tokenizer do its thing
emsize     = 512 # embedding dimension
nhid       = 2048 # the dimension of the feedforward network model in torch.nn.TransformerEncoder
nlayers    = 12 # the number of torch.nn.TransformerEncoderLayer in torch.nn.TransformerEncoder
nhead      = 8 # the number of heads in the multiheadattention models
dropout    = 0.2 # the dropout value
batchSize = 10 #32
valBatchSize = 10 #32, not used right now.
epochs     = 50  # The number of epochs

TRAIN = True

### Format Dataset

Uses a custom dataset class, which is an iterable and callable structure that returns a sample from our dataset. Within this custom dataset, can determine all preprocessing.

In [4]:
# create dataset
dataset = dsc.ArxivDataset()
dataset = dsc.WikiTextDataset()

#train tokenizer (or use one already trained)
tknzrType = 'BPE'
tknzrTrain = True
tknzrFast = True

_ = dataset.tokenizer(tknzrTrain, tknzrType, tknzrFast=tknzrFast)






### Creating DataLoaders

Training is done on batches, so we need a way to extract groupings of the data in the appropriate format for our transformer model.
Note that for transformers which we are training, dataloaders outputs both src (x[:-1] and tgt ([1:]).
The collation of batches for different transformer models we have vary. For HuggingFace it's ( maxLen x batch_size ) whereas I think that the Annotated Transformer has ( batch_size x maxLen ).

I created a custom Dataloader class that wraps splitting the dataset and also outputs different dataloaders for each.

NOTE : Do not use the tokenizer before the training if you use num_workers>0!
FastTokenizer does not play nicely with forking if you use it before the forking of your data:
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning

In [5]:
dataloader = dlc.CustomDataloader(dataset, batchSize, maxLen)

### Selecting model

Here we choose which model we shall use for training. For now, I've selected the black box Transformer from HuggingFace because the collate_fn I've written gives the correct input size force it... however this can easily be changed! 

In [6]:
# transformer from huggingface
# TODO : Change to the Annotated Transformer if I want
model = TransformerModel(dataset.vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()#ignore_index=tknzr.get_vocab()["<pad>"])

# optimizer
paramsAdam  = [{'params' : model.parameters(), 'lr' : 1e-3, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsAdamW = [{'params' : model.parameters(), 'lr' : 5e-5, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsSGD   = [{'params' : model.parameters(), 'lr' : 0.5, 'momentum' : 0.0, 'dampening' : 0.0, 'weight_decay' : 0.0}]

#optimizer = torch.optim.SGD( paramsSGD )
#optimizer = torch.optim.Adam( paramsAdam )
optimizer = torch.optim.AdamW( paramsAdamW )

# scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) # 1.0 to signify no decay rate

### Training

Training loop!

In [7]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # fasttokenizer should not be used before forking. Something
                                                # to figure out. What this does is suppress some warning messages 
                                                # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
                                                # doesn't seem to affect the timing though
if TRAIN:
    best_val_loss = float("inf")
    best_model = None
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train( model, dataloader.train, device, dataset.vocabSize, epoch, optimizer, scheduler, criterion, maxLen)
        val_loss = evaluate(model, dataloader.valid, device, dataset.vocabSize, criterion, maxLen, len(dataloader.dsetValid))
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
                                         # Why is math.exp so large????
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()

    # save best model (two methods)
    modelFull = default.MODEL_DIR + os.sep + f'{dataset.name}_epoch{epochs}.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{dataset.name}_weights_epoch{epochs}.pth'
    modelFullBest = default.MODEL_DIR + os.sep + f'{dataset.name}_epoch{epochs}_best.pth'
    modelWeightsBest = default.MODEL_DIR + os.sep + f'{dataset.name}_weights_epoch{epochs}_best.pth'
    # approach 1: save model (class) entirely (uses pickle)
    torch.save(model, modelFull)
    torch.save(best_model, modelFullBest)
    # approach 2: save model weights
    torch.save(best_model.state_dict(), modelWeightsBest)

| epoch   1 |   200/  486 batches | lr 0.00 | ms/batch 29.17 | loss  7.73 | ppl  2275.87
| epoch   1 |   400/  486 batches | lr 0.00 | ms/batch 28.87 | loss  6.88 | ppl   974.42
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 14.72s | valid loss 25.60 | valid ppl 131361991014.71
-----------------------------------------------------------------------------------------
| epoch   2 |   200/  486 batches | lr 0.00 | ms/batch 29.09 | loss  6.50 | ppl   663.50
| epoch   2 |   400/  486 batches | lr 0.00 | ms/batch 28.25 | loss  6.39 | ppl   596.02
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 14.51s | valid loss 24.91 | valid ppl 66039371377.62
-----------------------------------------------------------------------------------------
| epoch   3 |   200/  486 batches | lr 0.00 | ms/batch 28.98 | loss  6.23 | ppl   507.75
| epoch   3 |   400/  486 batches 

| epoch  20 |   200/  486 batches | lr 0.00 | ms/batch 29.81 | loss  4.42 | ppl    82.72
| epoch  20 |   400/  486 batches | lr 0.00 | ms/batch 29.14 | loss  4.42 | ppl    83.14
-----------------------------------------------------------------------------------------
| end of epoch  20 | time: 14.87s | valid loss 23.24 | valid ppl 12367966010.66
-----------------------------------------------------------------------------------------
| epoch  21 |   200/  486 batches | lr 0.00 | ms/batch 28.84 | loss  4.37 | ppl    78.66
| epoch  21 |   400/  486 batches | lr 0.00 | ms/batch 28.45 | loss  4.36 | ppl    77.95
-----------------------------------------------------------------------------------------
| end of epoch  21 | time: 14.47s | valid loss 23.24 | valid ppl 12328640751.53
-----------------------------------------------------------------------------------------
| epoch  22 |   200/  486 batches | lr 0.00 | ms/batch 30.06 | loss  4.32 | ppl    75.03
| epoch  22 |   400/  486 batches |

| epoch  39 |   200/  486 batches | lr 0.00 | ms/batch 29.78 | loss  3.69 | ppl    40.04
| epoch  39 |   400/  486 batches | lr 0.00 | ms/batch 29.35 | loss  3.71 | ppl    40.74
-----------------------------------------------------------------------------------------
| end of epoch  39 | time: 14.86s | valid loss 23.56 | valid ppl 17056109788.47
-----------------------------------------------------------------------------------------
| epoch  40 |   200/  486 batches | lr 0.00 | ms/batch 29.57 | loss  3.68 | ppl    39.84
| epoch  40 |   400/  486 batches | lr 0.00 | ms/batch 29.23 | loss  3.68 | ppl    39.57
-----------------------------------------------------------------------------------------
| end of epoch  40 | time: 14.87s | valid loss 23.60 | valid ppl 17728510077.19
-----------------------------------------------------------------------------------------
| epoch  41 |   200/  486 batches | lr 0.00 | ms/batch 29.40 | loss  3.66 | ppl    38.94
| epoch  41 |   400/  486 batches |

### Text Generation

Here I've simply taken the code Matt uses to generate text.

In [8]:
if not TRAIN:
    customFilename = 'arxiv_10000'
    customEpochs = 10
    modelFull = default.MODEL_DIR + os.sep + f'{customFilename}_epoch{customEpochs}_best.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{customFilename}_weights_epoch{customEpochs}_best.pth'
    
    # approach 1: load model (class) entirely (uses pickle)
    modelFullLoad = torch.load(modelFull, map_location=device)

    # approach 2: load model weights, need to have some parameter or something 
    modelLoad = TransformerModel(vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)
    modelWeightsLoad = modelLoad.load_state_dict( torch.load(modelWeights) )

In [10]:
# inspect both models
#print('model_A info...\n', modelFullLoad)
#print('\nmodel_B info...\n', modelWeightsLoad)

#print('model_A == model_B:', modelFullLoad == modelWeightsLoad)
#model = modelFullLoad
# Text generation example

#model = modelLoad
prompt = 'The dog ran'
ngen = 100
decode_style = 'greedy' #greedy, sample_topp
model.to('cpu')
generated_text = gen_some_text(
    best_model, dataset.transform, 'cpu', maxLen, text_prompt=prompt, tokens_to_gen=ngen, vis=False,
    decode_style=decode_style)
print("Text prompt:\n", prompt)
print("Number of tokens to generate:", ngen)
print("Generated_text:\n", generated_text)

# TODO: alternative generation
# currently 'greedy method'
# see: https://huggingface.co/blog/how-to-generate

<s>  The  dog  ran
tensor([-5.4660, -6.1401, -3.6607,  ..., -1.0406, -1.7796, -2.9135],
       grad_fn=<SelectBackward0>)
The dog ran  into
<s>  The  dog  ran  into
tensor([-6.5964, -9.0464, -3.9420,  ..., -0.4227, -0.9724, -2.2983],
       grad_fn=<SelectBackward0>)
The dog ran  into  a
<s>  The  dog  ran  into  a
tensor([-2.7266, -6.0550, -2.4167,  ..., -1.7753, -0.9441, -1.2870],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical
<s>  The  dog  ran  into  a  tropical
tensor([-1.9276, -4.3682,  0.6975,  ..., -0.2348,  0.2135, -2.1654],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression
<s>  The  dog  ran  into  a  tropical  depression
tensor([ -6.6767, -11.7690,   1.0595,  ...,  -2.4775,  -5.8395,  -1.9622],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on
<s>  The  dog  ran  into  a  tropical  depression  on
tensor([-3.6805, -9.0180,  2.2344,  ..., -3.1880, -0.6575, -3.7623],
       grad_fn=<SelectBackward0>)
Th

tensor([24.5543, -0.8966,  5.2844,  ..., -0.1265, -0.2876,  1.1082],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad>
<s>  The  dog  ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad>
tensor([24.4588, -0.9913,  5.4639,  ..., -0.2446, -0.2223,  1.0862],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad> <pad>
<s>  The  dog  ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  c

tensor([14.7975, -2.3566, -0.7450,  ..., -2.5153, -0.4684,  0.7997],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
 tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([14.2055, -2.5421, -0.7973,  ..., -2.6626, -0.2428,  0.6202],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <

tensor([ 5.4398, -4.1931, -0.2123,  ..., -3.6752, -1.8246, -0.1923],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  D
<\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  D
tensor([-2.8735, -7.3338,  1.5734,  ..., -4.0192, -2.8315, -1.4398],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclon

tensor([-1.3355, -6.8782,  7.5075,  ..., -4.4148, -2.1113, -2.1413],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression  on  August  19 ,  and  the  JMA  upgraded  the  storm  to  tropical  cyclones  on  August  31 .  The  system  moved  across  the  east  of  the  Philippines , <\s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  D  and  the  east  of  the  Naktong  and  the  east  of  the  Naktong  and  the <\s>
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>  D  and  the  east  of  the  Naktong  and  the  east  of  the  Naktong  and  the <\s>
tensor([13.0030, -2.6790,  4.1362,  ..., -1.8144, -2.2800,  0.5950],
       grad_fn=<SelectBackward0>)
The dog ran  into  a  tropical  depression