# Model building

Now we start designing models to do stuff

In [1]:
import math
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from ataarangi.data import encode_world_state, TextTokenizer, WorldStateTokenizer, RākauDataset

In [2]:
# Initialize tokenizers
world_state_tokenizer = WorldStateTokenizer()
text_tokenizer = TextTokenizer()

In [3]:
rākau_data = pd.read_csv('../data/rākau_data.csv')
rākau_data['rākau'] = rākau_data.rākau.apply(json.loads)
rākau_data = rākau_data[rākau_data.num_rākau <= 10].reset_index(drop=True)

In [4]:
rākau_data.sort_values('num_rākau', ascending=False)

Unnamed: 0,id,entropy,num_rākau,rākau,description
653,eb825132-1e22-44ff-8b94-c62644d80390,8.090296,10,"[{'color': 'white', 'height': 7, 'location': 3...",te rākau mā me te rākau māwhero nui rawa me te...
323,ab4d2e1e-880e-4664-8ce1-01593cedbf76,7.690296,10,"[{'color': 'black', 'height': 5, 'location': 3...",ngā rākau kikorangi me ngā rākau whero me ngā ...
337,9ef16aba-598c-4c55-bf0b-a0d1fdfc56da,8.490296,10,"[{'color': 'red', 'height': 5, 'location': 20,...",te rākau kikorangi me te rākau mā iti
972,8ee330e2-3f6a-42bd-b28b-68a8bc9e757f,7.814807,10,"[{'color': 'white', 'height': 10, 'location': ...",te rākau mā me ngā rākau parauri nui rawa e ru...
647,9cc7f006-0cac-4ef7-9769-810204eae383,7.890296,10,"[{'color': 'pink', 'height': 8, 'location': 19...",ngā rākau katoa hāunga te rākau mā iti rawa
...,...,...,...,...,...
420,9fc48bf4-9c4c-4726-8829-113665d0558c,3.000000,2,"[{'color': 'black', 'height': 4, 'location': 2...",te rākau pango
419,c762566a-b4e1-4191-bf2b-6d7ab7e77f82,3.000000,2,"[{'color': 'brown', 'height': 6, 'location': 7...",ngā rākau
418,56ee6946-e664-4b23-bd1c-824d5ce9436e,2.000000,2,"[{'color': 'brown', 'height': 4, 'location': 9...",te rākau pango
417,edfa86f8-1ec4-4470-bfc2-03684bd2d516,3.000000,2,"[{'color': 'yellow', 'height': 7, 'location': ...",te rākau māwhero


In [5]:
text_tokenizer = TextTokenizer()
ws_tokenizer = WorldStateTokenizer()

In [6]:
rākau_data['input'] = rākau_data.rākau.apply(ws_tokenizer.tokenize)
rākau_data['target'] = rākau_data.description.apply(text_tokenizer.tokenize)

In [7]:
rākau_data

Unnamed: 0,id,entropy,num_rākau,rākau,description,input,target
0,800d23b5-574c-46d3-94cf-1066082e9d7d,3.0,2,"[{'color': 'blue', 'height': 4, 'location': 3,...",te rākau mā,"[2, 12, 6, 16, 19]","[21, 20, 24, 50]"
1,659289fa-7473-4617-a8f0-61324cc0e3b1,2.0,2,"[{'color': 'blue', 'height': 1, 'location': 8,...",ngā rākau,"[2, 9, 3, 9, 19]","[22, 20, 50]"
2,7629eca4-df36-4921-973b-0ca7859c5018,3.0,2,"[{'color': 'blue', 'height': 1, 'location': 4,...",te rākau iti,"[2, 9, 3, 17, 19]","[21, 20, 32, 50]"
3,89413927-94ca-4078-8198-85b957338400,3.0,2,"[{'color': 'green', 'height': 2, 'location': 1...",te rākau mā,"[6, 16, 3, 10, 19]","[21, 20, 24, 50]"
4,eb110390-7be8-4d7a-93e0-211c33c033f1,3.0,2,"[{'color': 'black', 'height': 10, 'location': ...",ngā rākau,"[5, 18, 6, 16, 19]","[22, 20, 50]"
...,...,...,...,...,...,...,...
1007,3d543250-f7c8-47c4-9491-ab29d5c8c8fd,3.0,2,"[{'color': 'blue', 'height': 8, 'location': 18...",te rākau kikorangi,"[5, 17, 2, 16, 19]","[21, 20, 27, 50]"
1008,b6fcfe8b-5604-47d5-8e1a-1a2830219683,3.0,2,"[{'color': 'white', 'height': 4, 'location': 1...",te rākau mā,"[6, 12, 5, 18, 19]","[21, 20, 24, 50]"
1009,ecc53392-b801-4159-8fe3-fe9475e50672,3.0,2,"[{'color': 'black', 'height': 2, 'location': 1...",te rākau whero,"[1, 17, 5, 10, 19]","[21, 20, 30, 50]"
1010,bd1d8542-5e5f-4d03-9bf3-8b3d889fe737,3.0,2,"[{'color': 'yellow', 'height': 3, 'location': ...",ngā rākau,"[1, 18, 4, 11, 19]","[22, 20, 50]"


## Defining a transformer model

In [8]:
import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embed_size = embed_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # Create a positional encoding that is large enough for any sequence you expect to process
        self.register_buffer('positional_encodings', self.create_positional_encodings(max_seq_length, embed_size))
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def create_positional_encodings(self, max_len, embed_size):
        """Create positional encodings for transformer model."""
        pos_enc = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)

    def forward(self, src, tgt):
        src_pos = self.positional_encodings[:, :src.size(1), :]
        tgt_pos = self.positional_encodings[:, :tgt.size(1), :]
        src = self.embedding(src) + src_pos
        tgt = self.embedding(tgt) + tgt_pos
        output = self.transformer(src, tgt)
        return self.fc_out(output)

# Model instantiation
model = TransformerModel(
    vocab_size=max(text_tokenizer.token_map.values())+1,
    embed_size=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    max_seq_length=500,
    dropout=0.1
)

In [9]:
ws_tokenizer.token_map.values(), text_tokenizer.token_map.values()

(dict_values([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
 dict_values([20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]))

In [10]:
def custom_collate_fn(batch):
    # Extracting input_ids, token_type_ids, and attention_mask from the batch
    input_ids = [item['input_ids'] for item in batch]
    token_type_ids = [item['token_type_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]

    # Find the maximum sequence length in this batch
    max_len = max(len(ids) for ids in input_ids)

    # Pad all sequences to this maximum length
    padded_input_ids = torch.stack([torch.cat([ids, torch.zeros(max_len - len(ids), dtype=torch.long)]) for ids in input_ids])
    padded_token_type_ids = torch.stack([torch.cat([ids, torch.zeros(max_len - len(ids), dtype=torch.long)]) for ids in token_type_ids])
    padded_attention_mask = torch.stack([torch.cat([mask, torch.zeros(max_len - len(mask), dtype=torch.long)]) for mask in attention_mask])

    return {
        'input_ids': padded_input_ids,
        'token_type_ids': padded_token_type_ids,
        'attention_mask': padded_attention_mask
    }

In [11]:
# Create dataset
dataset = RākauDataset(rākau_data.rākau, rākau_data.description, world_state_tokenizer, text_tokenizer)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)

In [None]:
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model to the appropriate device
model = model.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 100

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    epoch_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        # Assuming 'input_ids' are the source and target sequences
        src = batch['input_ids'][:, :-1].to(device)  # all but the last for input
        tgt = batch['input_ids'][:, 1:].to(device)   # all but the first for target

        # Forward pass
        output = model(src, tgt)

        # Compute loss; assume output is reshaped to (batch_size*seq_len, vocab_size)
        # and tgt is reshaped accordingly for CrossEntropyLoss
        loss = criterion(output.view(-1, output.size(-1)), tgt.reshape(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Optional: Log progress
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}')

    # Average loss for the epoch
    print(f'Epoch {epoch+1} completed, Average Loss: {epoch_loss / len(dataloader)}')

Epoch 1/100, Batch 0, Loss: 3.98132061958313
Epoch 1 completed, Average Loss: 2.432126149535179
Epoch 2/100, Batch 0, Loss: 2.12595796585083
Epoch 2 completed, Average Loss: 1.9801320880651474
Epoch 3/100, Batch 0, Loss: 1.983611822128296
Epoch 3 completed, Average Loss: 1.975804142653942
Epoch 4/100, Batch 0, Loss: 1.917512059211731
Epoch 4 completed, Average Loss: 1.961048997938633
Epoch 5/100, Batch 0, Loss: 2.0927631855010986
Epoch 5 completed, Average Loss: 1.9646124467253685
Epoch 6/100, Batch 0, Loss: 1.95881986618042
Epoch 6 completed, Average Loss: 2.0049845948815346
Epoch 7/100, Batch 0, Loss: 1.8378864526748657
Epoch 7 completed, Average Loss: 1.991580843925476
Epoch 8/100, Batch 0, Loss: 1.7605085372924805
Epoch 8 completed, Average Loss: 1.926990658044815
Epoch 9/100, Batch 0, Loss: 1.9239059686660767
Epoch 9 completed, Average Loss: 1.946223922073841
Epoch 10/100, Batch 0, Loss: 1.7644635438919067
Epoch 10 completed, Average Loss: 1.991080828011036
Epoch 11/100, Batch 0, 