In [5]:
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
from diff_attn import DifferentialAttention

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diff_attn import DifferentialAttention

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TranslationEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, d_head, n_heads, n_layers, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'attention': DifferentialAttention(d_model, d_head, n_heads, dropout),
                'norm1': nn.LayerNorm(d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, 4 * d_model),
                    nn.ReLU(),
                    nn.Linear(4 * d_model, d_model)
                ),
                'norm2': nn.LayerNorm(d_model)
            }) for _ in range(n_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # x shape: [batch_size, seq_len]
        x = self.embedding(x) * math.sqrt(x.size(-1))  # [batch_size, seq_len, d_model]
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for layer in self.layers:
            # Self-attention
            attn_out = layer['attention'](x, mask=mask)
            x = layer['norm1'](x + self.dropout(attn_out))
            
            # Feed forward
            ffn_out = layer['ffn'](x)
            x = layer['norm2'](x + self.dropout(ffn_out))
            
        return x

class TranslationDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, d_head, n_heads, n_layers, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'self_attention': DifferentialAttention(d_model, d_head, n_heads, dropout),
                'norm1': nn.LayerNorm(d_model),
                'cross_attention': DifferentialAttention(d_model, d_head, n_heads, dropout),
                'norm2': nn.LayerNorm(d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, 4 * d_model),
                    nn.ReLU(),
                    nn.Linear(4 * d_model, d_model)
                ),
                'norm3': nn.LayerNorm(d_model)
            }) for _ in range(n_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # x shape: [batch_size, seq_len]
        x = self.embedding(x) * math.sqrt(x.size(-1))  # [batch_size, seq_len, d_model]
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for layer in self.layers:
            # Self-attention
            self_attn = layer['self_attention'](x, mask=tgt_mask)
            x = layer['norm1'](x + self.dropout(self_attn))
            
            # Cross-attention
            cross_attn = layer['cross_attention'](x, context=encoder_output, mask=src_mask)
            x = layer['norm2'](x + self.dropout(cross_attn))
            
            # Feed forward
            ffn_out = layer['ffn'](x)
            x = layer['norm3'](x + self.dropout(ffn_out))

        output = self.output_layer(x)
        return output

class TranslationModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, d_head=64, n_heads=8, 
                 n_layers=6, dropout=0.1):
        super().__init__()
        self.encoder = TranslationEncoder(src_vocab_size, d_model, d_head, n_heads, n_layers, dropout)
        self.decoder = TranslationDecoder(tgt_vocab_size, d_model, d_head, n_heads, n_layers, dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # src shape: [batch_size, src_seq_len]
        # tgt shape: [batch_size, tgt_seq_len]
        
        encoder_output = self.encoder(src, src_mask)  # [batch_size, src_seq_len, d_model]
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)  # [batch_size, tgt_seq_len, tgt_vocab_size]
        
        return decoder_output

def create_masks(src, tgt, pad_idx):
    # Source mask
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    
    # Target mask
    tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
    
    # Causal mask for decoder
    seq_len = tgt.size(1)
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    causal_mask = causal_mask.unsqueeze(0)
    
    tgt_mask = tgt_mask & ~causal_mask
    
    return src_mask, tgt_mask

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import MarianTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm
import wandb
import numpy as np

class TranslationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # WMT14 dataset has a 'translation' field containing source and target text
        source_text = item['translation']['en']  # English text
        target_text = item['translation']['cs']  # Czech text

        # Tokenize inputs
        source_encoding = self.tokenizer(
            source_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize targets
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_encoding['input_ids'].squeeze(),
            'attention_mask': source_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze(),
            'decoder_attention_mask': target_encoding['attention_mask'].squeeze()
        }

def train_model():
    # Initialize wandb
    wandb.init(project="translation-differential-attention")

    # Model parameters
    max_length = 128
    batch_size = 16
    num_epochs = 3
    learning_rate = 2e-5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("wmt14", "cs-en")  # Explicitly specify language pair
    
    # Take a subset of the data for faster training
    train_size = 100000  # Adjust this number based on your computational resources
    
    # Create smaller training and validation sets
    train_dataset = dataset['train'].select(range(train_size))
    val_dataset = dataset['validation'].select(range(1000))  # smaller validation set

    print(f"Training on {len(train_dataset)} examples")
    print(f"Validating on {len(val_dataset)} examples")

    # Initialize tokenizer
    print("Loading tokenizer...")
    tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-cs")  # Changed to en-cs tokenizer

    # Create custom datasets
    train_data = TranslationDataset(train_dataset, tokenizer, max_length)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    val_data = TranslationDataset(val_dataset, tokenizer, max_length)
    val_loader = DataLoader(val_data, batch_size=batch_size)

    # Initialize model
    print("Initializing model...")
    model = TranslationModel(
        src_vocab_size=tokenizer.vocab_size,
        tgt_vocab_size=tokenizer.vocab_size,
        d_model=512,
        d_head=64,
        n_heads=8,
        n_layers=6,
        dropout=0.1
    ).to(device)

    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')

        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            decoder_attention_mask = batch['decoder_attention_mask'].to(device)

            # Create masks
            src_mask, tgt_mask = create_masks(input_ids, labels, tokenizer.pad_token_id)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids, labels[:, :-1], src_mask, tgt_mask)
            
            # Calculate loss
            loss = loss_fct(outputs.view(-1, outputs.size(-1)), labels[:, 1:].contiguous().view(-1))

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

            # Log to wandb
            wandb.log({
                'train_loss': loss.item(),
                'learning_rate': scheduler.get_last_lr()[0]
            })

            # Save checkpoint every 5000 batches
            if batch_idx % 5000 == 0:
                checkpoint = {
                    'epoch': epoch,
                    'batch_idx': batch_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                }
                torch.save(checkpoint, f'checkpoint_epoch_{epoch}_batch_{batch_idx}.pt')

        avg_train_loss = total_loss / len(train_loader)
        print(f'Average training loss: {avg_train_loss}')

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                decoder_attention_mask = batch['decoder_attention_mask'].to(device)

                src_mask, tgt_mask = create_masks(input_ids, labels, tokenizer.pad_token_id)
                src_mask = src_mask.to(device)
                tgt_mask = tgt_mask.to(device)

                outputs = model(input_ids, labels[:, :-1], src_mask, tgt_mask)
                loss = loss_fct(outputs.view(-1, outputs.size(-1)), labels[:, 1:].contiguous().view(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f'Validation loss: {avg_val_loss}')
        wandb.log({'val_loss': avg_val_loss})

        # Save epoch checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_train_loss,
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')

        scheduler.step()

    wandb.finish()
    return model, tokenizer



# Function to translate text
def translate(model, tokenizer, text, max_length=128):
    model.eval()
    device = next(model.parameters()).device
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True, padding=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    # Create start token for decoder
    start_token = torch.tensor([[tokenizer.bos_token_id]]).to(device)
    
    with torch.no_grad():
        # Generate translation one token at a time
        generated_ids = start_token
        for _ in range(max_length):
            src_mask, tgt_mask = create_masks(input_ids, generated_ids, tokenizer.pad_token_id)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)
            
            outputs = model(input_ids, generated_ids, src_mask, tgt_mask)
            next_token = outputs[:, -1:].argmax(dim=-1)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# Training the model
if __name__ == "__main__":
    try:
        model, tokenizer = train_model()
        
        # Test translation
        test_sentence = "Hello, how are you?"
        translation = translate(model, tokenizer, test_sentence)
        print(f"Source: {test_sentence}")
        print(f"Translation: {translation}")
    
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011173725921738271, max=1.0…

Loading dataset...


Downloading data:   0%|          | 0.00/168M [00:00<?, ?B/s]

KeyboardInterrupt: 

RuntimeError: expand(torch.LongTensor{[32, 1, 1, 1, 1, 128]}, size=[32, 12, 128, 128]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (6)