In [1]:
import torch as t
import pytorch_lightning as pl
import wandb

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from fancy_einsum import einsum
from dataclasses import dataclass
from tqdm.notebook import tqdm_notebook
from einops import rearrange, reduce, repeat
from IPython.display import display
from typing import Callable
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

import sys 
sys.path.append('../common')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.version

'3.10.8 (main, Nov  4 2022, 13:48:29) [GCC 11.2.0]'

In [3]:
import general_modules as cm
import transformer_modules as tm
from transformer_modules import TransformerConfig
from nlp_modules import WordsDataset, WordsTokenizer
import sample_methods as s


In [4]:
config = TransformerConfig(
    num_layers=12, 
    num_heads=8, 
    vocab_size=34543, 
    hidden_size=256,
    max_seq_len=128,
    dropout=0.1)

In [5]:
tokenizer = WordsTokenizer(config.max_seq_len)
words_ds = WordsDataset(seq_len=config.max_seq_len, filename='100-0.txt', tokenizer=tokenizer, truncate=1.0)
trainloader = DataLoader(words_ds, batch_size=256, shuffle=True, num_workers=8)

In [6]:
class SHKTrainModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        # Key parameters
        self.save_hyperparameters(ignore=["model", "data"])
        self.model = tm.DecoderOnlyTransformer(config).train()

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model(x)
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        logits = rearrange(logits, 'B S V -> (B S) V')
        y = rearrange(y, 'B S -> (B S)')
        loss = self.criterion(logits, y)

        '''
        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)

            # log prediction examples to wandb
            
            initial_text = "turn down for what"
            text_output = s.sample_tokens(self.model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)
            
            try:
                self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=text_output)]})
            except AttributeError:
                pass
        '''
            

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        logits = rearrange(logits, 'B S V -> (B S) V')
        y = rearrange(y, 'B S -> (B S)')

        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = t.optim.AdamW(
            self.parameters()
        )

        return {"optimizer": optimizer}
        # return optimizer

In [7]:
pl_model = SHKTrainModule(config)

In [8]:
MODEL_FILENAME = "./w1d3_transformer_shakespeare.pt"

# Comment out if not using wandb
wandb_logger = WandbLogger(
    project="shk-transformer", 
    save_dir="training/logs/",
    log_model=True)
#wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=1,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger, # Comment out if not using wandb
    default_root_dir="training/checkpoints/",
    callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(pl_model, train_dataloaders=trainloader)

#print(f"Saving model to: {MODEL_FILENAME}")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33marena-ldn[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | DecoderOnlyTransformer | 18.3 M
1 | criterion | CrossEntropyLoss       | 0     
-----------------------------------------------------
18.3 M    Trainable params
0         Non-trainable params
18.3 M    Total params
73.283    Total estimated model params size (MB)


Epoch 0: 100%|██████████| 7765/7765 [2:56:36<00:00,  1.36s/it, loss=1.89, v_num=hwj4]  

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 7765/7765 [2:56:36<00:00,  1.36s/it, loss=1.89, v_num=hwj4]


In [11]:
import sample_methods as s

initial_text = "turn down for what"

text_output = s.sample_tokens(pl_model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)

print(text_output)

# turn down for what you do you think,
# That take the last, of many, which is so much I
# As this blows along than my life thou say’st, which makes thy hand,
# Thou wilt be given, or more
# Entitled in thy great world’s fresh blood will,
# To answer th’ alluring countenance, beauty

turn down for what thou art;
And for my good will I keep the throne
Through yon hideous sea.

KING RICHARD.
Uncle, how wilt thou do for a man?
Thy life is better than mine is in words;
Thy tongue is better than thy life is dear,
Therefore my heart, thy life for a
