In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


In [2]:
!pip3 install datasets wandb einops torchdata==0.7.1 portalocker==2.10.0



Models

LSTM

In [3]:
from functools import cached_property

import torch
from torch import nn


class CustomLSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, output_dim, finetune=True):
        super(CustomLSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.hidden_dim = hidden_dim
        self.input_dim = embed_dim
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.finetune = finetune
        self.output_dim = output_dim

        # Create layers
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            input_dim = embed_dim if i == 0 else hidden_dim
            self.layers.append(LSTMLayer(input_dim, hidden_dim))

        # Output layer weights and biases
        self.Wy = nn.Parameter(torch.empty(output_dim, hidden_dim))
        self.by = nn.Parameter(torch.zeros(output_dim, 1))

        # Initialize weights using Xavier initialization
        nn.init.xavier_uniform_(self.Wy)

    def forward(self, texts):
        batch_size, seq_len = texts.size()
        embedded = self.embedding(texts).permute(1, 2, 0)  # (seq_len, embed_dim, batch_size)

        h = [
            torch.zeros(self.hidden_dim, batch_size).to(texts.device)
            for _ in range(self.num_layers)
        ]
        c = [
            torch.zeros(self.hidden_dim, batch_size).to(texts.device)
            for _ in range(self.num_layers)
        ]

        if not self.finetune:
            y = torch.zeros(seq_len, batch_size, self.vocab_size).to(texts.device)  # (seq_len, batch_size, vocab_size)
            for t in range(seq_len):
                x = embedded[t, :, :]  # (embed_dim, batch_size)
                for i, layer in enumerate(self.layers):
                    h[i], c[i] = layer(x, h[i], c[i])
                    x = h[i]
                h_last = h[-1].permute(1, 0)  # (batch_size, hidden_dim)
                y[t] = torch.matmul(h_last, self.Wy.t()) + self.by.t()  # (batch_size, vocab_size)

            return y.permute(1, 0, 2)
        else:
            for t in range(seq_len):
                x = embedded[t, :, :]  # (embed_dim, batch_size)
                for i, layer in enumerate(self.layers):
                    h[i], c[i] = layer(x, h[i], c[i])
                    x = h[i]

            h_last = h[-1].permute(1, 0)  # (batch_size, hidden_dim)
            y = torch.matmul(h_last, self.Wy.t()) + self.by.t()  # (batch_size, output_dim)
            return y.squeeze(1)

    @cached_property
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class LSTMLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LSTMLayer, self).__init__()
        self.hidden_dim = hidden_dim

        # LSTM weights and biases
        self.Wf = nn.Parameter(torch.empty(hidden_dim, input_dim + hidden_dim))
        self.bf = nn.Parameter(torch.zeros(hidden_dim, 1))
        self.Wi = nn.Parameter(torch.empty(hidden_dim, input_dim + hidden_dim))
        self.bi = nn.Parameter(torch.zeros(hidden_dim, 1))
        self.Wo = nn.Parameter(torch.empty(hidden_dim, input_dim + hidden_dim))
        self.bo = nn.Parameter(torch.zeros(hidden_dim, 1))
        self.Wc = nn.Parameter(torch.empty(hidden_dim, input_dim + hidden_dim))
        self.bc = nn.Parameter(torch.zeros(hidden_dim, 1))

        # Initialize weights using Xavier initialization
        nn.init.xavier_uniform_(self.Wf)
        nn.init.xavier_uniform_(self.Wi)
        nn.init.xavier_uniform_(self.Wo)
        nn.init.xavier_uniform_(self.Wc)

    def sigmoid(self, x):
        return torch.sigmoid(x)

    def tanh(self, x):
        return torch.tanh(x)

    def forward(self, x, h, c):
        concat = torch.cat((h, x), dim=0)  # (hidden_dim + input_dim, batch_size)

        ft = self.sigmoid(
            torch.matmul(self.Wf, concat) + self.bf
        )  # (hidden, batch_size)
        it = self.sigmoid(
            torch.matmul(self.Wi, concat) + self.bi
        )  # (hidden, batch_size)
        c_hat = self.tanh(
            torch.matmul(self.Wc, concat) + self.bc
        )  # (hidden, batch_size)
        c = ft * c + it * c_hat  # (hidden, batch_size)
        ot = self.sigmoid(
            torch.matmul(self.Wo, concat) + self.bo
        )  # (hidden, batch_size)
        h = ot * self.tanh(c)

        return h, c


S4

In [4]:
import math
from functools import cached_property

import torch
from einops import rearrange, repeat
from torch import nn


class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)."""
        if self.training:
            if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
            mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
            X = X * mask * (1.0 / (1 - self.p))
            if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
            return X
        return X


class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, hidden_dim, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = hidden_dim
        log_dt = torch.rand(H) * (
                math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N // 2))
        A_imag = math.pi * repeat(torch.arange(N // 2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt)  # (H)
        C = torch.view_as_complex(self.C)  # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag  # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device)  # (H N L)
        C = C * (torch.exp(dtA) - 1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    def __init__(self, hidden_dim, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.h = hidden_dim
        self.n = d_state
        self.output_dim = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))

        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)

        # Pointwise
        self.activation = nn.GELU()
        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2 * self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs):  # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)

        # Compute SSM Kernel
        k = self.kernel(L=L)  # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2 * L)  # (H L)
        u_f = torch.fft.rfft(u, n=2 * L)  # (B H L)
        y = torch.fft.irfft(u_f * k_f, n=2 * L)[..., :L]  # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y, None


class S4Model(nn.Module):

    def __init__(
            self,
            embed_dim,
            vocab_size,
            output_dim,
            hidden_dim=256,
            num_layers=4,
            dropout=0.1,
            lr=0.001,
            prenorm=False,
            finetune=True
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.prenorm = prenorm

        # Linear encoder (embed_dim = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(embed_dim, hidden_dim)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(num_layers):
            self.s4_layers.append(
                S4D(hidden_dim, dropout=dropout, transposed=True, lr=lr)
            )
            self.norms.append(nn.LayerNorm(hidden_dim))
            self.dropouts.append(nn.Dropout(dropout))

        # Linear decoder
        self.decoder = nn.Linear(hidden_dim, output_dim)
        self.finetune = finetune

    def forward(self, x):
        """
        Input x is shape (B, L, embed_dim)
        """
        x = self.embedding(x)  # -> (B, L, embed_dim)

        x = self.encoder(x)  # (B, L, embed_dim) -> (B, L, hidden_dim)

        x = x.transpose(-1, -2)  # (B, L, hidden_dim) -> (B, hidden_dim, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, hidden_dim, L) -> (B, hidden_dim, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        if self.finetune:
            x = x.mean(dim=1)
            # Decode the outputs
            x = self.decoder(x)  # (B, hidden_dim) -> (B, output_dim)
            return x.squeeze(-1)

        else:
            x = self.decoder(x)  # (B, hidden_dim) -> (B, output_dim)
            return x.squeeze(-1).permute(1, 0, 2)

    @cached_property
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


Transformer

In [5]:
import math
from functools import cached_property

import torch
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_dim

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = 1 / math.sqrt(self.head_dim)

    def forward(self, x, mask=None):
        batch_size, seq_length, embed_dim = x.size()

        qkv = self.qkv_proj(x)  # (batch_size, seq_length, embed_dim * 3)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, 3 * head_dim)

        q, k, v = qkv.chunk(3, dim=-1)  # each will be (batch_size, num_heads, seq_length, head_dim)

        attn_scores = torch.einsum('bnqd,bnkd->bnqk', q,
                                   k) * self.scale  # (batch_size, num_heads, seq_length, seq_length)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = nn.functional.softmax(attn_scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)

        attn_output = torch.einsum('bnqk,bnvd->bnqd', attn_probs, v)  # (batch_size, num_heads, seq_length, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_length, embed_dim)  # (batch_size, seq_length, embed_dim)

        output = self.o_proj(attn_output)  # (batch_size, seq_length, embed_dim)
        return output


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        return x


class CustomTransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, output_dim, dropout=0.1,
                 finetune=True):
        super(CustomTransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, hidden_dim, dropout)
            for _ in range(num_layers)])
        self.fc_out = nn.Linear(embed_dim, output_dim)
        self.finetune = finetune

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).T
        mask = mask.int()
        return mask

    def forward(self, x):
        batch_size, seq_len = x.size()
        mask = self.generate_square_subsequent_mask(seq_len).to(x.device)
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, mask)
        if self.finetune:
            x = x.mean(dim=1)
            x = self.fc_out(x)
            return x.squeeze(1)
        else:
            x = self.fc_out(x)
            return x.squeeze(2)

    @cached_property
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


Logger

In [6]:
import logging
import sys

# Setup basic configuration for logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        logging.StreamHandler(sys.stdout)  # Ensure logs are directed to stdout
    ]
)


def setup_logger(name=__name__):
    logger = logging.getLogger(name)
    return logger


Datasets

IMDB Dataset

In [7]:
import torch
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.datasets import IMDB


class IMDBDataset:
    # Initialize BPE tokenizer
    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()

    trainer = trainers.BpeTrainer(
        vocab_size=10000, special_tokens=["<unk>", "<pad>", "<s>", "</s>"]
    )

    @staticmethod
    def yield_texts(data_iter):
        for _, text in data_iter:
            yield text

    @staticmethod
    def get_tokenizer_and_vocab():
        train_iter = IMDB(split="train")
        IMDBDataset.tokenizer.train_from_iterator(
            IMDBDataset.yield_texts(train_iter), IMDBDataset.trainer
        )
        return IMDBDataset.tokenizer

    @staticmethod
    def text_pipeline(text, tokenizer):
        return tokenizer.encode(text).ids

    @staticmethod
    def label_pipeline(label):
        return label - 1

    @staticmethod
    def collate_batch(batch, tokenizer):
        label_list, text_list = [], []
        for _label, _text in batch:
            label_list.append(IMDBDataset.label_pipeline(_label))
            processed_text = torch.tensor(
                IMDBDataset.text_pipeline(_text, tokenizer), dtype=torch.int64
            )
            text_list.append(processed_text)
        return torch.tensor(label_list, dtype=torch.int64), pad_sequence(
            text_list, padding_value=tokenizer.token_to_id("<pad>"), batch_first=True
        )

    @staticmethod
    def get_dataloaders(batch_size):
        tokenizer = IMDBDataset.tokenizer
        train_iter, test_iter = IMDB(split="train"), IMDB(split="test")
        train_dataloader = DataLoader(
            list(train_iter),
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda x: IMDBDataset.collate_batch(x, tokenizer),
        )
        test_dataloader = DataLoader(
            list(test_iter),
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda x: IMDBDataset.collate_batch(x, tokenizer),
        )
        return train_dataloader, test_dataloader




Wikitext Dataset

In [8]:
import os

import torch
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

# Initialize BPE tokenizer

logger = setup_logger(__name__)


class WikiTextDataset(Dataset):
    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()

    trainer = trainers.BpeTrainer(
        vocab_size=10000, special_tokens=["<unk>", "<pad>", "<s>", "</s>"]
    )

    def __init__(self, split, tokenizer):
        self.dataset = load_dataset(
            "Salesforce/wikitext", "wikitext-103-raw-v1", split=split
        ).filter(lambda x: x["text"].strip() != "")
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]["text"]

    @staticmethod
    def yield_texts(data_iter):
        for text in data_iter:
            yield text

    @staticmethod
    def get_tokenizer_and_vocab():
        tokenizer_file = "wikitext_tokenizer.json"

        # Check if the tokenizer file already exists
        if os.path.exists(tokenizer_file):
            logger.info("Tokenizer loaded from file.")
            WikiTextDataset.tokenizer = Tokenizer.from_file(tokenizer_file)
        else:
            train_iter = load_dataset(
                "Salesforce/wikitext", "wikitext-103-raw-v1", split="train"
            ).filter(lambda x: x["text"].strip() != "")["text"]
            WikiTextDataset.tokenizer.train_from_iterator(
                WikiTextDataset.yield_texts(train_iter), WikiTextDataset.trainer
            )

            # Save the tokenizer to a file
            WikiTextDataset.tokenizer.save(tokenizer_file)
            logger.info("Tokenizer trained and saved to file.")

        return WikiTextDataset.tokenizer

    @staticmethod
    def text_pipeline(text):
        return WikiTextDataset.tokenizer.encode(text).ids

    @staticmethod
    def collate_batch(batch):
        text_list = []
        for _text in batch:
            processed_text = torch.tensor(
                WikiTextDataset.text_pipeline(_text),
                dtype=torch.int64,
            )
            text_list.append(processed_text)
        # Dummy labels for wikitext
        return torch.zeros(len(text_list), dtype=torch.int64), pad_sequence(
            text_list,
            padding_value=WikiTextDataset.tokenizer.token_to_id("<pad>"),
            batch_first=True,
        )

    @staticmethod
    def get_dataloaders(batch_size):
        tokenizer = WikiTextDataset.tokenizer
        train_dataset = WikiTextDataset(split="train", tokenizer=tokenizer)
        test_dataset = WikiTextDataset(split="test", tokenizer=tokenizer)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda x: WikiTextDataset.collate_batch(x),
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda x: WikiTextDataset.collate_batch(x),
        )
        return train_dataloader, test_dataloader


2024-07-27 23:32:37 - datasets - INFO - PyTorch version 2.3.1+cu118 available.


  from .autonotebook import tqdm as notebook_tqdm


Train Evaluate

Pretrain Train_Evaluate

In [9]:
import torch

import wandb


def pretrain_train(model, dataloader, criterion, optimizer, device, epoch, logger, use_wandb):
    model.train()
    total_loss = 0
    for batch_idx, (_, texts) in enumerate(dataloader):
        texts = texts.to(device)
        optimizer.zero_grad()

        input_texts = texts[:, :-1]
        target_texts = texts[:, 1:]

        output = model(input_texts)
        loss = criterion(output.reshape(-1, output.size(-1)), target_texts.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 10 == 0:
            logger.info(
                f"Train Epoch: {epoch + 1} [{batch_idx * len(texts)}/{len(dataloader.dataset)} "
                f"({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item():.6f}"
            )
            if use_wandb:
                wandb.log({"train_batch_loss": loss.item()})
            torch.cuda.empty_cache()

    avg_loss = total_loss / len(dataloader)
    logger.info(f"====> Epoch: {epoch + 1} Average loss: {avg_loss:.4f}")
    return avg_loss


def pretrain_evaluate(model, dataloader, criterion, device, logger, use_wandb):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (_, texts) in enumerate(dataloader):
            texts = texts.to(device)

            # Shift the target texts by one time step for the decoder input
            input_texts = texts[:, :-1]
            target_texts = texts[:, 1:]

            output = model(input_texts)
            loss = criterion(output.reshape(-1, output.size(-1)), target_texts.reshape(-1))
            total_loss += loss.item()

            if batch_idx % 10 == 0:
                logger.info(
                    f"Eval Batch: {batch_idx + 1}/{len(dataloader)}\tLoss: {loss.item():.6f}"
                )
                if use_wandb:
                    wandb.log({"eval_batch_loss": loss.item()})
                torch.cuda.empty_cache()
            

    avg_loss = total_loss / len(dataloader)
    logger.info(f"====> Test set loss: {avg_loss:.4f}")
    return avg_loss


Train Evaluate

In [10]:
import torch

import wandb


def finetune_train(model, dataloader, criterion, optimizer, device, epoch, logger, use_wandb):
    model.train()
    total_loss = 0
    for batch_idx, (labels, texts) in enumerate(dataloader):
        labels, texts = labels.to(device), texts.to(device)
        optimizer.zero_grad()
        output = model(texts)
        loss = criterion(output, labels.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 10 == 0:
            logger.info(
                f"Train Epoch: {epoch + 1} [{batch_idx * len(labels)}/{len(dataloader.dataset)} "
                f"({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item():.6f}"
            )
            if use_wandb:
                wandb.log({"train_batch_loss": loss.item()})
            torch.cuda.empty_cache()

    avg_loss = total_loss / len(dataloader)
    logger.info(f"====> Epoch: {epoch + 1} Average loss: {avg_loss:.4f}")
    return avg_loss


def finetune_evaluate(model, dataloader, criterion, device, logger, use_wandb):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (labels, texts) in enumerate(dataloader):
            labels, texts = labels.to(device), texts.to(device)
            output = model(texts)
            loss = criterion(output, labels.float())
            total_loss += loss.item()

            if batch_idx % 10 == 0:
                logger.info(
                    f"Eval Batch: {batch_idx + 1}/{len(dataloader)}\tLoss: {loss.item():.6f}"
                )
                if use_wandb:
                    wandb.log({"eval_batch_loss": loss.item()})
                torch.cuda.empty_cache()

    avg_loss = total_loss / len(dataloader)
    logger.info(f"====> Test set loss: {avg_loss:.4f}")
    return avg_loss


Config

In [11]:
config = {
    "models": {
        "lstm": {
            "embed_dim": 16,
            "hidden_dim": 128,
            "num_layers": 2,
            "output_dim": 1
        },
        "s4": {
            "embed_dim": 24,
            "hidden_dim": 128,
            "output_dim": 1,
            "num_layers": 2
        },
        "transformer": {
            "embed_dim": 32,
            "num_heads": 2,
            "hidden_dim": 128,
            "output_dim": 1,
            "num_layers": 4
        }
    },
    "run_parameters": {
        "batch_size": 1,
        "n_epochs": 2,
        "learning_rate": 0.001
    }
}


Utils

In [12]:
import torch


def replace_final_layer(model, config, model_name, device):
    hidden_dim = config["models"][model_name]["hidden_dim"]
    embed_dim = config["models"][model_name]["embed_dim"]
    if model_name == "lstm":
        output_dim = 1
        model.Wy = torch.nn.Parameter(torch.empty(output_dim, hidden_dim).to(device))
        model.by = torch.nn.Parameter(torch.zeros(output_dim, 1).to(device))
        torch.nn.init.xavier_uniform_(model.Wy)
    elif model_name == "s4":
        output_dim = 1
        model.decoder = torch.nn.Linear(hidden_dim, output_dim).to(device)
    else:
        output_dim = 1
        model.fc_out = torch.nn.Linear(embed_dim, output_dim).to(device)


Main

In [17]:
import datetime

import torch
import wandb



model_options = list(config["models"].keys())
run_types = ["task", "lra_pretrain", "wikitext_pretrain", "task_finetune_lra_pretrain", "task_finetune_wikitext_pretrain"]
class Args:
    def __init__(self, model, run_type, use_wandb = False):
        self.model = model
        self.run_type = run_type
        self.use_wandb = use_wandb

args = Args(model="transformer", run_type="wikitext_pretrain", use_wandb=False)

model_config = config["models"][args.model]
run_parameters = config["run_parameters"]

# Setup logging
logger = setup_logger(__name__)

run_name = f"run_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_{args.model}_{args.run_type}"

batch_size, n_epochs, learning_rate = run_parameters["batch_size"], run_parameters["n_epochs"], run_parameters[
    "learning_rate"]

# Load tokenizer and data loaders
if "wikitext" in args.run_type and "finetune" in args.run_type:
    get_tokenizer_and_vocab = WikiTextDataset.get_tokenizer_and_vocab
    get_dataloaders = IMDBDataset.get_dataloaders
elif "wikitext" in args.run_type:
    get_tokenizer_and_vocab = WikiTextDataset.get_tokenizer_and_vocab
    get_dataloaders = WikiTextDataset.get_dataloaders
else:
    get_tokenizer_and_vocab = IMDBDataset.get_tokenizer_and_vocab
    get_dataloaders = IMDBDataset.get_dataloaders
tokenizer = get_tokenizer_and_vocab()
train_dataloader, test_dataloader = get_dataloaders(batch_size)

# Initialize model, criterion, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(tokenizer.get_vocab())
finetune = args.run_type in ["task", "task_finetune_lra_pretrain", "task_finetune_wikitext_pretrain"]

model_dic = {
    "lstm": CustomLSTMModel,
    "s4": S4Model,
    "transformer": CustomTransformerModel
}

if args.run_type == "task":
    config["models"][args.model]["output_dim"] = 1
else:
    config["models"][args.model]["output_dim"] = vocab_size

model = model_dic[args.model](vocab_size=vocab_size, **config["models"][args.model], finetune=finetune)
model.to(device)

pretrained = "lra_pretrained" if "lra" in args.run_type else "wikitext_pretrained"
checkpoint_path = f"{args.model}_{pretrained}.pth"
if finetune and args.run_type != "task":
    model.load_state_dict(torch.load(checkpoint_path))
    logger.info(f"Checkpoint loaded from {checkpoint_path}")
    replace_final_layer(model, config, args.model, device)

criterion = torch.nn.BCEWithLogitsLoss() if finetune else torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

logger.info(f"parameter_cnt: {model.count_parameters}")

if args.use_wandb:
    # Initialize WandB
    wandb.login(key="5fda0926085bc8963be5e43c4e501d992e35abe8")
    wandb.init(project="model-comparison", name=run_name)

    # Log hyperparameters and model
    wandb.config.update(
        {
            **{
                "run_name": run_name,
                "model": args.model,
                "parameter_cnt": model.count_parameters,
            },
            **config["models"][args.model],
            **config["run_parameters"]}
    )

# Training loop
logger.info("Starting training...")

if finetune:
    train, evaluate = finetune_train, finetune_evaluate
else:
    train, evaluate = pretrain_train, pretrain_evaluate

for epoch in range(n_epochs):
    train_loss = train(model, train_dataloader, criterion, optimizer, device, epoch, logger, args.use_wandb)
    test_loss = evaluate(model, test_dataloader, criterion, device, logger, args.use_wandb)
    logger.info(f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

    if args.use_wandb:
        # Log metrics to WandB
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "test_loss": test_loss
        })

pretrained = "lra_pretrained" if "lra" in args.run_type else "wikitext_pretrained"
checkpoint_path = f"{args.model}_{pretrained}.pth"
if not finetune:
    torch.save(model.state_dict(), checkpoint_path)
    logger.info(f"Checkpoint saved at {checkpoint_path}")

logger.info("Training completed.")

if args.use_wandb:
    wandb.finish()


2024-07-27 23:46:11 - __main__ - INFO - Tokenizer loaded from file.
2024-07-27 23:46:19 - __main__ - INFO - parameter_cnt: 700816
2024-07-27 23:46:19 - __main__ - INFO - Starting training...


KeyboardInterrupt: 