In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    %cd /content/drive/MyDrive/Github/Abstract-generator/bumbleBERT/notebooks
    
    %%capture
    !pip install feedparser tokenizers transformers;

In [2]:
import os, torch, time, math, sys, re, csv
import numpy as np

sys.path.append('..' + os.sep )
from src import default

from src.data import download as dl, tokenization as tkn, custom_dataset as cd

from torch.utils.data import DataLoader
from src.model.transformer_hf import TransformerModel
from src.model.batching import CustomBatch
from src.model.generate_text import gen_some_text
from src.model.train_evaluate import train, evaluate
#from src.model.transformer import make_gpt_model # imports don't work

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Parameters

In [3]:
# ARCHITECTURE
maxLen     = 35 # maximum sentence length
vocabSize  = None # None if you want to let tokenizer do its thing
emsize     = 512 # embedding dimension
nhid       = 2048 # the dimension of the feedforward network model in torch.nn.TransformerEncoder
nlayers    = 12 # the number of torch.nn.TransformerEncoderLayer in torch.nn.TransformerEncoder
nhead      = 8 # the number of heads in the multiheadattention models
dropout    = 0.2 # the dropout value
batch_size = 10 #32
val_batch_size = 10 #32
epochs     = 10  # The number of epochs

# TOKENIZER
tknzerType = 'BPE' # type of tokenizing algorithm
trainTokenizer = True # whether to train a new tokenizer or use one already trained

TRAIN = True

### Download Dataset

In [4]:
# download data
nbrResults = 10**4 # number of data samples to download
filename = f'arxiv_{nbrResults}'
extension = '.csv'
filename += extension

filepath = default.RAW_DATA_DIR + os.sep + filename

if not os.path.exists(filepath):
    dl.arxiv_api( default.RAW_DATA_DIR, filename, max_results=nbrResults ) # TODO : CHANGE SO THAT NOT CONSTANTLY LOADING DATA
print(f'>> Using {filename} for training <<')

>> Using arxiv_10000.csv for training <<


### Format Dataset

Uses a custom dataset class, which is an iterable and callable structure that returns a sample from our dataset. Within this custom dataset, can determine all preprocessing.

In [5]:
# create dataset
dataset = cd.ArxivDataset(filepath)

### Training Tokenizer

Training of a custom tokenizer. Many options possible here, check the tokenizer training functions to try out various strategies. If he tokenizer for the dataset has already been trained, no need to run this again.

In [6]:
trainTokenizer: _ = tkn.train_custom_tokenizer(tknzerType, dataset, filename
                                            , default.TOK_DIR
                                            , vocabSize
                                            , **default.special_token_lst)






### Loading Tokenizer and Splitting Datasets

For some reason, torch tokenizers are not callable as trained. This is confusing, but c'est la vie! Instead, need to load it from file it was saved in using the PreTrainedTokenizerFast class (__call__) implemented in here. Once that's done, you can add this tokenizer as a transform to your dataset! Useful.

We also split the dataset here into training, testing and validation datasets.

In [7]:
tknzrFile = default.TOK_DIR + os.sep + filename + '_' + tknzerType + '.json'

# load PreTrainedTokenizerFast, for __call__. __call__ not implemented in
# the base Tokenizer class... that sounds silly, but it is what it is
tknzr = tkn.load_tokenizer(tknzrFile, **default.special_token_lst)

if vocabSize is None: vocabSize = tknzr.vocab_size

# set tknzr as the transform
dataset.set_transform( tknzr )

# separate dataset into train, test valid TODO : make into a function
fracTrain, fracTest, fracVal = ( 0.7, 0.2, 0.1)
trainTestVal = [ np.floor(fracTrain*len(dataset))\
                    , np.floor(fracTest*len(dataset))\
                    , len(dataset) - ( np.floor( fracTrain*len(dataset) ) +
                    np.floor( fracTest*len(dataset) ) )
                    ]

trainDataset, testDataset, valDataset =\
        torch.utils.data.random_split(dataset, [int(x) for x in trainTestVal]
                                , generator=torch.Generator().manual_seed(42) )


### Creating DataLoaders

Training is done on batches, so we need a way to extract groupings of the data in the appropriate format for our transformer model.
Note that for transformers which we are training, dataloaders outputs both src (x[:-1] and tgt ([1:]).
The collation of batches for different transformer models we have vary. For HuggingFace it's ( maxLen x batch_size ) whereas I think that the Annotated Transformer has ( batch_size x maxLen ).

NOTE : Do not use the tokenizer before the training if you use num_workers>0!
FastTokenizer does not play nicely with forking if you use it before the forking of your data:
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning

In [8]:
# create dataloaders
# uses collate function to transform batch to correct dimensions
def collate_wrapper(batch):
    return CustomBatch(batch, dim=0, maxLenModel=maxLen, padValue=tknzr.get_vocab()["<pad>"])

# dataloader for training
trainDataLoader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True
                                        , num_workers=2
                                        , collate_fn=collate_wrapper
                                        , pin_memory=True
                                        )
# dataloader for validation
valDataLoader = DataLoader(valDataset, batch_size=val_batch_size, shuffle=True
                                        , num_workers=2
                                        , collate_fn=collate_wrapper
                                        , pin_memory=True
                                        )

### Selecting model

Here we choose which model we shall use for training. For now, I've selected the black box Transformer from HuggingFace because the collate_fn I've written gives the correct input size force it... however this can easily be changed! 

In [10]:
# transformer from huggingface
# TODO : Change to the Annotated Transformer if I want
model = TransformerModel(vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)

# criterion
criterion = torch.nn.CrossEntropyLoss()#ignore_index=tknzr.get_vocab()["<pad>"])

# optimizer
# learning rate Matt used with Adam is 0.5
paramsAdam  = [{'params' : model.parameters(), 'lr' : 1e-3, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsAdamW = [{'params' : model.parameters(), 'lr' : 5e-5, 'betas' : (0.9, 0.999), 'eps' : 1e-08, 'weight_decay' : 0.0}]
paramsSGD   = [{'params' : model.parameters(), 'lr' : 0.5, 'momentum' : 0.0, 'dampening' : 0.0, 'weight_decay' : 0.0}]

#optimizer = torch.optim.SGD( paramsSGD )
#optimizer = torch.optim.Adam( paramsAdam )
optimizer = torch.optim.AdamW( paramsAdamW )

# scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

### Training

Training loop!

In [12]:
if TRAIN:
    best_val_loss = float("inf")
    best_model = None
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train( model, maxLen, trainDataLoader, device, vocabSize, epoch, optimizer, scheduler, criterion)
        val_loss = evaluate(model, maxLen, valDataLoader, len(valDataset), device, vocabSize, criterion)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
                                         # Why is math.exp so large????
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()


    # save best model (two methods)
    modelFull = default.MODEL_DIR + os.sep + f'{filename}_epoch{epochs}.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{filename}_weights_epoch{epochs}.pth'
    modelFullBest = default.MODEL_DIR + os.sep + f'{filename}_epoch{epochs}_best.pth'
    modelWeightsBest = default.MODEL_DIR + os.sep + f'{filename}_weights_epoch{epochs}_best.pth'
    # approach 1: save model (class) entirely (uses pickle)
    torch.save(model, modelFull)
    torch.save(best_model, modelFullBest)
    # approach 2: save model weights
    torch.save(best_model.state_dict(), modelWeightsBest)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34

torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([14, 10]) torch.Size([14, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([21, 10]) torch.Size([21, 10])
| epoch   1 |   200/  700 batches | lr 0.00 | ms/batch 28.93 | loss   nan | ppl      nan
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([28, 10]) torch.Si

torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
| epoch   1 |   400/  700 batches | lr 0.00 | ms/batch 26.55 | loss   nan | 

torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([20, 10]) torch.Size([2

-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 19.85s | valid loss   nan | valid ppl      nan
-----------------------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([

torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([32, 10]) torch.Size([3

torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([24, 10]) torch.Size([2

torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([24, 10]) torch.Size([2

-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 19.55s | valid loss   nan | valid ppl      nan
-----------------------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([

torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
| epoch   3 |   200/  700 batches | lr 0.00 | ms/batch 28.29 | loss   nan | 

torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([2

torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([20, 10]) torch.Size([2

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34

torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([20, 10]) torch.Size([20, 10])
| epoch   4 |   200/  700 batches | lr 0.00 | ms/batch 28.87 | loss   nan | ppl      nan
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Si

torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([14, 10]) torch.Size([14, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([34, 10]) torch.Size([3

torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([15, 10]) torch.Size([1

-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 20.47s | valid loss   nan | valid ppl      nan
-----------------------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([

torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([22, 10]) torch.Size([2

torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([18, 10]) torch.Size([1

torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([2

-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 20.03s | valid loss   nan | valid ppl      nan
-----------------------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([

torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([26, 10]) torch.Size([26, 10])
| epoch   6 |   200/  700 batches | lr 0.00 | ms/batch 27.45 | loss   nan | ppl      nan
torch.Size([19, 10]) torch.Si

torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([31, 10]) torch.Size([31, 10])
| epoch   6 |   400/  700 batches 

torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([26, 10]) torch.Size([2

length of src : 20
20
length of src : 24
24
length of src : 19
19
length of src : 29
29
length of src : 18
18
length of src : 19
19
length of src : 29
29
length of src : 18
18
length of src : 31
31
length of src : 18
18
length of src : 21
21
length of src : 23
23
length of src : 22
22
length of src : 18
18
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 19.78s | valid loss   nan | valid ppl      nan
-----------------------------------------------------------------------------------------
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid 

torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([28, 10]) torch.Size([2

torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([3

torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([3

length of src : 29
29
length of src : 19
19
length of src : 19
19
length of src : 29
29
length of src : 27
27
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
length of src : 18
18
length of src : 27
27
length of src : 30
30
length of src : 34
34
length of src : 33
33
length of src : 21
21
length of src : 20
20
length of src : 26
26
length of src : 20
20
length of src : 19
19
length of src : 21
21
length of src : 18
18
length of src : 28
28
length of src : 21
21
length of src : 19
19
length of src : 19
19
length of src : 24
24
length of src : 22
22
length of src : 28
28
length of src : 22
22
length of src : 34
34
length of src : 34
34
length of src : 25
25
length of src : 28
28
-------------------------------------------------------------------------------

torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([20, 10]) torch.Size([2

torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([19, 10]) torch.Size([1

torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([15, 10]) torch.Size([15, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([25, 10]) torch.Size([2

length of src : 28
28
length of src : 19
19
length of src : 20
20
length of src : 28
28
length of src : 21
21
length of src : 34
34
length of src : 23
23
length of src : 28
28
length of src : 31
31
length of src : 27
27
length of src : 27
27
length of src : 19
19
length of src : 20
20
length of src : 29
29
length of src : 21
21
length of src : 32
32
length of src : 18
18
length of src : 26
26
length of src : 27
27
length of src : 27
27
length of src : 20
20
length of src : 18
18
length of src : 18
18
length of src : 20
20
length of src : 34
34
length of src : 18
18
length of src : 20
20
length of src : 23
23
length of src : 20
20
length of src : 24
24
length of src : 22
22
length of src : 28
28
length of src : 34
34
length of src : 28
28
length of src : 25
25
length of src : 27
27
length of src : 34
34
length of src : 17
17
length of src : 34
34
length of src : 20
20
length of src : 27
27
length of src : 23
23
length of src : 33
33
length of src : 15
15
length of src : 17
17
length of 

torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([3

torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([29, 10]) torch.Size([29, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([16, 10]) torch.Size([1

torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([34, 10]) torch.Size([3

length of src : 23
23
length of src : 29
29
length of src : 26
26
length of src : 25
25
length of src : 21
21
length of src : 24
24
length of src : 34
34
length of src : 34
34
length of src : 28
28
length of src : 18
18
length of src : 25
25
length of src : 20
20
length of src : 27
27
length of src : 28
28
length of src : 19
19
length of src : 25
25
length of src : 18
18
length of src : 20
20
length of src : 27
27
length of src : 19
19
length of src : 17
17
length of src : 19
19
length of src : 22
22
length of src : 32
32
length of src : 21
21
length of src : 19
19
length of src : 21
21
length of src : 20
20
length of src : 29
29
length of src : 20
20
length of src : 23
23
length of src : 26
26
length of src : 23
23
length of src : 17
17
length of src : 32
32
length of src : 22
22
length of src : 23
23
length of src : 18
18
length of src : 34
34
length of src : 21
21
length of src : 23
23
length of src : 30
30
length of src : 22
22
length of src : 24
24
length of src : 34
34
length of 

torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([33, 10]) torch.Size([33, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([19, 10]) torch.Size([19, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([24, 10]) torch.Size([2

torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([31, 10]) torch.Size([31, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([16, 10]) torch.Size([16, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([28, 10]) torch.Size([2

torch.Size([18, 10]) torch.Size([18, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([21, 10]) torch.Size([21, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([27, 10]) torch.Size([27, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([30, 10]) torch.Size([30, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([20, 10]) torch.Size([20, 10])
torch.Size([24, 10]) torch.Size([24, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([22, 10]) torch.Size([22, 10])
torch.Size([23, 10]) torch.Size([23, 10])
torch.Size([25, 10]) torch.Size([25, 10])
torch.Size([34, 10]) torch.Size([34, 10])
torch.Size([28, 10]) torch.Size([28, 10])
torch.Size([26, 10]) torch.Size([26, 10])
torch.Size([17, 10]) torch.Size([17, 10])
torch.Size([18, 10]) torch.Size([1

length of src : 20
20
length of src : 20
20
length of src : 18
18
length of src : 27
27
length of src : 24
24
length of src : 15
15
length of src : 34
34
length of src : 19
19
length of src : 20
20
length of src : 21
21
length of src : 19
19
length of src : 34
34
length of src : 18
18
length of src : 16
16
length of src : 28
28
length of src : 29
29
length of src : 20
20
length of src : 34
34
length of src : 28
28
length of src : 23
23
length of src : 21
21
length of src : 20
20
length of src : 20
20
length of src : 19
19
length of src : 34
34
length of src : 29
29
length of src : 24
24
length of src : 25
25
length of src : 27
27
length of src : 25
25
length of src : 30
30
length of src : 19
19
length of src : 31
31
length of src : 26
26
length of src : 27
27
length of src : 23
23
length of src : 24
24
length of src : 24
24
length of src : 20
20
length of src : 30
30
length of src : 28
28
length of src : 28
28
length of src : 28
28
length of src : 18
18
length of src : 21
21
length of 

FileNotFoundError: [Errno 2] No such file or directory: '/home/jrothschild/Projects/Abstract-generator/bumbleBERT/data/models/arxiv_10000.csv_epoch10.pth'

### Text Generation

Here I've simply taken the code Matt uses to generate text.

In [None]:
if not TRAIN:
    customFilename = 'arxiv_10000'
    customEpochs = 10
    modelFull = default.MODEL_DIR + os.sep + f'{customFilename}_epoch{customEpochs}_best.pth'
    modelWeights = default.MODEL_DIR + os.sep + f'{customFilename}_weights_epoch{customEpochs}_best.pth'
    
    # approach 1: load model (class) entirely (uses pickle)
    modelFullLoad = torch.load(modelFull, map_location=device)

    # approach 2: load model weights, need to have some parameter or something 
    modelLoad = TransformerModel(vocabSize, emsize, nhead, nhid, nlayers, dropout).to(device)
    modelWeightsLoad = modelLoad.load_state_dict( torch.load(modelWeights) )

In [13]:
# inspect both models
#print('model_A info...\n', modelFullLoad)
#print('\nmodel_B info...\n', modelWeightsLoad)

#print('model_A == model_B:', modelFullLoad == modelWeightsLoad)
#model = modelFullLoad
# Text generation example

#model = modelLoad
prompt = 'Electrons are protons. Protons are also electrons.'
ngen = 100
decode_style = 'greedy'
model.to('cpu')
generated_text = gen_some_text(
    model, tknzr, 'cpu', maxLen, text_prompt=prompt, tokens_to_gen=ngen, vis=False,
    decode_style=decode_style)
print("Text prompt:\n", prompt)
print("Number of tokens to generate:", ngen)
print("Generated_text:\n", generated_text)

# TODO: alternative generation
# currently 'greedy method'
# see: https://huggingface.co/blog/how-to-generate

<s>  Electrons  are  protons .  Protons  are  also  electrons .
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad>
<s>  Electrons  are  protons .  Protons  are  also  electrons . <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad>
<s>  Electrons  are  protons .  Protons  are  also  electrons . <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad>
<s>  Electrons  are  protons .  Protons  are  also  electrons . <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad>
<s>  Electrons  are  protons .  Protons  are  also  electrons . <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward

tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
 protons .  Protons  are  also  electrons . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
.  Protons  are  also  electrons . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pa

tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<pad> <pad

tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad

tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad

tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also electrons. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
tensor([nan, nan, nan,  ..., nan, nan, nan], grad_fn=<SelectBackward0>)
Electrons are protons. Protons are also ele

In [None]:
tknzr.decode([107])