In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from cot.data import Parity, BinaryCopy, Copy
from cot.config import RAW_DIR
from cot.models import Transformer, TransformerConfig

import logging
logging.basicConfig(level=logging.INFO)

## DataLoader

In [3]:
Problem = BinaryCopy

In [4]:
rng = np.random.default_rng()
seq_length = 20
max_nb_data_per_len = 1000
random = False
split_probas_by_len =  [1, 1, 1, 1, .9, .8, .9, .3, .2]
probas_by_len = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]).astype(float)
probas_by_len /= probas_by_len.sum()

lengths = list(np.arange(len(split_probas_by_len)) + 1)

In [5]:
if Problem.prefix == 'copy':
    Problem(vocab_size=20)

Problem.generate_datafiles(max_nb_data_per_len, split_probas_by_len, rng)

trainset = Problem()
trainset.set_as_trainset(lengths, probas_by_len)

testset = Problem()
testset.set_as_testset(lengths)

INFO:cot.data.data_processing:Sequences of length 1 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (2/2 split).
INFO:cot.data.data_processing:Sequences of length 2 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (4/4 split).
INFO:cot.data.data_processing:Sequences of length 3 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (8/8 split).
INFO:cot.data.data_processing:Sequences of length 4 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (16/16 split).
INFO:cot.data.data_processing:Sequences of length 5 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (28/32 split).
INFO:cot.data.data_processing:Sequences of length 6 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (51/64 split).
INFO:cot.data.data_processing:Sequences of length 7 done. Saved in /home/vivc/Code/llm/Compositionality/data/raw/binary_copy (115/128 split).
INFO:cot.data.data_proce

In [6]:
loader = DataLoader(trainset, batch_size=32, sampler=trainset.sampler)

In [7]:
trainset[1].shape

torch.Size([20])

## Training loop

In [8]:
config = TransformerConfig(
    vocab_size=4,
    emb_dim=128,
    pos_emb=True,
    seq_len=20,
    emb_dropout=0.1,
    n_head=2,
    n_layer=2,
)


In [9]:
model = Transformer(config)
print(model)

Transformer(
  (embeddings): Embedding(
    (token_emb): Embedding(4, 128)
    (pos_emb): Embedding(20, 128)
  )
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv_mat): Linear(in_features=128, out_features=384, bias=False)
        (output): Linear(in_features=128, out_features=128, bias=False)
      )
      (norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=128, out_features=512, bias=False)
        (fc2): Linear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (output_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (output): Linear(in_features=128, out_features=4, bias=False)
)


In [10]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Transformer(config)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

losses = []
test_losses = []

n_epochs = 30
model.train()
for _ in range(n_epochs):
    total_loss = 0
    for sequence in loader:
        # deal with EoS token being represented as -1
        sequence += 1
        sequence = sequence.to(device=device, dtype=torch.long)

        inputs = sequence[:, :-1]
        targets = sequence[:, 1:]

        logits = model(inputs)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    losses.append(total_loss)
    print(f'Loss: {total_loss}')

  from .autonotebook import tqdm as notebook_tqdm


Loss: 10.217584311962128
Loss: 7.802834510803223
Loss: 6.353094637393951
Loss: 5.898797899484634
Loss: 5.46126389503479
Loss: 5.370787471532822
Loss: 5.105739146471024
Loss: 4.880185544490814
Loss: 4.755508065223694
Loss: 4.687484264373779
Loss: 4.442636847496033
Loss: 4.120648562908173
Loss: 4.088984429836273
Loss: 3.8568408489227295
Loss: 3.941498562693596
Loss: 3.887954920530319
Loss: 3.7642501294612885
Loss: 3.720744490623474
Loss: 3.5997921526432037
Loss: 3.53359092772007
Loss: 3.4738723188638687
Loss: 3.665855199098587
Loss: 3.5462906509637833
Loss: 3.47406467795372
Loss: 3.575624018907547
Loss: 3.4897633641958237
Loss: 3.577207922935486
Loss: 3.447888672351837
Loss: 3.4478889852762222
Loss: 3.4787664711475372


In [11]:
tmp = logits.argmax(-1) - targets
tmp[targets == 3] = 6
print(tmp)

tensor([[-1,  1,  0, -1,  0, -1,  1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 0,  0,  0,  0, -1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 0,  1, -1, -1, -1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 1,  0, -1, -1,  1,  1,  1,  1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [-1,  0,  1,  0,  0, -1,  0,  2,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 0,  1,  0, -1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [-1,  0,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 0,  1,  1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [-1, -1,  0,  0, -1,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0],
        [ 0, -1, -1, -1, -1,  6,  0,  0,  0,  0,  