In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import copy
torch.set_default_device('mps')

In [3]:
tokenized_train = torch.load("tokenized_train.pt")

print(tokenized_train[0])

[{'input_ids': tensor([[    1,   530, 17986,  6559,  1139,   292, 10742, 29899,  1595,  2512,
           868,  2877,   297, 12101,  7208,  1199, 29889, 21724, 29968, 29899,
         29941,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])}, {'input_ids': tensor([[    1,  7684,  1655,  ...,   891,   891,   891],
        [    1,   425, 10732,  ...,   891,   891,   891],
        [    1,  7684,  1655,  ...,   891,   891,   891],
        ...,
        [    1, 14263, 29895,  ...,   891,   891,   891],
        [    1,   425, 10732,  ...,   891,   891,   891],
        [    1,   425, 10732,  ...,   891,   891,   891]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ...

In [4]:
llama_tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
llama_tok.pad_token_id = 891

In [5]:
llama_mod = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
llama_mod.to('mps')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [6]:
class LlamaRetrofit(torch.nn.Module):
    def __init__(self, llama):
        super().__init__()

        self.emb = llama.model.embed_tokens
        self.blocks = llama.model.layers
        self.norm = llama.model.norm
        self.head = llama.lm_head

        # deep copy the first 3 blocks to repourpose for the encoder
        self.encoder = torch.nn.ModuleList([copy.deepcopy(self.blocks[i]) for i in range(2)])
        self.cross_attn = torch.nn.ModuleList([torch.nn.MultiheadAttention(4096, 32, batch_first=True) for _ in range(len(self.blocks)//8)])
    

    def forward(self, x, context):
        x_pad_mask = x['attention_mask']
        x = x['input_ids']
        context = context['input_ids']
        x = self.emb(x)
        context = self.emb(context)
        for i, block in enumerate(self.encoder):
            context = block(context)[0]
        # tile the mask to match the batch size (x.size[0]), and unsqueeze on dim 1
        mask = torch.full(
                (x.shape[0], 1, x.shape[1], x.shape[1]), float("-inf"), device=x.device
            )
        mask = torch.triu(mask, diagonal=1)
        # replace nan with 0
        mask = mask.masked_fill(mask != mask, 0.0)
        # combine the mask with the padding mask
        # attention mask has 0s where there is padding, so it needs to be replaced with -inf
        mask = mask.masked_fill(x_pad_mask.unsqueeze(1) == 0, float("-inf"))
        print(mask)
        for i, block in enumerate(self.blocks):
            x = block(x, mask)[0]
            if i % 8 == 0:
                # take the average of cross attention with each item in the context along the batch dim
                crossed = self.cross_attn[i//8](x, context[0].unsqueeze(0), context[0].unsqueeze(0))[0]
                for j in range(1, len(context)):
                    crossed += self.cross_attn[i//8](x, context[j].unsqueeze(0), context[j].unsqueeze(0))[0]
                # average the cross attention
                x = torch.mean(crossed, dim=0).unsqueeze(0)


        x = self.norm(x)
        x = self.head(x)
        return x

test_model = LlamaRetrofit(llama)
# save state dicts for encoder and cross_attn from test_model
torch.save(test_model.encoder.state_dict(), 'encoder.pt')
torch.save(test_model.cross_attn.state_dict(), 'cross_attn.pt')



cpu
tensor([[[[0., -inf, -inf,  ..., -inf, -inf, -inf],
          [0., 0., -inf,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          ...,
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf]]]], device='mps:0')
tensor([[[-0.5688, -0.9764,  2.7797,  ..., -5.1534, -3.9327,  0.4894],
         [-1.2674, -3.0520,  1.5501,  ..., -2.0579, -5.4599, -1.8719],
         [-2.3778, -3.8658, -0.2674,  ..., -2.4051, -4.9045,  0.6033],
         ...,
         [-1.8827, -3.6941,  0.5543,  ..., -1.9751, -4.6183,  0.0827],
         [-1.8824, -3.6932,  0.5546,  ..., -1.9755, -4.6181,  0.0824],
         [-1.8818, -3.6917,  0.5554,  ..., -1.9762, -4.6173,  0.0824]]],
       device='mps:0') torch.Size([1, 32, 32000])


In [None]:
model = LlamaRetrofit(llama_mod)

# optimize model.cross_attn.parameters() and model.encoder.parameters()
print(tokenized_train[0][0]['input_ids'].device)
with torch.no_grad():
    retrofit_out = model(tokenized_train[0][0].to('mps'), tokenized_train[0][1].to('mps'))
    print(retrofit_out, retrofit_out.shape)

In [11]:
# optimize both model.cross_attn.parameters() and model.encoder.parameters()
params_to_optimize = list(model.cross_attn.parameters()) + list(model.encoder.parameters())
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

optimizer.zero_grad()
for i in range(len(tokenized_train)):
    # label is shifted by 1 with an extra padding token, from tokenized_train[i][0]['input_ids']
    label = torch.roll(tokenized_train[i][0]['input_ids'], -1).to('mps')

    output = model(tokenized_train[i][0].to('mps'), tokenized_train[i][1].to('mps'))
    # argmax and decode
    stringy = torch.argmax(output, dim=-1)
    stringy = llama_tok.decode(stringy[0])
    print(stringy)
    print(output.shape, label.shape)
    # get the first unmasked token and clip the label and output, this can be done by getting the first '0' in the attention mask
    firstzero = torch.argmax(1 - tokenized_train[i][0]['attention_mask']).item()
    print(firstzero)
    label = label[:, :firstzero]
    output = output[:, :firstzero, :]
    # print the label as well
    print(llama_tok.decode(label[0]))
    loss = loss_fn(torch.permute(output, (0,2,1)), label)
    loss.backward()
    if i % 16 == 0:
        optimizer.step()
        optimizer.zero_grad()
    print(loss.item())

# save weights
torch.save(model.state_dict(), 'model.pt')

tensor([[[[0., -inf, -inf,  ..., -inf, -inf, -inf],
          [0., 0., -inf,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          ...,
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf]]]], device='mps:0')
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
torch.Size([1, 32, 32000]) torch.Size([1, 32])
21


In [None]:
# log current folder location
import os
os.getcwd()

In [None]:

optimizer = torch.optim.Adam(model.cross_attn.parameters(), lr=1e-4)
optimizer.zero_grad()

loss_fn = torch.nn.CrossEntropyLoss()
for i in range(len(tokenized_train)):
    label = tokenized_train[i][0][0][1:]
    # add a padding token to the beginning of the label
    label = torch.cat([torch.tensor([0]), label])
    inp = torch.cat([tokenized_train[i][0], tokenized_train[i][1]])
    output = model(inp)
    # argmax and decode
    stringy = torch.argmax(output, dim=-1)
    stringy = llama_tok.decode(stringy)
    print(stringy)
    loss = loss_fn(output, label)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())



In [14]:
print(stringy)
print(llama_tok.decode(stringy))

tensor([ 3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130, 16020, 16020, 16020, 16020, 16020, 16020, 16020, 16020,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130, 

In [6]:
# test_adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=16)

# optimizer = torch.optim.Adam(test_adapter.parameters(), lr=1e-4)
# sample_out = test_adapter(torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]))
# print(sample_out.shape)


# sample_out.sum().backward()
# optimizer.step()

import sqlite3

conn = sqlite3.connect('data.db')

# create a cursor object
cur = conn.cursor()

cur.execute('SELECT content, author FROM discord ORDER BY timestamp')
data = cur.fetchall()

# concat author name to all messages
data = [f'{author}: {content}' for content, author in data]

# combine all messages into one string
data = '\n\n'.join(data)

data = data[:100000]

t5_tokenized_data = t5_tok.encode(data, return_tensors='pt')
llama_tokenized_data = llama_tok.encode(data, return_tensors='pt')
print(t5_tokenized_data.shape)
print(llama_tokenized_data.shape)


Token indices sequence length is longer than the specified maximum sequence length for this model (28985 > 512). Running this sequence through the model will result in indexing errors


torch.Size([1, 28985])
torch.Size([1, 30485])


In [7]:

adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=8)


params_to_optimize = list(adapter.llama_queries.parameters()) + list(adapter.t5_keys.parameters()) + list(adapter.t5_values.parameters())
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-4)
# optimizer = torch.optim.Adam(adapter.parameters(), lr=1e-4)

total_llama_tokens = llama_tokenized_data.shape[1]

# for the training loop, for each chunk passed to the llama decoder, the previous 16 chunks of 512 t5-tokens are passed to the t5 encoder
epochs = 4
llama_batch_size = 1024
t5_batch_size = 512
for epoch in range(epochs):
    for i in range(t5_batch_size*20, total_llama_tokens, llama_batch_size):
        print(f'Epoch {epoch}, batch {i}')
        # get the next batch of data
        llama_in = llama_tokenized_data[:, i:i+llama_batch_size]
        # to get approximately the same spot in the chat for the t5 tokenization, use the current index/total_llama_tokens, and multiply by the total t5 tokens
        t5_index = int(i/total_llama_tokens*t5_tokenized_data.shape[1])
        t5_in = t5_tokenized_data[:, t5_index-t5_batch_size*8:t5_index]
        # reshape the t5 input to be 16 chunks of 512 tokens
        t5_in = t5_in.reshape(-1, t5_batch_size)
        # pass the data through the model
        out = adapter(llama_in, t5_in)
        # calculate loss by shifting the target data by 1
        loss = torch.nn.functional.cross_entropy(out[:, :-1].reshape(-1, out.shape[-1]), llama_in[:, 1:].reshape(-1))
        # backpropagate
        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0)
        # update parameters
        optimizer.step()
        # zero gradients
        optimizer.zero_grad()
        print(f'Loss: {loss.item()}')
        del llama_in, t5_in, out, loss



Epoch 0, batch 10240
Loss: 11.575541496276855
Epoch 0, batch 11264
Loss: 9.961394309997559
Epoch 0, batch 12288
