In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
torch.set_default_device('mps')

In [2]:
llama_tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
llama_tok.pad_token_id = 891

In [3]:
llama_mod = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
llama_mod.to('mps')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [35]:
from sentence_transformers import SentenceTransformer
emb_mod = SentenceTransformer('BAAI/bge-large-en')


In [4]:

# load link_pairs.json
import json


In [6]:
with open('link_pairs.json', 'r') as f:
    link_pairs = json.load(f)


In [5]:
with open('link_pairs_emb.json', 'r') as f:
    link_pairs = json.load(f)

In [None]:

# for each pair, add a third item, the embedding vector
for pair in link_pairs:
    link_embed = emb_mod.encode(pair['link']).tolist()
    pair['lembed'] = link_embed
    summary_embeds = []
    for summary in pair['summaries']:
        summary_embeds.append(emb_mod.encode(summary).tolist())
    pair['sembeds'] = summary_embeds

# save link_pairs.json
with open('link_pairs_emb.json', 'w') as f:
    json.dump(link_pairs, f)




In [6]:
# import cosine similarity function
from sklearn.metrics.pairwise import cosine_similarity

# function, given a query and a context (link), return the top 7 most similar contexts plus the original context if it isnt in the top 7
def get_top7(query_embed, context):
    # get the cosine similarity between the query and all the other contexts
    cos_sims = []
    for pair in link_pairs:
        if pair['link'] != context:
            cos_sims.append(cosine_similarity([query_embed], [pair['lembed']]))
    # get the top 7 most similar contexts
    top7 = []
    for i in range(6):
        top7.append(link_pairs[cos_sims.index(max(cos_sims))]['link'])
        cos_sims[cos_sims.index(max(cos_sims))] = -1
    top7.append(context)
    return top7


In [7]:
train_data = []
for pair in link_pairs:
    for sem, summ in zip(pair['sembeds'], pair['summaries']):
        train_data.append([summ, get_top7(sem, pair['link'])])

print(len(train_data))
print(train_data[0])



628
['An experimental study questioning wave-particle duality in quantum mechanics.', ['obscillesk: https://www.tomshardware.com/news/quantum-computing-researchers-achieve-100-million-quantum-operations more quantum shenanigans\nLink: "Quantum Computing: Researchers Achieve 100 Million Quantum Operations" Description: "That\'s a lot of processing power being used in five-second workloads."\n\n\njcorvinus: I think I\'m gonna need a bigger R&D budget lol\n\n\nlauren0001: manhattan project for ASI is publicly underway though so idk @_@\n\n\nlauren0001: absolutely\n\n\njcorvinus: I wonder if they\'ll designate it as ITAR restricted\n\n\nlauren0001: idk about full singularity\n\n\nlauren0001: life extension should become a button to push\n\n\nlauren0001: if we can get to hard superintelligence in the next couple of years after human level, then assuming the org that does this is aligned with humanity personally, then global warming should be possible to solve via unreasonably advanced bioen

In [8]:
import random

In [None]:

# for each item in train data, shuffle the top7, and then replace each link string in the chat snippet with LINK-(1-8), and append the correct link ID to the end of the summary
for i in range(len(train_data)):
    # the last item in train_data[i][1] is the correct link context, so we need to track its position
    correct_link = train_data[i][1][-1]
    # shuffle the top7
    random.shuffle(train_data[i][1])
    correct_id = 0
    for j in range(len(train_data[i][1])):
        if train_data[i][1][j] == correct_link:
            correct_id = j+1
        # add [LINK-1] [LINK-2] ... [LINK-8] before the 'http' in the chat snippet
        train_data[i][1][j] = train_data[i][1][j].replace('http', ' [LINK-' + str(j+1) + '] http')

    # append the correct link ID to the end of the summary
    train_data[i][0] = train_data[i][0] + ' LINK-' + str(correct_id)


print(train_data[0])

In [10]:
llama_tok.padding_side = 'left'

In [11]:
tokenized_train = []
for i in range(len(train_data)):
    item = []
    item.append(llama_tok(train_data[i][0], padding='max_length', truncation=True, max_length=512, return_tensors='pt')['input_ids'])
    item.append(llama_tok(train_data[i][1], padding='max_length', truncation=True, max_length=512, return_tensors='pt')['input_ids'])
    tokenized_train.append(item)

print(tokenized_train[0][0])
print(tokenized_train[0][1])

tensor([[  891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,   891,   891,   891,   891,
           891,   891,   891,   891,   891,   891,  

In [12]:
class LlamaRetrofit(torch.nn.Module):
    def __init__(self, llama, rtr_num=7):
        super().__init__()

        self.rtr_num = rtr_num

        self.emb = llama.model.embed_tokens
        self.blocks = llama.model.layers
        self.norm = llama.model.norm
        self.head = llama.lm_head

        self.cross_attn = torch.nn.ModuleList([torch.nn.MultiheadAttention(4096, 32, batch_first=True) for _ in range(len(self.blocks)//4)])

    def forward(self, x):
        x = self.emb(x)
        for i, block in enumerate(self.blocks):
            x = block(x)[0]
            if i % 4 == 0 and i != 0:
                # do self attention as normal, but then reshape the batch into [-1, rtr_num+1, seq_len, hidden_size], and use the q, k, and v from the current layer to perform cross attention from item 0 to 1-rtr_num, and vis versa, to move information between the retrieved items and the query.
                tiled = x[0].unsqueeze(0).repeat(self.rtr_num, 1, 1)
                print(tiled.shape)
                new_query = self.cross_attn[i//4](tiled, x[1:], x[1:])[0]
                # average across batch
                new_query = new_query.mean(dim=0)
                x = torch.cat([new_query.unsqueeze(0), x[1:]], dim=0)
        x = self.norm(x)
        x = self.head(x)
        return x[0]

# create test model and pass a sample batch of 16
model = LlamaRetrofit(llama_mod)
# move to cuda
model = model.cuda()

optimizer = torch.optim.Adam(model.cross_attn.parameters(), lr=1e-4)
optimizer.zero_grad()

loss_fn = torch.nn.CrossEntropyLoss()
for i in range(len(tokenized_train)):
    label = tokenized_train[i][0][0][1:]
    # add a padding token to the beginning of the label
    label = torch.cat([torch.tensor([0]), label])
    inp = torch.cat([tokenized_train[i][0], tokenized_train[i][1]])
    output = model(inp)
    # argmax and decode
    stringy = torch.argmax(output, dim=-1)
    stringy = llama_tok.decode(stringy)
    print(stringy)
    loss = loss_fn(output, label)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())



torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
torch.Size([7, 512, 4096])
HColisolisHColisHCHCHCظظظظظظظظظظظظظظظظظظظظظظظظظظظظظظHCHCantinHCHCHCHCHCHCHCHCHColisHCHCHCHColisolisHCHCHCHCHCLSLSHCLSLSLSHCHCHCHCHCHCHCHCظظظظظظظظظظظظظظHCHCLSLSLSHCessaHCHCHColiseedHCHCeedHCHCHCHCHCHCHCHCHCHCHCHCHCLSLSLSHCLSLSLSLSLSLSLSLSLSLSLSLSLSLSLSLSLSظظظظظظظظظظظظظظHCHCHCHC Victoria Victoria Victoria Victoria VictoriaHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCHCظظظظظظظظظظظHCظظظظessaLSLSLSLSLSLSLSLSLSLSLStegerHCHCHCHCHCHCHCHCHCHCHCHCHCHCLSLSLSLSHCHCHCHCHCHCظHCHCHCHCHCHCHCershellershellHCLSHCHCHCHCHCLSLSLSHCHCLSLSLSLSLSLSLS Victoria VictoriaHCHC Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Victoria Vict

In [14]:
print(stringy)
print(llama_tok.decode(stringy))

tensor([ 3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130, 16020, 16020, 16020, 16020, 16020, 16020, 16020, 16020,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130,
         3130,  3130,  3130,  3130,  3130,  3130,  3130,  3130, 

In [6]:
# test_adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=16)

# optimizer = torch.optim.Adam(test_adapter.parameters(), lr=1e-4)
# sample_out = test_adapter(torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]))
# print(sample_out.shape)


# sample_out.sum().backward()
# optimizer.step()

import sqlite3

conn = sqlite3.connect('data.db')

# create a cursor object
cur = conn.cursor()

cur.execute('SELECT content, author FROM discord ORDER BY timestamp')
data = cur.fetchall()

# concat author name to all messages
data = [f'{author}: {content}' for content, author in data]

# combine all messages into one string
data = '\n\n'.join(data)

data = data[:100000]

t5_tokenized_data = t5_tok.encode(data, return_tensors='pt')
llama_tokenized_data = llama_tok.encode(data, return_tensors='pt')
print(t5_tokenized_data.shape)
print(llama_tokenized_data.shape)


Token indices sequence length is longer than the specified maximum sequence length for this model (28985 > 512). Running this sequence through the model will result in indexing errors


torch.Size([1, 28985])
torch.Size([1, 30485])


In [7]:

adapter = LlamaT5Adapter(t5_mod, llama_mod, ctx_num=8)


params_to_optimize = list(adapter.llama_queries.parameters()) + list(adapter.t5_keys.parameters()) + list(adapter.t5_values.parameters())
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-4)
# optimizer = torch.optim.Adam(adapter.parameters(), lr=1e-4)

total_llama_tokens = llama_tokenized_data.shape[1]

# for the training loop, for each chunk passed to the llama decoder, the previous 16 chunks of 512 t5-tokens are passed to the t5 encoder
epochs = 4
llama_batch_size = 1024
t5_batch_size = 512
for epoch in range(epochs):
    for i in range(t5_batch_size*20, total_llama_tokens, llama_batch_size):
        print(f'Epoch {epoch}, batch {i}')
        # get the next batch of data
        llama_in = llama_tokenized_data[:, i:i+llama_batch_size]
        # to get approximately the same spot in the chat for the t5 tokenization, use the current index/total_llama_tokens, and multiply by the total t5 tokens
        t5_index = int(i/total_llama_tokens*t5_tokenized_data.shape[1])
        t5_in = t5_tokenized_data[:, t5_index-t5_batch_size*8:t5_index]
        # reshape the t5 input to be 16 chunks of 512 tokens
        t5_in = t5_in.reshape(-1, t5_batch_size)
        # pass the data through the model
        out = adapter(llama_in, t5_in)
        # calculate loss by shifting the target data by 1
        loss = torch.nn.functional.cross_entropy(out[:, :-1].reshape(-1, out.shape[-1]), llama_in[:, 1:].reshape(-1))
        # backpropagate
        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0)
        # update parameters
        optimizer.step()
        # zero gradients
        optimizer.zero_grad()
        print(f'Loss: {loss.item()}')
        del llama_in, t5_in, out, loss



Epoch 0, batch 10240
Loss: 11.575541496276855
Epoch 0, batch 11264
Loss: 9.961394309997559
Epoch 0, batch 12288
