In [1]:
from datasets import load_dataset, DatasetDict

In [2]:
from imdb_dataset import ImdbDataset

In [3]:
from transformers import RobertaTokenizer, PreTrainedTokenizer

In [4]:
from torch.utils.data.dataloader import DataLoader

In [5]:
import torch
from torch import Tensor

In [6]:
from torch.utils.data.dataset import Dataset

In [7]:
import pytorch_lightning as pl

In [8]:
import wandb
from pytorch_lightning.loggers import WandbLogger

In [9]:
run = wandb.init(
    project="demo",
    name="imdb-demo",
    tags=["demo"],
    config={
        "batch_size": 8,
        "lr": 0.0001,
        "max_epochs": 5,
        "dropout": 0.1,
        "gemma": 0.99,
        "lr_step": 1,
        "embed_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "max_len": 2048,
        "hidden_dim": 2048,
        "num_workers": 8,
    }
)
logger = WandbLogger(run)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011120540277746234, max=1.0…

In [10]:
tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-large", cache_dir="./data/")

In [11]:
wandb.define_metric("second_per_batch", summary="mean")
wandb.define_metric("val_acc", summary="max")

<wandb.sdk.wandb_metric.Metric at 0x16b6d6750>

In [12]:
import torch.nn as nn
import numpy as np

In [13]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [14]:
class ImdbModel(pl.LightningModule):
    
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        embed_dim = 512,
        hidden_dim = 2048,
        max_len = 1024,
        num_heads = 8,
        num_layers = 8,
        dropout = 0.1,
    ):
        super(ImdbModel, self).__init__()
        self.tokenizer = tokenizer
        self.embed = nn.Sequential(
            nn.Embedding(tokenizer.vocab_size, embed_dim),
            PositionalEncoding(embed_dim, dropout, max_len),
        )
        self.encs = nn.ModuleList(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dropout=dropout,
            )
            for _ in range(num_layers)
        )
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )
        
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, x):
        x = self.embed(x)
        for enc in self.encs:
            x = enc(x)
        x = x.mean(dim=1)
        return self.classifier(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=wandb.config.lr)
        sched = torch.optim.lr_scheduler.StepLR(optim, step_size=wandb.config.lr_step, gamma=wandb.config.gemma)
        return [optim], [sched]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("val_loss", loss)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("test_loss", loss)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("test_acc", acc)
        return loss

In [15]:
from typing import Any
from pytorch_lightning import LightningModule, Trainer
from time import time

class SpeedCounterCallback(pl.Callback):
    
    def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
        self.batch_start_time = time()
    
    def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None:
        self.batch_end_time = time()
        wandb.log({"second_per_batch": self.batch_end_time - self.batch_start_time})

In [16]:
if __name__ == "__main__":
    trainer = pl.Trainer(
        max_epochs=wandb.config.max_epochs,
        logger=logger,
        callbacks=[
            SpeedCounterCallback(),
            pl.callbacks.LearningRateMonitor(logging_interval="step"),
        ],
        accelerator="gpu",
        strategy="deepspeed_stage_1"
    )
    model = ImdbModel(
        tokenizer,
        run.config.embed_dim,
        run.config.hidden_dim,
        run.config.max_len,
        run.config.num_heads,
        run.config.num_layers,
        run.config.dropout,
    )
    from imdb_dataset import ImdbDataModule
    data_module = ImdbDataModule(
        tokenizer,
        run.config.batch_size,
        run.config.num_workers,
        run.config.max_len
    )
    trainer.fit(model, data_module)

ValueError: You set `strategy=deepspeed_stage_1` but strategies from the DDP family are not supported on the MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy.