In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
torch.set_default_device('mps')

In [2]:

llama_tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
llama_mod = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
import random
import string

In [19]:
# for the sample input, generate 16 input_id tensors with random integers,  each of length 512
sample_in = torch.randint(0, 1000, (16, 512))


In [21]:

with torch.no_grad():
    output = llama_mod.generate(sample_in, max_length=514)
    print(llama_tok.batch_decode(output, skip_special_tokens=True))

['oneightраng R em strGject и�� _se commagetr o� go +ian down hadvank8 "on�),readound��� arep6ier{\\iesitt return             just7ontlectлиKiosttext B $ er\x16 yourleiew Sres�ru_ Bop hasph\x14каptabов thisialestionht *V par filougfter� Zps�ang\x05 ag� di            �is numésok�oryallageainomeel� ad� Vber set ccri herдаF have� im� value //irре�priz Alavht� des�be R\x16 / duIive useover In&iveothUam Inie�� pre some�� am�� de emxtode su�ith...actumreeIn any _ре other shick�овoteto поu (ptionag overваGans {icaber�� haveert dbe /ic inst�andund returnij your Comment toseieldml� Aniot\x17 newxtstrirstat< lo likeange Leand does whereickтаromra�то� wh� м et Youilrom наable oneodeield byethodsk les value���� ha [ vent \\amplebetern \'ipда helassie diec", pr usance[ [ide queound where "�ableize�� Seultof� QuestionM He�� they8qhttps likephexranong \\elodIize what\x1b!yngete -ameicO Seiseere Tнеpe vain rolowinal soos\x19 all suти H зio w Tnt�ль� only#enste��� Zror -ely sh�ка поен\x0e вase\x7fazuво

In [3]:
t5_tok = AutoTokenizer.from_pretrained("google/flan-t5-large")
t5_mod = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")


In [4]:
# tensor of sample integers to test the encoder
sample_ints = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]).to('mps')
sample_out = t5_mod.encoder(sample_ints, output_hidden_states=True).hidden_states
print(len(sample_out))
print(sample_out[0].shape)

print(llama_mod.model.layers[3](torch.randn(2, 4, 4096))[0].shape)

25
torch.Size([2, 4, 1024])
torch.Size([2, 4, 4096])


In [5]:
class LlamaT5Adapter(torch.nn.Module):
    def __init__(self, t5, llama, ctx_num=8):
        super().__init__()
        self.encoder = t5.encoder
        self.decoder_embed_tokens = llama.model.embed_tokens
        self.decoder_layers = llama.model.layers
        self.decoder_norm = llama.model.norm
        self.lm_head = llama.lm_head
        self.context_num = ctx_num
        # add cross attention between the following t5-encoder layers and llama layers: [(3,4), (6,8), (9,12), (15,16), (18,20), (21,24), (23, 28)]
        # llama needs to be projected down from 4096 to 1024 for cross attention

        self.llama_queries = torch.nn.ModuleList([torch.nn.Linear(4096, 4096) for _ in range(8)])
        self.t5_keys = torch.nn.ModuleList([torch.nn.Linear(1024, 4096) for _ in range(8)])
        self.t5_values = torch.nn.ModuleList([torch.nn.Linear(1024, 4096) for _ in range(8)])

    # forward pass 
    def forward(self, decoder_in, encoder_in):
        enc_context = self.encoder(encoder_in, output_hidden_states=True)
        enc_hidden_states = enc_context.hidden_states

        decoder_out = self.decoder_embed_tokens(decoder_in)

        # cross attention
        for i in range(8):
            for j in range(4):
                decoder_out = self.decoder_layers[i*4+j](decoder_out)[0]
            # duplicate the decoder output to match the batch size of the encoder output, needs to be interleaved
            decoder_out = decoder_out.repeat_interleave(self.context_num, dim=0)
            # cross attention
            decoder_out = torch.nn.functional.scaled_dot_product_attention(query=self.llama_queries[i](decoder_out), key=self.t5_keys[i](enc_hidden_states[i*3]), value=self.t5_values[i](enc_hidden_states[i*3]))
            # average every bin of 2 in the batch to get half the batch size. This is because the encoder gets 2 inputs for every 1 input to the decoder, so we need to average the encoder outputs to get the same batch size
            decoder_out = decoder_out.reshape(decoder_out.shape[0]//self.context_num, self.context_num, decoder_out.shape[1], decoder_out.shape[2]).mean(dim=1)
        decoder_out = self.decoder_norm(decoder_out)
        decoder_out = self.lm_head(decoder_out)
        return decoder_out

In [6]:
# test_adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=16)

# optimizer = torch.optim.Adam(test_adapter.parameters(), lr=1e-4)
# sample_out = test_adapter(torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]))
# print(sample_out.shape)


# sample_out.sum().backward()
# optimizer.step()

import sqlite3

conn = sqlite3.connect('data.db')

# create a cursor object
cur = conn.cursor()

cur.execute('SELECT content, author FROM discord ORDER BY timestamp')
data = cur.fetchall()

# concat author name to all messages
data = [f'{author}: {content}' for content, author in data]

# combine all messages into one string
data = '\n\n'.join(data)

data = data[:100000]

t5_tokenized_data = t5_tok.encode(data, return_tensors='pt')
llama_tokenized_data = llama_tok.encode(data, return_tensors='pt')
print(t5_tokenized_data.shape)
print(llama_tokenized_data.shape)


Token indices sequence length is longer than the specified maximum sequence length for this model (28985 > 512). Running this sequence through the model will result in indexing errors


torch.Size([1, 28985])
torch.Size([1, 30485])


In [7]:

adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=8)


params_to_optimize = list(adapter.llama_queries.parameters()) + list(adapter.t5_keys.parameters()) + list(adapter.t5_values.parameters())
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-4)
# optimizer = torch.optim.Adam(adapter.parameters(), lr=1e-4)

total_llama_tokens = llama_tokenized_data.shape[1]

# for the training loop, for each chunk passed to the llama decoder, the previous 16 chunks of 512 t5-tokens are passed to the t5 encoder
epochs = 4
llama_batch_size = 1024
t5_batch_size = 512
for epoch in range(epochs):
    for i in range(t5_batch_size*20, total_llama_tokens, llama_batch_size):
        print(f'Epoch {epoch}, batch {i}')
        # get the next batch of data
        llama_in = llama_tokenized_data[:, i:i+llama_batch_size]
        # to get approximately the same spot in the chat for the t5 tokenization, use the current index/total_llama_tokens, and multiply by the total t5 tokens
        t5_index = int(i/total_llama_tokens*t5_tokenized_data.shape[1])
        t5_in = t5_tokenized_data[:, t5_index-t5_batch_size*8:t5_index]
        # reshape the t5 input to be 16 chunks of 512 tokens
        t5_in = t5_in.reshape(-1, t5_batch_size)
        # pass the data through the model
        out = adapter(llama_in, t5_in)
        # calculate loss by shifting the target data by 1
        loss = torch.nn.functional.cross_entropy(out[:, :-1].reshape(-1, out.shape[-1]), llama_in[:, 1:].reshape(-1))
        # backpropagate
        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0)
        # update parameters
        optimizer.step()
        # zero gradients
        optimizer.zero_grad()
        print(f'Loss: {loss.item()}')
        del llama_in, t5_in, out, loss



Epoch 0, batch 10240
Loss: 11.575541496276855
Epoch 0, batch 11264
Loss: 9.961394309997559
Epoch 0, batch 12288
