In [26]:
from ben_utils import SaveForward
from collections import defaultdict
import csv
import einops
import matplotlib.pyplot as plt
import sys
import torch as t
import torchtext
from tqdm import tqdm
import transformers

In [2]:
devices = [f'cuda:{i}' for i in [2,3]]

In [3]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

In [4]:
gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(devices[0]).eval()

In [5]:
class BeheadedGPT2(t.nn.Module):
    def __init__(self, gpt2, n_layer=None):
        super().__init__()
        if n_layer is None:
            n_layer = gpt2.config.n_layer
        self.blocks = t.nn.ModuleList(
            [transformers.models.gpt2.modeling_gpt2.GPT2Block(gpt2.config) for _ in range(n_layer)]
        )
        self.ln_f = t.nn.LayerNorm((gpt2.config.n_embd,), eps=1e-05, elementwise_affine=True)
        self.lm_head = t.nn.Linear(in_features=gpt2.config.n_embd, out_features=gpt2.config.vocab_size)

    def forward(self, x):
        for block in self.blocks:
            (x,) = block(x)
        x = self.ln_f(x)
        x = self.lm_head(x)
        return x

In [6]:
def batches_from_data(data, max_seq_len):
    assert max_seq_len > 0
    batch_list = []
    next_batch = t.zeros(max_seq_len, dtype=t.long)
    remaining_space_in_next_batch = max_seq_len
    for item in data:
        tokens = tokenizer.encode(item, return_tensors="pt").squeeze(0)
        while len(tokens):
            if remaining_space_in_next_batch <= 0:
                batch_list.append(next_batch)
                next_batch = t.zeros(max_seq_len, dtype=t.long)
                remaining_space_in_next_batch = max_seq_len
            n = remaining_space_in_next_batch
            tokens_left = len(tokens)
            for_this_batch, tokens = tokens[:n], tokens[n:]
            assert for_this_batch.shape == (min(n, tokens_left),), (n, tokens_left, tokens, for_this_batch)
            write_start = max_seq_len - remaining_space_in_next_batch
            write_end = write_start + len(for_this_batch)
            next_batch[write_start:write_end] = for_this_batch
            remaining_space_in_next_batch -= len(for_this_batch)
            if write_end < max_seq_len and remaining_space_in_next_batch > 0:
                next_batch[write_end] = tokenizer.eos_token_id
                remaining_space_in_next_batch -= 1
    return batch_list

[
    (batch, tokenizer.decode(batch))
    for batch in batches_from_data(
            ['hello there', 'how are you doing', 'here is some sample data'],
            max_seq_len=6,
        )
]

[(tensor([31373,   612, 50256,  4919,   389,   345]),
  'hello there<|endoftext|>how are you'),
 (tensor([ 1804, 50256,  1456,   318,   617,  6291]),
  ' doing<|endoftext|>here is some sample')]

In [7]:
def load_data(data, batch_size, max_seq_len):
    return t.utils.data.DataLoader(batches_from_data(data, max_seq_len=max_seq_len), batch_size=batch_size, shuffle=True)

def load_dataset(**kwargs):
    data_train, data_test = torchtext.datasets.WikiText2(root='.data', split=('train', 'test'))
    return load_data(data_train, **kwargs), load_data(data_test, **kwargs)

In [12]:
data_train, data_test = load_dataset(batch_size=1, max_seq_len=gpt2.config.n_positions)

In [27]:
def train(beheaded, dataloader):
    layer_input_to_invert = gpt2.transformer.h[0]
    optimizer = t.optim.Adam(beheaded.parameters())
    logged_losses = []

    for i, input_ids in enumerate(dataloader):
        input_ids = input_ids.to(gpt2.device)
        batch_size, seq_len = input_ids.shape
        with SaveForward(layer_input_to_invert) as saved:
            with t.no_grad():
                gpt2(input_ids)
            (to_invert,) = saved.saved_input
            assert to_invert.shape == (batch_size, seq_len, gpt2.config.n_embd)
        output_logits = beheaded(to_invert)
        assert output_logits.shape == (batch_size, seq_len, tokenizer.vocab_size)
        output_logits = einops.rearrange(output_logits, 'b i w -> (b i) w')
        assert output_logits.shape == (batch_size * seq_len, tokenizer.vocab_size)
        target = einops.rearrange(input_ids, 'b i -> (b i)')
        assert target.shape == (batch_size * seq_len,)
        loss = t.nn.functional.cross_entropy(output_logits, target)
        logged_losses.append(loss.item())
        loss.backward()
        optimizer.step()

    plt.scatter(range(len(logged_losses)), logged_losses)
    plt.show()

In [22]:
import gc
gc.collect()

458

In [None]:
print(len(data_train))

In [29]:
beheaded = BeheadedGPT2(gpt2, n_layer=2).to(gpt2.device).train()
train(beheaded, dataloader=data_train)

KeyboardInterrupt: 