In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from cot.data import Parity, BinaryCopy, Copy
from cot.config import RAW_DIR
from cot.models import Transformer, TransformerConfig

import logging
logging.basicConfig(level=logging.INFO)

## DataLoader

In [3]:
Problem = BinaryCopy

In [4]:
rng = np.random.default_rng()
seq_length = 20
max_nb_data_per_len = 1000
random = False
split_probas_by_len =  [1, 1, 1, 1, .9, .8, .9, .3, .2]
probas_by_len = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]).astype(float)
probas_by_len /= probas_by_len.sum()

lengths = list(np.arange(len(split_probas_by_len)) + 1)

In [5]:
if Problem.prefix == 'copy':
    Problem(vocab_size=20)

Problem.generate_datafiles(max_nb_data_per_len, split_probas_by_len, rng)

trainset = Problem()
trainset.set_as_trainset(lengths, probas_by_len)

testset = Problem()
testset.set_as_testset(lengths)

INFO:cot.data.data_processing:Sequences of length 1 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (2/2 split).
INFO:cot.data.data_processing:Sequences of length 2 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (4/4 split).
INFO:cot.data.data_processing:Sequences of length 3 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (8/8 split).
INFO:cot.data.data_processing:Sequences of length 4 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (16/16 split).
INFO:cot.data.data_processing:Sequences of length 5 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (28/32 split).
INFO:cot.data.data_processing:Sequences of length 6 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (51/64 split).
INFO:cot.data.data_processing:Sequences of length 7 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (115/128 split).
INFO:cot.data.data_proce

In [6]:
loader = DataLoader(trainset, batch_size=32, sampler=trainset.sampler)

In [7]:
trainset[1].shape

torch.Size([20])

## Training loop

In [8]:
config = TransformerConfig(
    vocab_size=4,
    emb_dim=128,
    pos_emb=True,
    seq_len=20,
    emb_dropout=0.1,
    n_head=2,
    n_layer=2,
)


In [9]:
model = Transformer(config)
print(model)

Transformer(
  (embeddings): Embedding(
    (token_emb): Embedding(4, 128)
    (pos_emb): Embedding(20, 128)
  )
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv_mat): Linear(in_features=128, out_features=384, bias=False)
        (output): Linear(in_features=128, out_features=128, bias=False)
      )
      (norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=128, out_features=512, bias=False)
        (fc2): Linear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (output_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (output): Linear(in_features=128, out_features=4, bias=False)
)


In [28]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(config)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

losses = []
test_losses = []

n_epochs = 30
model.train()
for _ in range(n_epochs):
    total_loss = 0
    for sequence in loader:
        # deal with EoS token being represented as -1
        sequence += 1
        sequence = sequence.to(device=device, dtype=torch.long)

        inputs = sequence[:, :-1]
        targets = sequence[:, 1:]

        # only train on the chain-of-thoughts process
        ind = targets == 3
        cot_mask = ind.cumsum(axis=1)
        cot_mask[ind] = 0
        cot_mask = cot_mask.to(dtype=bool)

        logits = model(inputs)
        loss = F.cross_entropy(logits[cot_mask].view(-1, logits.size(-1)), targets[cot_mask].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    losses.append(total_loss)
    print(f'Loss: {total_loss}')

Loss: 7.685434967279434
Loss: 4.7039487063884735
Loss: 3.600172221660614
Loss: 2.6852837651968002
Loss: 2.084871418774128
Loss: 1.9622856751084328
Loss: 1.7896827161312103
Loss: 1.5350349992513657
Loss: 1.414369598031044
Loss: 1.2896583676338196
Loss: 0.9764289855957031
Loss: 0.6902799997478724
Loss: 0.5577131491154432
Loss: 0.46510985493659973
Loss: 0.4739277195185423
Loss: 0.43351088277995586
Loss: 0.30967373214662075
Loss: 0.24477957095950842
Loss: 0.21711145667359233
Loss: 0.23435441683977842
Loss: 0.19153718138113618
Loss: 0.19407782796770334
Loss: 0.19256279850378633
Loss: 0.10184032842516899
Loss: 0.1010023639537394
Loss: 0.08920602547004819
Loss: 0.08030583639629185
Loss: 0.07640748959966004
Loss: 0.0685235713608563
Loss: 0.10115879005752504


## Results Analysis

There are several quantities to monitor:
- Make sure that `-1` is an absorbing state.
- Check the validity of the full chain of thoughts.
- Check the validity of the final answer in the chain, or of intermediate answer in the chain.
- Cluster results by lengths of the input sequence.

In [30]:
(targets - logits.argmax(dim=-1))[cot_mask]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])