In [3]:
import torch
import numpy as np
from transformer_lens import HookedTransformerConfig, HookedTransformer

In [4]:

def get_seq_lengths(total_length, min_length, max_length, rng):
    too_many = rng.integers(min_length, max_length, min_length * total_length)
    sequence_lengths = too_many[np.cumsum(too_many) <= total_length]
    diff = total_length - np.sum(sequence_lengths)
    sequence_lengths = sequence_lengths.tolist()
    if diff >= min_length and diff <= max_length:
        sequence_lengths += [diff]
    return sequence_lengths
    

def generate_cum_parity(total_seq_len: int, rng):
    seq_len = (total_seq_len - 2) // 2
    assert seq_len > 0, total_seq_len
    equals = np.array([2])
    sep = 3
    x = rng.integers(0, 2, seq_len)
    running_x = np.cumsum(x) % 2
    seq = np.concatenate([x, equals, running_x], axis=0)
    seq = np.pad(seq, (1, total_seq_len - len(seq) - 1), mode='constant', constant_values=(sep, sep))
    return seq


def generate_packed_parity(total_seq_length, min_seq_length=6, max_seq_length=30, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    sep = 3
    sequence_lengths = get_seq_lengths(total_seq_length, min_seq_length, max_seq_length, rng)
    assert sum(sequence_lengths) <= total_seq_length
    parities = [
            generate_cum_parity(seq_len, stream) for seq_len, stream in zip(sequence_lengths, rng.spawn(len(sequence_lengths)))
    ]
    parities = np.concatenate(parities)
    diff = total_seq_length - len(parities)
    return np.pad(parities, (0, diff), mode='constant', constant_values=(sep, sep))
    




In [5]:
cfg = {
    "d_model": 128,
    "d_head": 32,
    "n_heads": 2,
    "d_mlp": 512,
    "n_ctx": 512,
    "n_layers": 1,
    "d_vocab": 4,
    "act_fn": "relu"
}

config = HookedTransformerConfig(**cfg)
model = HookedTransformer(config)

In [6]:
model = model.to('cuda:0')

Moving model to device:  cuda:0


In [7]:
rng = np.random.default_rng(10)
parities = np.stack([ generate_packed_parity(512, 8, 32, stream) for stream in rng.spawn(512)], axis=0)


In [18]:
loss = model(torch.asarray(parities, device='cuda:0'), return_type='loss')

In [19]:
loss.backward()

In [29]:
from torch.utils.data import IterableDataset, DataLoader


from concurrent.futures import ProcessPoolExecutor, as_completed




class CumulativeParityDataset:

    def __init__(self, total_sequence_length: int, min_sequence_length: int, max_sequence_length: int, batch_size: int, rng_seed: int = 0):
        super().__init__()
        self.total_sequence_length = total_sequence_length
        self.min_sequence_length = min_sequence_length
        self.max_sequence_length = max_sequence_length
        self.batch_size = batch_size
        self.rng = np.random.default_rng(rng_seed)

    def next(self):
        #worker_info = torch.utils.data.get_worker_info()
        #if worker_info is not None:
        #    all_rngs = [stream for stream in self.rng.spawn(worker_info.num_workers)]
        #    rng = all_rngs[worker_info.id]
        #else:
        rng = self.rng
        parities = [
            generate_packed_parity(
                self.total_sequence_length,
                self.min_sequence_length,
                self.max_sequence_length,
                stream
            ) for stream in rng.spawn(self.batch_size)
        ]
   
        return torch.asarray(np.stack(parities, axis=0))


In [11]:

dataset = CumulativeParityDataset(512, 8, 30, 512, 10)
loader = DataLoader(dataset, num_workers=8, pin_memory=True)

In [13]:
next(iter(loader))

tensor([[3, 0, 0,  ..., 1, 1, 3],
        [3, 1, 1,  ..., 1, 0, 3],
        [3, 1, 0,  ..., 1, 1, 3],
        ...,
        [3, 1, 0,  ..., 1, 0, 1],
        [3, 0, 0,  ..., 3, 3, 3],
        [3, 1, 0,  ..., 1, 0, 3]])

In [30]:
from tqdm.notebook import trange


def train(model, optimizer, scheduler, num_steps, dataloader):

    #loader = iter(dataloader)

    with trange(num_steps) as t:
        for i in t:
            data = dataloader.next()
            optimizer.zero_grad()
            loss = model(data.to('cuda:0'), return_type='loss')
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if i % 100 == 0:
                t.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr())



In [31]:
num_warmup = 100
num_steps = 10_000
seed = 0


cfg = {
    "d_model": 128,
    "d_head": 32,
    "n_heads": 2,
    "d_mlp": 512,
    "n_ctx": 512,
    "n_layers": 1,
    "d_vocab": 4,
    "act_fn": "relu"
}

config = HookedTransformerConfig(**cfg)
model = HookedTransformer(config)


optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.1)
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=num_warmup)
annealing = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(num_steps - num_warmup), eta_min=1.0e-6)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, annealing], milestones=[num_warmup])

dataset = CumulativeParityDataset(512, 6, 32, 512, seed)


In [None]:
train(model, optimizer, scheduler, num_steps, dataset)

  0%|          | 0/10000 [00:00<?, ?it/s]