In [2]:
# autoreload to reload modules when they change
%load_ext autoreload
%autoreload 2

import rollout
from torch import optim
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn

import networkx  as nx 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
from itertools import permutations


def generate_reasoning_combinations(num_tokens, sequence_length, n_back, return_output=True):
    # generate all permutations of the first sequence_length-1 elements
    # TODO: Could subsample the permutations to reduce the number of combinations
    G = nx.cycle_graph(num_tokens)
    all_perms = list(permutations(range(num_tokens), sequence_length-1))
    combinations = torch.zeros(len(all_perms), sequence_length, dtype=torch.long)
    combinations[:,:sequence_length-1] = torch.tensor(all_perms)
    # now set the last element to a random element in the sequence
    prompt_inds = torch.randint(0, sequence_length-1-n_back, (len(combinations),))
    combinations[:,-1] = combinations[np.arange(len(combinations)),prompt_inds]
    if return_output:
        output_inds = prompt_inds + n_back
        output = combinations[np.arange(len(combinations)),output_inds]
        reasoned_outputs = (output + 1) % num_tokens
        
        return combinations, reasoned_outputs
    return combinations

In [63]:
n_tokens = 10
sequence_length = 6
n_back = 1
comb, reason = generate_reasoning_combinations(num_tokens=n_tokens, sequence_length=sequence_length, n_back=n_back)

In [64]:
 
class ReasoningDataset:
    # A dataset class for generating sequences with induction patterns.
    # TODO could subclass torch.utils.data.Dataset for more flexibility
    def __init__(self, num_tokens, sequence_length, n_back=1, random_seed=42, train_fraction=0.8, data_generator=generate_reasoning_combinations):
        """
        Initializes the InductionDataset with the given parameters.
        Args:
            num_tokens (int): Total number of unique tokens.
            sequence_length (int): Length of each sequence.
            n_back (int, optional): Number of steps back to look for the induction pattern. Defaults to 1.
            random_seed (int, optional): Random seed for reproducibility. Defaults to 42.
            train_fraction (float, optional): Fraction of data to be used for training. Defaults to 0.8.
        """
        torch.manual_seed(random_seed)
        assert num_tokens > sequence_length, "num_tokens must be greater than sequence_length"
        assert n_back < sequence_length-1, "n_back must be less than sequence_length-1"
        self.n = num_tokens
        self.n_back = n_back

        self.X, self.y = data_generator(num_tokens, sequence_length, n_back)
        shuffle_idx = torch.randperm(len(self.X))
        self.X = self.X[shuffle_idx]
        self.y = self.y[shuffle_idx]
        self.n_samples = len(self.X)

        self.n_train = int(self.n_samples * 0.8)
        self.train_idx = torch.arange(self.n_train)

        self.test_idx = torch.arange(self.n_train, self.n_samples)
        self.n_test = len(self.test_idx)

    def __len__(self):
        return self.n_samples

    def generate_batch(self, batch_size, type='train'):
        """
        Generates a batch of data for training or testing.
        Args:
            batch_size (int): Number of samples in the batch.
            type (str, optional): Type of data to generate ('train' or 'test'). Defaults to 'train'.
        Returns:
            tuple: A tuple containing the input sequences (X) and the output sequences (y).
        """
        assert type in ['train', 'test'], "type must be either 'train' or 'test'"
        if type == 'train':
            idx = self.train_idx[torch.randint(0, self.n_train, (batch_size,))]
        else:
            idx = self.test_idx[torch.randint(0, self.n_test, (batch_size,))]
        X = self.X[idx]
        y = self.y[idx]
        return X, y
    

In [67]:
d_model = 256
n_tokens = 10
sequence_length = 6
n_heads = 1
dataset = ReasoningDataset(n_tokens, sequence_length)
simpleModel = rollout.models.FlexibleTransformer(d_model, n_tokens, sequence_length, n_heads, n_attn_layers=1)
optimizer = optim.AdamW(simpleModel.parameters(), lr=0.001)
criterion = nn.functional.cross_entropy

simple_train_losses, simple_test_losses = rollout.models.optimize_model(simpleModel, criterion, optimizer, dataset, n_epochs=500, batch_size=1024)

Epoch 0: Train Loss: 2.4965081214904785, Test Loss: 2.366917848587036
Epoch 100: Train Loss: 1.3366155624389648, Test Loss: 1.344161033630371
Epoch 200: Train Loss: 1.315226435661316, Test Loss: 1.3030773401260376
Epoch 300: Train Loss: 1.2873501777648926, Test Loss: 1.291556477546692
Epoch 400: Train Loss: 1.3003287315368652, Test Loss: 1.2621301412582397


In [66]:
d_model = 256
n_tokens = 10
sequence_length = 6
n_heads = 1
dataset = ReasoningDataset(n_tokens, sequence_length)
complexModel = rollout.models.FlexibleTransformer(d_model, n_tokens, sequence_length, n_heads, n_attn_layers=2)
optimizer = optim.AdamW(complexModel.parameters(), lr=0.001)
criterion = nn.functional.cross_entropy

complex_train_losses, complex_test_losses = rollout.models.optimize_model(complexModel, criterion, optimizer, dataset, n_epochs=500, batch_size=1024)

Epoch 0: Train Loss: 2.528334140777588, Test Loss: 2.7821390628814697
Epoch 100: Train Loss: 0.00027958714053966105, Test Loss: 0.00026847951812669635
Epoch 200: Train Loss: 8.501573756802827e-05, Test Loss: 8.46980547066778e-05
Epoch 300: Train Loss: 5.293273716233671e-05, Test Loss: 5.359837814467028e-05
Epoch 400: Train Loss: 3.7338297261158004e-05, Test Loss: 3.9357313653454185e-05
