In [2]:
!pip install sacremoses

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: sacremoses
Successfully installed sacremoses-0.1.1


In [3]:
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from math import sqrt
from loguru import logger
from zeta import FeedForward, OutputHead
from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm

class DiffAttn(nn.Module):
    def __init__(self, d: int, embedding_dim: int):
        super(DiffAttn, self).__init__()
        self.d = d
        self.W_q = nn.Linear(embedding_dim, 2 * d)
        self.W_k = nn.Linear(embedding_dim, 2 * d)
        self.W_v = nn.Linear(embedding_dim, d)

    def forward(self, X: Tensor, λ: float) -> Tensor:
        logger.info("Executing DiffAttn forward pass")

        Q = self.W_q(X)
        K = self.W_k(X)
        V = self.W_v(X)

        Q1, Q2 = self.split(Q)
        K1, K2 = self.split(K)

        s = 1 / sqrt(self.d)

        A1 = (Q1 @ K1.transpose(-1, -2)) * s
        A2 = (Q2 @ K2.transpose(-1, -2)) * s

        A1_softmax = F.softmax(A1, dim=-1)
        A2_softmax = F.softmax(A2, dim=-1)

        result = (A1_softmax - λ * A2_softmax) @ V
        return result

    @staticmethod
    def split(X: Tensor) -> (Tensor, Tensor):
        half_dim = X.shape[-1] // 2
        return X[..., :half_dim], X[..., half_dim:]

class MultiHeadDifferentialAttention(nn.Module):
    def __init__(self, h: int, d: int, embedding_dim: int, λinit: float):
        super(MultiHeadDifferentialAttention, self).__init__()
        self.h = h
        self.d = d
        self.λinit = λinit
        self.embedding_dim = embedding_dim
        self.diff_attn_heads = nn.ModuleList([DiffAttn(d, embedding_dim) for _ in range(h)])
        self.W_o = nn.Linear(h * d, embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)

    def forward(self, X: Tensor, λ: float) -> Tensor:
        logger.info("Executing MultiHead forward pass")

        O_list = [head(X, λ) for head in self.diff_attn_heads]
        O_concat = torch.cat(O_list, dim=-1)

        result = self.W_o(O_concat)
        result = self.norm(result)
        result = result * (1 - self.λinit)

        return result

class DifferentialTransformerBlock(nn.Module):
    def __init__(self, dim: int, heads: int = 12, dropout: float = 0.1, λinit: float = 0.05, *args, **kwargs):
        super(DifferentialTransformerBlock, self).__init__()
        self.dim = dim
        self.heads = heads
        self.dropout = dropout
        self.λinit = λinit

        self.attn = MultiHeadDifferentialAttention(heads, dim, dim, *args, λinit=λinit, **kwargs)
        self.ffn = FeedForward(dim, dim, mult=4, swiglu=True, post_act_ln=True)
        self.norm = SimpleRMSNorm(dim)

    def forward(self, x: Tensor, λ: float = 0.1, *args, **kwargs):
        residual = x
        attended = self.attn(self.norm(x), λ) + residual
        logger.info(f"First attention output shape: {attended.shape}")

        residual_two = attended
        attended = self.attn(self.norm(residual_two), λ) + residual_two
        logger.info(f"Second attention output shape: {attended.shape}")

        return attended

class DifferentialTransformer(nn.Module):
    def __init__(self, dim: int = 3072, heads: int = 12, dropout: float = 0.1, λinit: float = 0.8, depth: int = 24, num_tokens: int = 30000, *args, **kwargs):
        super(DifferentialTransformer, self).__init__()
        self.dim = dim
        self.heads = heads
        self.dropout = dropout
        self.λinit = λinit
        self.depth = depth
        self.num_tokens = num_tokens

        self.layers = nn.ModuleList(
            [
                DifferentialTransformerBlock(
                    dim=dim,
                    heads=heads,
                    dropout=dropout,
                    λinit=λinit,
                    *args, **kwargs
                ) for _ in range(depth)
            ]
        )

        self.embed = nn.Embedding(num_embeddings=num_tokens, embedding_dim=dim)
        self.norm = SimpleRMSNorm(dim)

    def forward(self, x, λ: float = 0.1):
        device = next(self.parameters()).device
        x = x.to(device).long()

        # Clamp indices to ensure they are within embedding vocabulary size
        x = x.clamp(0, self.num_tokens - 1)

        x = self.embed(x)
        x = self.norm(x)

        for layer in self.layers:
            x = layer(x)

        return OutputHead(self.dim, vocab_size=self.num_tokens)(x)

class TranslationTransformer(nn.Module):
    def __init__(self, dim: int = 512, heads: int = 8, dropout: float = 0.1, λinit: float = 0.8, depth: int = 6, src_vocab_size: int = 0, tgt_vocab_size: int = 0, max_seq_length: int = 128):
        super().__init__()
        self.dim = dim

        self.encoder = DifferentialTransformer(dim=dim, heads=heads, dropout=dropout, λinit=λinit, depth=depth, num_tokens=src_vocab_size)
        self.decoder = DifferentialTransformer(dim=dim, heads=heads, dropout=dropout, λinit=λinit, depth=depth, num_tokens=tgt_vocab_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_length, dim))

    def forward(self, src, tgt, λ: float = 0.1):
        device = src.device
        src = src.long().to(device)
        tgt = tgt.long().to(device)

        # Clamp to ensure indices are within bounds
        src = src.clamp(0, self.encoder.num_tokens - 1)
        tgt = tgt.clamp(0, self.decoder.num_tokens - 1)

        src = self.encoder.embed(src) + self.pos_embedding[:, :src.size(1), :].expand(src.size(0), -1, -1)
        tgt = self.decoder.embed(tgt) + self.pos_embedding[:, :tgt.size(1), :].expand(tgt.size(0), -1, -1)

        enc_output = self.encoder(src, λ)
        dec_output = self.decoder(tgt, λ)

        return dec_output

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

dataset = load_dataset("opus_books", "en-fr", split="train[:1000]")

src_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tgt_tokenizer = AutoTokenizer.from_pretrained("flaubert/flaubert_base_uncased")

def preprocess_data(examples):
    src_texts = [item['en'] for item in examples['translation']]
    tgt_texts = [item['fr'] for item in examples['translation']]

    src_encoded = src_tokenizer(src_texts, padding=True, truncation=True, max_length=128)
    tgt_encoded = tgt_tokenizer(tgt_texts, padding=True, truncation=True, max_length=128)

    return {
        "src_ids": src_encoded["input_ids"],
        "tgt_ids": tgt_encoded["input_ids"]
    }

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TranslationTransformer(
        src_vocab_size=src_tokenizer.vocab_size,
        tgt_vocab_size=tgt_tokenizer.vocab_size
    ).to(device)

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_tokenizer.pad_token_id)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

    for epoch in range(10):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            optimizer.zero_grad()

            src_ids = torch.stack([torch.tensor(ids, dtype=torch.long) for ids in batch["src_ids"]]).to(device)
            tgt_ids = torch.stack([torch.tensor(ids, dtype=torch.long) for ids in batch["tgt_ids"]]).to(device)

            tgt_input = tgt_ids[:, :-1]
            tgt_labels = tgt_ids[:, 1:]

            outputs = model(src_ids, tgt_input)

            outputs = outputs.view(-1, outputs.size(-1))
            tgt_labels = tgt_labels.view(-1)

            loss = criterion(outputs, tgt_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

        # Step scheduler
        scheduler.step()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")

# Create training dataloader
processed_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset.column_names)
dataloader = DataLoader(processed_dataset, batch_size=16, shuffle=True)


Downloading readme:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 21.0M/21.0M [00:01<00:00, 11.1MB/s]


Generating train split:   0%|          | 0/127085 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.56M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/917k [00:00<?, ?B/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
train_model()

[32m2024-11-20 18:20:40.508[0m | [1mINFO    [0m | [36m__main__[0m:[36mforward[0m:[36m55[0m - [1mExecuting MultiHead forward pass[0m
[32m2024-11-20 18:20:40.815[0m | [1mINFO    [0m | [36m__main__[0m:[36mforward[0m:[36m18[0m - [1mExecuting DiffAttn forward pass[0m
