In [4]:
import torch
import numpy as np
from transformer_lens import HookedTransformerConfig, HookedTransformer

In [5]:

def get_seq_lengths(total_length, min_length, max_length, rng):
    too_many = rng.integers(min_length, max_length, min_length * total_length)
    sequence_lengths = too_many[np.cumsum(too_many) <= total_length]
    diff = total_length - np.sum(sequence_lengths)
    sequence_lengths = sequence_lengths.tolist()
    if diff >= min_length and diff <= max_length:
        sequence_lengths += [diff]
    return sequence_lengths
    

def generate_cum_parity(total_seq_len: int, rng):
    seq_len = (total_seq_len - 2) // 2
    assert seq_len > 0, total_seq_len
    equals = np.array([2])
    sep = 3
    x = rng.integers(0, 2, seq_len)
    running_x = np.cumsum(x) % 2
    seq = np.concatenate([x, equals, running_x], axis=0)
    seq = np.pad(seq, (1, total_seq_len - len(seq) - 1), mode='constant', constant_values=(sep, sep))
    return seq


def generate_packed_parity(total_seq_length, min_seq_length=6, max_seq_length=30, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    sep = 3
    sequence_lengths = get_seq_lengths(total_seq_length, min_seq_length, max_seq_length, rng)
    assert sum(sequence_lengths) <= total_seq_length
    parities = [
            generate_cum_parity(seq_len, stream) for seq_len, stream in zip(sequence_lengths, rng.spawn(len(sequence_lengths)))
    ]
    parities = np.concatenate(parities)
    diff = total_seq_length - len(parities)
    return np.pad(parities, (0, diff), mode='constant', constant_values=(sep, sep))
    




In [7]:
cfg = {
    "d_model": 128,
    "d_head": 32,
    "n_heads": 2,
    "d_mlp": 512,
    "n_ctx": 512,
    "n_layers": 1,
    "d_vocab": 4,
    "act_fn": "relu"
}

config = HookedTransformerConfig(**cfg)
model = HookedTransformer(config)

In [8]:
model = model.to('cuda:0')

Moving model to device:  cuda:0


In [17]:
rng = np.random.default_rng(10)
parities = np.stack([ generate_packed_parity(512, 8, 32, stream) for stream in rng.spawn(512)], axis=0)


In [18]:
loss = model(torch.asarray(parities, device='cuda:0'), return_type='loss')

In [19]:
loss.backward()

In [55]:

dataset = CumulativeParityDataset(512, 8, 30, 512, 10)
loader = DataLoader(dataset, batch_size=512, num_workers=8, pin_memory=True)

In [56]:
for i, batch in enumerate(loader):
    if i > 5:
        break
    print(batch.shape)

torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512, 512])
