In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from cot.data import Parity, BinaryCopy, Copy
from cot.config import RAW_DIR
from cot.models import Transformer, TransformerConfig

import logging
logging.basicConfig(level=logging.INFO)

## DataLoader

In [3]:
Problem = BinaryCopy

In [4]:
rng = np.random.default_rng()
seq_length = 20
max_nb_data_per_len = 1000
random = False
split_probas_by_len =  [1, 1, 1, 1, .9, .8, .9, .3, .2]
probas_by_len = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]).astype(float)
probas_by_len /= probas_by_len.sum()

lengths = list(np.arange(len(split_probas_by_len)) + 1)

In [5]:
if Problem.prefix == 'copy':
    Problem(vocab_size=20)

Problem.generate_datafiles(max_nb_data_per_len, split_probas_by_len, rng)

trainset = Problem()
trainset.set_as_trainset(lengths, probas_by_len)

testset = Problem()
testset.set_as_testset(lengths)

INFO:cot.data.data_processing:Sequences of length 1 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (2/2 split).
INFO:cot.data.data_processing:Sequences of length 2 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (4/4 split).
INFO:cot.data.data_processing:Sequences of length 3 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (8/8 split).
INFO:cot.data.data_processing:Sequences of length 4 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (16/16 split).
INFO:cot.data.data_processing:Sequences of length 5 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (28/32 split).
INFO:cot.data.data_processing:Sequences of length 6 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (51/64 split).
INFO:cot.data.data_processing:Sequences of length 7 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (115/128 split).
INFO:cot.data.data_proce

In [6]:
loader = DataLoader(trainset, batch_size=32, sampler=trainset.sampler)

In [7]:
trainset[1].shape

torch.Size([20])

## Training loop

In [8]:
config = TransformerConfig(
    vocab_size=4,
    emb_dim=128,
    pos_emb=True,
    seq_len=20,
    emb_dropout=0.1,
    n_head=2,
    n_layer=2,
)


In [9]:
model = Transformer(config)
print(model)

Transformer(
  (embeddings): Embedding(
    (token_emb): Embedding(4, 128)
    (pos_emb): Embedding(20, 128)
  )
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv_mat): Linear(in_features=128, out_features=384, bias=False)
        (output): Linear(in_features=128, out_features=128, bias=False)
      )
      (norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=128, out_features=512, bias=False)
        (fc2): Linear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (output_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (output): Linear(in_features=128, out_features=4, bias=False)
)


In [47]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(config)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

nb_epochs = 100

losses = np.empty(nb_epochs)
accuracies = np.empty(nb_epochs)
test_losses = []

model.train()
for t in range(nb_epochs):
    running_loss = 0
    accuracy = 0
    for sequence in loader:
        # deal with EoS token being represented as -1
        sequence += 1
        sequence = sequence.to(device=device, dtype=torch.long)

        inputs = sequence[:, :-1]
        targets = sequence[:, 1:]

        # only train on the chain-of-thoughts process, EoI is represented by 3 in our case
        ind = targets == 3
        cot_mask = ind.cumsum(axis=1)
        cot_mask[ind] = 0
        cot_mask = cot_mask.to(dtype=bool)

        logits = model(inputs)
        loss = F.cross_entropy(logits[cot_mask].view(-1, logits.size(-1)), targets[cot_mask].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            running_loss += loss.item()
            tmp = (logits.argmax(dim=-1) - targets)
            tmp[~cot_mask] = 0
            accuracy += (tmp == 0).all(dim=1).float().sum()

    accuracy /= len(trainset)

    accuracies[t] = accuracy
    losses[t] = loss
    print(f'Loss: {running_loss:.4f}, Accuracy: {accuracy:.4f}')

Loss: 7.6854, Accuracy: 0.0100
Loss: 4.7039, Accuracy: 0.1244
Loss: 3.6002, Accuracy: 0.2587
Loss: 2.6853, Accuracy: 0.3607
Loss: 2.0849, Accuracy: 0.4204
Loss: 1.9623, Accuracy: 0.4279
Loss: 1.7897, Accuracy: 0.4726
Loss: 1.5350, Accuracy: 0.5075
Loss: 1.4144, Accuracy: 0.5348
Loss: 1.2897, Accuracy: 0.5945
Loss: 0.9764, Accuracy: 0.6493
Loss: 0.6903, Accuracy: 0.7662
Loss: 0.5577, Accuracy: 0.7985
Loss: 0.4651, Accuracy: 0.8209
Loss: 0.4739, Accuracy: 0.8184
Loss: 0.4335, Accuracy: 0.8557
Loss: 0.3097, Accuracy: 0.8955
Loss: 0.2448, Accuracy: 0.9254
Loss: 0.2171, Accuracy: 0.9204
Loss: 0.2344, Accuracy: 0.9229
Loss: 0.1915, Accuracy: 0.9353
Loss: 0.1941, Accuracy: 0.9428
Loss: 0.1926, Accuracy: 0.9353
Loss: 0.1018, Accuracy: 0.9751
Loss: 0.1010, Accuracy: 0.9652
Loss: 0.0892, Accuracy: 0.9751
Loss: 0.0803, Accuracy: 0.9801
Loss: 0.0764, Accuracy: 0.9776
Loss: 0.0685, Accuracy: 0.9876
Loss: 0.1012, Accuracy: 0.9677
Loss: 0.0776, Accuracy: 0.9751
Loss: 0.0648, Accuracy: 0.9801
Loss: 0.

## Results Analysis

There are several quantities to monitor:
- Make sure that `-1` is an absorbing state.
- Check the validity of the full chain of thoughts.
- Check the validity of the final answer in the chain, or of intermediate answer in the chain.
- Cluster results by lengths of the input sequence.