In [1]:
import random
random.seed(42)
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.autograd import Variable as V
import torchtext
from torchtext import data
from pytorch_lightning_lm.data_module import QuotesDataModule
from pytorch_lightning_lm.metrics import Perplexity
from pytorch_lightning.loggers import WandbLogger
from argparse import ArgumentParser



## Transformers

In [14]:
from pytorch_lightning_lm.model import TransformerModel

In [8]:
parser = ArgumentParser()
parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")

# add PROGRAM level args
parser.add_argument('--project-name', type=str, default='transformer_lms')
parser.add_argument('--experiment-tag', type=str, default='Transformer_LM')
parser.add_argument('--use-cuda', type=bool, default=True)
parser.add_argument('--use-wandb', type=bool, default=True)
parser.add_argument('--log-gradients', type=bool, default=True)
parser.add_argument('--unk-cutoff', type=int, default=1)

# add model specific args
# parser = LitModel.add_model_specific_args(parser)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--accumulate-grad-batches', type=int, default=4)
parser.add_argument('--bptt', type=int, default=16)
parser.add_argument('--nhid', type=int, default=64)
parser.add_argument('--nhead', type=int, default=12)
parser.add_argument('--nlayers', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight-decay', type=float, default=0)
parser.add_argument('--pretrained-vector', type=str, default="fasttext.simple.300d")

# add all the available trainer options to argparse
parser.add_argument('--max_epochs', type=int, default=25)
parser.add_argument('--fast_dev_run', type=bool, default=False)
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
# parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

In [9]:
device = torch.device('cuda') if (torch.cuda.is_available()&args.use_cuda) else torch.device('cpu')
experiment_name = f"{args.experiment_tag}_{args.batch_size}_{args.bptt}_{args.nhead}_{args.nhid}_{args.nlayers}"
print(experiment_name)

Transformer_LM_64_16_12_64_2


In [10]:
dm = QuotesDataModule(
    train_file="data/quotesdb/funny_quotes.train.txt",
    valid_file="data/quotesdb/funny_quotes.val.txt",
    test_file="data/quotesdb/funny_quotes.test.txt",
    tokenizer=None,
    unk_limit = args.unk_cutoff,
    batch_size=args.batch_size,
    bptt=args.bptt,
    pretrained_vectors=args.pretrained_vector,
)


Field class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.


Example class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.



In [15]:
vocab = dm.vocab
weight_matrix = vocab.vectors
ntoken, ninp = weight_matrix.shape

pad_idx = vocab.stoi["<pad>"]

ppl = Perplexity(pad_idx)
model = TransformerModel(
    ntoken=ntoken, 
    ninp=ninp,
    nhead=args.nhead,
    nhid=args.nhid, 
    nlayers=args.nlayers, 
    batch_size=args.batch_size,
    dropout=args.dropout,
    device_type= device.type,
    weight_decay = args.weight_decay,
    lr=args.lr,
    pretrained_vectors=weight_matrix, 
    metric=ppl,
    log_grad_norm = True
)

if args.use_wandb:
    wandb_logger = WandbLogger(name=experiment_name,project=args.project_name)
    if args.log_gradients:
        wandb_logger.watch(model, log='gradients', log_freq=100)
    logger= wandb_logger
    logger.log_hyperparams({"bptt":args.bptt,
                           "pretrained_vector": args.pretrained_vector,
                           "unk_limit": args.unk_cutoff})
else:
    logger= True

if args.fast_dev_run:
    logger = None
    
early_stop_callback = pl.callbacks.EarlyStopping(
   min_delta=0.01,
   patience=2,
   verbose=False,
   mode='min'
)

trainer = pl.Trainer(gpus=1 if device.type =='cuda' else 0, 
                     max_epochs=args.max_epochs, 
                     min_epochs = 5,
                     logger=logger, 
                     auto_lr_find=False if args.fast_dev_run else True,
                    fast_dev_run=args.fast_dev_run,
                     accumulate_grad_batches = args.accumulate_grad_batches,
                    early_stop_callback=early_stop_callback)

trainer.fit(model, datamodule=dm)
if not args.fast_dev_run:
    trainer.save_checkpoint(f"models/{experiment_name}.ckpt")
    torch.save(dm.vocab, f"models/{experiment_name}_vocab.sav")
    trainer.auto_lr_find = False
    test_eval = trainer.test(model, datamodule=dm)
    logger.log_metrics({
        "test_ppl":test_eval[0]['test_ppl'],
        "test_loss":test_eval[0]['test_loss']
    })

wandb: Wandb version 0.9.6 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | criterion           | CrossEntropyLoss   | 0     
1 | metric              | Perplexity         | 0     
2 | pos_encoder         | PositionalEncoding | 0     
3 | transformer_encoder | TransformerEncoder | 802 K 
4 | encoder             | Embedding          | 13 M  
5 | decoder             | Linear             | 13 M  
6 | drop                | Dropout            | 0     


HBox(children=(FloatProgress(value=0.0, description='Finding best initial lr', style=ProgressStyle(description…

Learning rate set to 0.002754228703338169

  | Name                | Type               | Params
-----------------------------------------------------------
0 | criterion           | CrossEntropyLoss   | 0     
1 | metric              | Perplexity         | 0     
2 | pos_encoder         | PositionalEncoding | 0     
3 | transformer_encoder | TransformerEncoder | 802 K 
4 | encoder             | Embedding          | 13 M  
5 | decoder             | Linear             | 13 M  
6 | drop                | Dropout            | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

AttributeError: 'Run' object has no attribute 'Image'