# Model building

Now we start designing models to do stuff

In [None]:
import math
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from ataarangi.train import encode_world_state, TextTokenizer, WorldStateTokenizer, RākauDataset

# Initialize tokenizers
world_state_tokenizer = WorldStateTokenizer()
text_tokenizer = TextTokenizer()

rākau_data = pd.read_csv('../data/rākau_data.csv')
rākau_data['rākau'] = rākau_data.rākau.apply(json.loads)
rākau_data = rākau_data[rākau_data.num_rākau <= 10].reset_index(drop=True)

rākau_data.sort_values('num_rākau', ascending=False)

text_tokenizer = TextTokenizer()
ws_tokenizer = WorldStateTokenizer()

rākau_data['input'] = rākau_data.rākau.apply(ws_tokenizer.tokenize)
rākau_data['target'] = rākau_data.description.apply(text_tokenizer.tokenize)

import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embed_size = embed_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # Create a positional encoding that is large enough for any sequence you expect to process
        self.register_buffer('positional_encodings', self.create_positional_encodings(max_seq_length, embed_size))
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def create_positional_encodings(self, max_len, embed_size):
        """Create positional encodings for transformer model."""
        pos_enc = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)

    def forward(self, src, tgt):
        src_pos = self.positional_encodings[:, :src.size(1), :]
        tgt_pos = self.positional_encodings[:, :tgt.size(1), :]
        src = self.embedding(src) + src_pos
        tgt = self.embedding(tgt) + tgt_pos
        output = self.transformer(src, tgt)
        return self.fc_out(output)

# Model instantiation
model = TransformerModel(
    vocab_size=max(text_tokenizer.token_map.values())+1,
    embed_size=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    max_seq_length=500,
    dropout=0.1
)

def custom_collate_fn(batch):
    # Extracting input_ids, token_type_ids, and attention_mask from the batch
    input_ids = [item['input_ids'] for item in batch]
    token_type_ids = [item['token_type_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]

    # Find the maximum sequence length in this batch
    max_len = max(len(ids) for ids in input_ids)

    # Pad all sequences to this maximum length
    padded_input_ids = torch.stack([torch.cat([ids, torch.zeros(max_len - len(ids), dtype=torch.long)]) for ids in input_ids])
    padded_token_type_ids = torch.stack([torch.cat([ids, torch.zeros(max_len - len(ids), dtype=torch.long)]) for ids in token_type_ids])
    padded_attention_mask = torch.stack([torch.cat([mask, torch.zeros(max_len - len(mask), dtype=torch.long)]) for mask in attention_mask])

    return {
        'input_ids': padded_input_ids,
        'token_type_ids': padded_token_type_ids,
        'attention_mask': padded_attention_mask
    }

# Create dataset
dataset = RākauDataset(rākau_data.rākau, rākau_data.description, world_state_tokenizer, text_tokenizer)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model to the appropriate device
model = model.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 100

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    epoch_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        # Assuming 'input_ids' are the source and target sequences
        src = batch['input_ids'][:, :-1].to(device)  # all but the last for input
        tgt = batch['input_ids'][:, 1:].to(device)   # all but the first for target

        # Forward pass
        output = model(src, tgt)

        # Compute loss; assume output is reshaped to (batch_size*seq_len, vocab_size)
        # and tgt is reshaped accordingly for CrossEntropyLoss
        loss = criterion(output.view(-1, output.size(-1)), tgt.reshape(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Optional: Log progress
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}')

    # Average loss for the epoch
    print(f'Epoch {epoch+1} completed, Average Loss: {epoch_loss / len(dataloader)}')