In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from cot.data import Parity, BinaryCopy, Copy
from cot.config import RAW_DIR
from cot.models import Transformer, TransformerConfig

import logging
logging.basicConfig(level=logging.INFO)

## DataLoader

In [3]:
Problem = Parity

In [4]:
rng = np.random.default_rng()
seq_length = 20
max_nb_data_per_len = 1000
random = False
split_probas_by_len =  [1, 1, 1, 1, .9, .8, .9, .3, .2]
probas_by_len = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]).astype(float)
probas_by_len /= probas_by_len.sum()

lengths = list(np.arange(len(split_probas_by_len)) + 1)

In [5]:
if Problem.prefix == 'copy':
    Problem(vocab_size=20)

Problem.generate_datafiles(max_nb_data_per_len, split_probas_by_len, rng)

trainset = Problem()
trainset.set_as_trainset(lengths, probas_by_len)

testset = Problem()
testset.set_as_testset(lengths)

INFO:cot.data.data_processing:Sequences of length 1 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (2/2 split).
INFO:cot.data.data_processing:Sequences of length 2 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (4/4 split).
INFO:cot.data.data_processing:Sequences of length 3 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (8/8 split).
INFO:cot.data.data_processing:Sequences of length 4 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (16/16 split).
INFO:cot.data.data_processing:Sequences of length 5 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (28/32 split).
INFO:cot.data.data_processing:Sequences of length 6 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (51/64 split).
INFO:cot.data.data_processing:Sequences of length 7 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/parity (115/128 split).
INFO:cot.data.data_processing:Sequences of length 8 done. S

In [6]:
loader = DataLoader(trainset, batch_size=256, sampler=trainset.sampler)

## Training loop

In [7]:
config = TransformerConfig(
    vocab_size=torch.max(trainset.data).item() + 1,
    emb_dim=128,
    pos_emb=True,
    seq_len=len(trainset[0]),
    emb_dropout=0.1,
    n_head=2,
    n_layer=2,
)


In [8]:
model = Transformer(config)
print(model)

Transformer(
  (embeddings): Embedding(
    (token_emb): Embedding(5, 128)
    (pos_emb): Embedding(21, 128)
  )
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv_mat): Linear(in_features=128, out_features=384, bias=False)
        (output): Linear(in_features=128, out_features=128, bias=False)
      )
      (norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=128, out_features=512, bias=False)
        (fc2): Linear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (output_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (output): Linear(in_features=128, out_features=5, bias=False)
)


In [9]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(config)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

nb_epochs = 100

losses = np.empty(nb_epochs)
accuracies = np.empty(nb_epochs)
test_losses = []

model.train()
for t in range(nb_epochs):
    running_loss = 0
    accuracy = 0
    for sequence in loader:
        sequence = sequence.to(device=device, dtype=torch.long)

        inputs = sequence[:, :-1]
        targets = sequence[:, 1:]

        # only train on the chain-of-thoughts process, EoI is represented by 1 in our case
        ind = targets == 1
        cot_mask = ind.cumsum(axis=1)
        cot_mask[ind] = 0
        cot_mask = cot_mask.to(dtype=bool)

        logits = model(inputs)
        loss = F.cross_entropy(logits[cot_mask].view(-1, logits.size(-1)), targets[cot_mask].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            running_loss += loss.item()
            tmp = (logits.argmax(dim=-1) - targets)
            tmp[~cot_mask] = 0
            accuracy += (tmp == 0).all(dim=1).float().sum()

    accuracy /= len(trainset)

    accuracies[t] = accuracy
    losses[t] = loss
    print(f'Loss: {running_loss:.4f}, Accuracy: {accuracy:.4f}')

  from .autonotebook import tqdm as notebook_tqdm


Loss: 2.4651, Accuracy: 0.0050
Loss: 1.2965, Accuracy: 0.0124
Loss: 0.9876, Accuracy: 0.0050
Loss: 0.9343, Accuracy: 0.0000
Loss: 0.8938, Accuracy: 0.0000
Loss: 0.8772, Accuracy: 0.0000
Loss: 0.8131, Accuracy: 0.0522
Loss: 0.8033, Accuracy: 0.0871
Loss: 0.7537, Accuracy: 0.0821
Loss: 0.7520, Accuracy: 0.0995
Loss: 0.7314, Accuracy: 0.1741
Loss: 0.7091, Accuracy: 0.1667
Loss: 0.6855, Accuracy: 0.1791
Loss: 0.6968, Accuracy: 0.1368
Loss: 0.6702, Accuracy: 0.1716
Loss: 0.6580, Accuracy: 0.1866
Loss: 0.6054, Accuracy: 0.1965
Loss: 0.6208, Accuracy: 0.1841
Loss: 0.5773, Accuracy: 0.1965
Loss: 0.5381, Accuracy: 0.2289
Loss: 0.5025, Accuracy: 0.2512
Loss: 0.4681, Accuracy: 0.2612
Loss: 0.4385, Accuracy: 0.2786
Loss: 0.4216, Accuracy: 0.2861
Loss: 0.4177, Accuracy: 0.3010
Loss: 0.4084, Accuracy: 0.2786
Loss: 0.3955, Accuracy: 0.3234
Loss: 0.3902, Accuracy: 0.3284
Loss: 0.3451, Accuracy: 0.3657
Loss: 0.3689, Accuracy: 0.3507
Loss: 0.3312, Accuracy: 0.3682
Loss: 0.3353, Accuracy: 0.3930
Loss: 0.

## Results Analysis

There are several quantities to monitor:
- Make sure that `-1` is an absorbing state.
- Check the validity of the full chain of thoughts.
- Check the validity of the final answer in the chain, or of intermediate answer in the chain.
- Cluster results by lengths of the input sequence.