In [1]:
import torch
import numpy as np
from transformer_lens import HookedTransformerConfig, HookedTransformer

In [2]:
def get_seq_lengths(total_length, min_length, max_length, rng):
    too_many = rng.integers(min_length, max_length, min_length * total_length)
    sequence_lengths = too_many[np.cumsum(too_many) <= total_length]
    diff = total_length - np.sum(sequence_lengths)
    sequence_lengths = sequence_lengths.tolist()
    if diff >= min_length and diff <= max_length:
        sequence_lengths += [diff]
    return sequence_lengths
    

def generate_cum_parity(total_seq_len: int, rng):
    sep = np.array([2])
    prob = rng.beta(2, 2)
    bits = rng.choice(2, (total_seq_len - 1,), replace=True, p=[prob, 1. - prob])
    parities = np.concatenate([sep, np.cumsum(bits) % 2])
    bits = np.concatenate([sep, bits])
    return bits, parities


def generate_packed_parity(total_seq_length, min_seq_length=6, max_seq_length=30, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    sep = 2
    sequence_lengths = get_seq_lengths(total_seq_length, min_seq_length, max_seq_length, rng)
    assert sum(sequence_lengths) <= total_seq_length
    [bits, parities] = list(zip(*[
            generate_cum_parity(seq_len, stream) for seq_len, stream in zip(sequence_lengths, rng.spawn(len(sequence_lengths)))
    ]))
    bits = np.concatenate(bits)
    parities = np.concatenate(parities)
    diff = total_seq_length - len(parities)
    bits = np.pad(bits, (0, diff), mode='constant', constant_values=(sep, sep))
    parities = np.pad(parities, (0, diff), mode='constant', constant_values=(sep, sep))
    return bits, parities


def generate_fixed_parity(total_seq_length, seq_length=30, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    sep = 2
    sequence_lengths = [seq_length] * (total_seq_length // seq_length)
    assert sum(sequence_lengths) <= total_seq_length
    [bits, parities] = list(zip(*[
            generate_cum_parity(seq_len, stream) for seq_len, stream in zip(sequence_lengths, rng.spawn(len(sequence_lengths)))
    ]))
    bits = np.concatenate(bits)
    parities = np.concatenate(parities)
    diff = total_seq_length - len(parities)
    bits = np.pad(bits, (0, diff), mode='constant', constant_values=(sep, sep))
    parities = np.pad(parities, (0, diff), mode='constant', constant_values=(sep, sep))
    return bits, parities



In [3]:
cfg = {
        "d_model": 64,
        "d_head": 32,
        "n_heads": 2,
        "d_mlp": 256,
        "n_ctx": 512,
        "n_layers": 1,
        "d_vocab": 4,
        "act_fn": "relu"
}

ckpt = torch.load('checkpoints/199999.pth')

config = HookedTransformerConfig(**cfg)
model = HookedTransformer(config)
model.load_state_dict(ckpt['model'])

<All keys matched successfully>

In [4]:
model = model.to('cuda:0')

Moving model to device:  cuda:0


In [5]:
from torch.utils.data import IterableDataset

class CumulativeParityFixed(IterableDataset):
    def __init__(self, total_sequence_length: int, sequence_length: int, batch_size: int, rng_seed: int = 0):
        super().__init__()
        self.total_sequence_length = total_sequence_length
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.rng = np.random.default_rng(rng_seed)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            all_rngs = [stream for stream in self.rng.spawn(worker_info.num_workers)]
            rng = all_rngs[worker_info.id]
        else:
            rng = self.rng
        while True:
            [bits, parities] = list(zip(*[
                generate_fixed_parity(
                    self.total_sequence_length,
                    self.sequence_length,
                    stream
                ) for stream in rng.spawn(self.batch_size)
            ]))
            yield torch.asarray(np.stack(bits, axis=0)), torch.asarray(np.stack(parities, axis=0))

In [51]:
data_generator51 = CumulativeParityFixed(512, 32, 512, 2)

In [52]:
iter_data = iter(data_generator51)

In [53]:
bits, parities = next(iter_data)

output = model(bits.to('cuda:0'))

In [46]:
from torch.nn.functional import log_softmax

def seq2seq_cross_entropy_loss(logits, tokens, ignore_token=2):
    log_probs = log_softmax(logits, dim=-1)
    # Use torch.gather to find the log probs of the correct tokens
    # Not using offsets because we're predicting the same token position, new _sequence
    # None and [..., 0] needed because the tensor used in gather must have the same rank.
    predicted_log_probs = log_probs[..., :, :].gather(
        dim=-1, index=tokens[..., :, None]
    )[..., 0]

    log_probs = torch.where(
        tokens != ignore_token,
        predicted_log_probs,
        torch.zeros_like(predicted_log_probs)
    )

    return -log_probs.mean()

In [54]:
seq2seq_cross_entropy_loss(output, parities.to('cuda'), ignore_token=2)

tensor(0.5235, device='cuda:0', grad_fn=<NegBackward0>)

In [49]:

def seq2seq_accuracy(logits, tokens):
    predicted_tok = logits.argmax(dim=-1)
    return (predicted_tok == tokens).to(torch.float32).mean()

In [55]:
seq2seq_accuracy(output, parities.to('cuda'))

tensor(0.6275, device='cuda:0')

In [22]:
from torch.utils.data import IterableDataset, DataLoader
from concurrent.futures import ProcessPoolExecutor, as_completed



class CumulativeParityDataset:

    def __init__(self, total_sequence_length: int, min_sequence_length: int, max_sequence_length: int, batch_size: int, rng_seed: int = 0):
        super().__init__()
        self.total_sequence_length = total_sequence_length
        self.min_sequence_length = min_sequence_length
        self.max_sequence_length = max_sequence_length
        self.batch_size = batch_size
        self.rng = np.random.default_rng(rng_seed)

    def next(self):
        #worker_info = torch.utils.data.get_worker_info()
        #if worker_info is not None:
        #    all_rngs = [stream for stream in self.rng.spawn(worker_info.num_workers)]
        #    rng = all_rngs[worker_info.id]
        #else:
        rng = self.rng
        parities = [
            generate_packed_parity(
                self.total_sequence_length,
                self.min_sequence_length,
                self.max_sequence_length,
                stream
            ) for stream in rng.spawn(self.batch_size)
        ]
   
        return torch.asarray(np.stack(parities, axis=0))


In [11]:

dataset = CumulativeParityDataset(512, 8, 30, 512, 10)
loader = DataLoader(dataset, num_workers=8, pin_memory=True)

In [13]:
next(iter(loader))

tensor([[3, 0, 0,  ..., 1, 1, 3],
        [3, 1, 1,  ..., 1, 0, 3],
        [3, 1, 0,  ..., 1, 1, 3],
        ...,
        [3, 1, 0,  ..., 1, 0, 1],
        [3, 0, 0,  ..., 3, 3, 3],
        [3, 1, 0,  ..., 1, 0, 3]])

In [30]:
from tqdm.notebook import trange


def train(model, optimizer, scheduler, num_steps, dataloader):

    #loader = iter(dataloader)

    with trange(num_steps) as t:
        for i in t:
            data = dataloader.next()
            optimizer.zero_grad()
            loss = model(data.to('cuda:0'), return_type='loss')
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if i % 100 == 0:
                t.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr())



In [31]:
num_warmup = 100
num_steps = 10_000
seed = 0


cfg = {
    "d_model": 128,
    "d_head": 32,
    "n_heads": 2,
    "d_mlp": 512,
    "n_ctx": 512,
    "n_layers": 1,
    "d_vocab": 4,
    "act_fn": "relu"
}

config = HookedTransformerConfig(**cfg)
model = HookedTransformer(config)


optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.1)
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=num_warmup)
annealing = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(num_steps - num_warmup), eta_min=1.0e-6)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, annealing], milestones=[num_warmup])

dataset = CumulativeParityDataset(512, 6, 32, 512, seed)


In [None]:
train(model, optimizer, scheduler, num_steps, dataset)

  0%|          | 0/10000 [00:00<?, ?it/s]