Import libraries, define the dataset and transformer.

In [107]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from datetime import datetime
from devinterp.slt.sampler import estimate_learning_coeff_with_summary, SGLD
from devinterp.utils import default_nbeta

# Set up device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

class ModularArithmeticDataset(Dataset):
    def __init__(self, data, encoder, max_seq_length, padding_char):
        self.data = data
        self.encoder = encoder
        self.max_seq_length = max_seq_length
        self.pad_token_id = self.encoder(padding_char)[0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        equation = item['input']
        result = item['output']

        equation_tokens = self.encoder(equation)

        # Pad or truncate to max_seq_length
        if len(equation_tokens) < self.max_seq_length:
            equation_tokens += [self.pad_token_id] * (self.max_seq_length - len(equation_tokens))
        else:
            equation_tokens = equation_tokens[:self.max_seq_length]

        return torch.tensor(equation_tokens, dtype=torch.long), torch.tensor(int(result), dtype=torch.long)

class ArithmeticTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, max_seq_length, max_result):
        super(ArithmeticTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(max_seq_length, d_model)

        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)

        self.fc = nn.Linear(d_model, max_result)

    def forward(self, src):
        # src shape: [batch_size, seq_len]

        # Create a mask for padded elements
        mask = (src == 0).to(device)

        # Create positional encodings
        positions = torch.arange(0, src.size(1)).unsqueeze(0).expand(src.size()).to(device)

        # Combine token embeddings and positional encodings
        x = self.embedding(src) + self.pos_encoder(positions)

        # Pass through the transformer
        output = self.transformer_encoder(x, src_key_padding_mask=mask)

        # Use the output of the last non-padded token for classification
        output = output[:, 0, :]  # Use the first token's output for classification

        # Project to the number of possible results
        output = self.fc(output)

        return output

def evaluate(model, data):
    inputs, outputs = data

    return nn.functional.cross_entropy(model(inputs).logits, outputs), {
        "logits": model(inputs).logits
    }  # transformers don't output a vector

def training_loop(num_epochs, optimizer, train_loader, test_loader, model, criterion):

    # Training loop
    start_time = datetime.now()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for equation, result in train_loader:
            equation, result = equation.to(device), result.to(device)
            optimizer.zero_grad()
            output = model(equation)
            loss = criterion(output, result)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for equation, result in test_loader:
                equation, result = equation.to(device), result.to(device)
                output = model(equation)
                _, predicted = torch.max(output, 1)
                total += result.size(0)
                correct += (predicted == result).sum().item()

        accuracy = 100 * correct / total

        learning_coeff_stats = estimate_learning_coeff_with_summary(
                                model,
                                loader=train_loader,
                                evaluate=criterion,
                                sampling_method=SGLD,
                                optimizer_kwargs=dict(lr=4e-4, localization=100.0),
                                num_chains=3,  # How many independent chains to run
                                num_draws=100,  # How many samples to draw per chain
                                num_burnin_steps=0,  # How many samples to discard at the beginning of each chain
                                num_steps_bw_draws=1,  # How many steps to take between each sample
                                device=device,
                                online=True)

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.3f}, Accuracy: {accuracy:.2f}% LLC: {llc_mean:.3f}")
        
    end_time = datetime.now()
    training_duration = end_time - start_time
    print(f"Training duration: {training_duration.total_seconds():0.2f} seconds")

Using device: cpu


### Import the data and create the tokenizer functions

In [108]:
# Get paths to the test and train data
data_set_name = "modular_arithmetic_three_numbers"
data_path = os.path.join("..", "data")
train_data_path = os.path.join(data_path,f'{data_set_name}_train.csv')
test_data_path = os.path.join(data_path, f'{data_set_name}_test.csv')

# Load and preprocess the data
train_data = pd.read_csv(train_data_path)
test_data = pd.read_csv(test_data_path)

print("Train dataset size:", len(train_data))
print("Test dataset size:", len(test_data))

print("Train data sample:")
print(train_data.head())

# Get set of unique characters in training data input
unique_characters = set("".join(train_data['input']))
padding_char = "P"
unique_characters.add(padding_char) # Add character for padding
charToInt = {char: i for i, char in enumerate(unique_characters)}
intToChar = {i: char for i, char in enumerate(unique_characters)}
vocab_size = len(unique_characters)

encoder = lambda string: [charToInt[char] for char in string]
decoder = lambda arr: "".join([intToChar[i] for i in arr])

all_chars = "".join(unique_characters)
print(f"vocab_size = {vocab_size}")
print(f'encode({all_chars}) = {encoder(all_chars)}')
print(f'decode(encode({train_data["input"][0]})) = {decoder(encoder(train_data["input"][0]))}')

Train dataset size: 8000
Test dataset size: 2000
Train data sample:
                 input  output  modulus       operation
0   (70 * 66 * 2) % 10       0       10  multiplication
1  (80 + 33 + 79) % 10       2       10        addition
2  (18 + 34 * 74) % 10       4       10     addAndMulti
3  (56 + 62 + 98) % 10       6       10        addition
4  (73 * 73 + 78) % 10       7       10     addAndMulti
vocab_size = 17
encode(%6+370*89)P42(51 ) = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
decode(encode((70 * 66 * 2) % 10)) = (70 * 66 * 2) % 10


### Train the model

In [109]:
# Determine max_length based on the longest equation in the dataset
max_seq_length = train_data['input'].str.len().max()

# Create datasets and dataloaders
train_dataset = ModularArithmeticDataset(train_data, encoder, max_seq_length, padding_char)
test_dataset = ModularArithmeticDataset(test_data, encoder, max_seq_length, padding_char)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Model parameters
d_model = 32
nhead = 2
num_layers = 3
dim_feedforward = 128
max_result = train_data['output'].max() + 1  # +1 because we start counting from 0

# Initialize the model
model = ArithmeticTransformer(vocab_size, d_model, nhead, num_layers, dim_feedforward, max_seq_length, max_result)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), weight_decay=1e-5)

num_epochs = 50

training_loop(num_epochs=num_epochs,
              optimizer=optimizer,
              train_loader=train_loader,
              test_loader=test_loader,
              model=model,
              criterion=criterion)



TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not ArithmeticTransformer

Test the model

In [54]:
model.eval()
with torch.no_grad():
    test_samples = 20  # Number of samples to test
    sample_count = 0
    for equation, result in test_loader:
        equation, result = equation.to(device), result.to(device)
        output = model(equation)
        _, predicted = torch.max(output, 1)
        for i in range(len(equation)):
            if sample_count >= test_samples:
                break
            eq_str = decoder(equation[i].tolist())
            eq_str = eq_str.replace(padding_char, "")
            wrong_message = " <--- WRONG!" if predicted[i] != result[i] else ""
            print(f"Equation: {eq_str}, Predicted: {predicted[i].item()}, Actual: {result[i].item()}{wrong_message}")
            sample_count += 1
        if sample_count >= test_samples:
            break

Equation: (30 * 42 + 13) % 10, Predicted: 3, Actual: 3
Equation: (85 + 7 * 61) % 10, Predicted: 2, Actual: 2
Equation: (55 * 40 * 3) % 10, Predicted: 0, Actual: 0
Equation: (60 + 42 + 91) % 10, Predicted: 3, Actual: 3
Equation: (16 + 58 + 0) % 10, Predicted: 4, Actual: 4
Equation: (96 + 92 + 16) % 10, Predicted: 4, Actual: 4
Equation: (5 * 70 * 45) % 10, Predicted: 0, Actual: 0
Equation: (5 * 5 + 25) % 10, Predicted: 0, Actual: 0
Equation: (21 * 1 * 44) % 10, Predicted: 4, Actual: 4
Equation: (35 + 29 + 96) % 10, Predicted: 0, Actual: 0
Equation: (0 * 36 * 18) % 10, Predicted: 0, Actual: 0
Equation: (92 * 19 + 63) % 10, Predicted: 1, Actual: 1
Equation: (41 * 33 + 21) % 10, Predicted: 4, Actual: 4
Equation: (7 + 59 * 12) % 10, Predicted: 5, Actual: 5
Equation: (42 * 82 + 29) % 10, Predicted: 3, Actual: 3
Equation: (33 * 70 + 7) % 10, Predicted: 7, Actual: 7
Equation: (20 + 17 * 31) % 10, Predicted: 7, Actual: 7
Equation: (82 * 32 + 76) % 10, Predicted: 0, Actual: 0
Equation: (49 + 13 +