In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
torch.set_default_device('mps')

In [2]:
pig_tok = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-350m")

In [3]:
pig_mod = AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-350m")
pig_mod.to('mps')

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=409

In [4]:
from sentence_transformers import SentenceTransformer
emb_mod = SentenceTransformer('BAAI/bge-large-en')


In [5]:

# load link_pairs.json
import json


In [6]:
with open('link_pairs.json', 'r') as f:
    link_pairs = json.load(f)


In [7]:
# with open('link_pairs_emb.json', 'r') as f:
#     link_pairs = json.load(f)

In [10]:

# for each pair, add a third item, the embedding vector
for pair in link_pairs:
    link_embed = emb_mod.encode(pair['link']).tolist()
    pair['lembed'] = link_embed
    summary_embeds = []
    for summary in pair['summaries']:
        summary_embeds.append(emb_mod.encode(summary).tolist())
    pair['sembeds'] = summary_embeds

# save link_pairs.json
with open('link_pairs_emb.json', 'w') as f:
    json.dump(link_pairs, f)




In [None]:
# import cosine similarity function
from sklearn.metrics.pairwise import cosine_similarity

# function, given a query and a context (link), return the top 7 most similar contexts plus the original context if it isnt in the top 7
def get_top7(query_embed, context):
    # get the cosine similarity between the query and all the other contexts
    cos_sims = []
    for pair in link_pairs:
        if pair['link'] != context:
            cos_sims.append(cosine_similarity([query_embed], [pair['lembed']]))
    # get the top 7 most similar contexts
    top7 = []
    for i in range(6):
        top7.append(link_pairs[cos_sims.index(max(cos_sims))]['link'])
        cos_sims[cos_sims.index(max(cos_sims))] = -1
    top7.append(context)
    return top7


In [None]:
train_data = []
for pair in link_pairs:
    for sem, summ in zip(pair['sembeds'], pair['summaries']):
        train_data.append([summ, get_top7(sem, pair['link'])])

print(len(train_data))
print(train_data[0])



In [None]:
import random

In [None]:

# for each item in train data, shuffle the top7, and then replace each link string in the chat snippet with LINK-(1-8), and append the correct link ID to the end of the summary
for i in range(len(train_data)):
    # the last item in train_data[i][1] is the correct link context, so we need to track its position
    correct_link = train_data[i][1][-1]
    # shuffle the top7
    random.shuffle(train_data[i][1])
    correct_id = 0
    for j in range(len(train_data[i][1])):
        if train_data[i][1][j] == correct_link:
            correct_id = j+1
        # add [LINK-1] [LINK-2] ... [LINK-8] before the 'http' in the chat snippet
        train_data[i][1][j] = train_data[i][1][j].replace('http', ' [LINK-' + str(j+1) + '] http')

    # append the correct link ID to the end of the summary
    train_data[i][0] = train_data[i][0] + ' LINK-' + str(correct_id)


print(train_data[0])

In [None]:
pig_tok.padding_side = 'left'

In [None]:
tokenized_train = []
for i in range(len(train_data)):
    item = []
    item.append(pig_tok(train_data[i][0], padding='max_length', truncation=True, max_length=512, return_tensors='pt')['input_ids'])
    item.append(pig_tok(train_data[i][1], padding='max_length', truncation=True, max_length=512, return_tensors='pt')['input_ids'])
    tokenized_train.append(item)

print(tokenized_train[0][0])
print(tokenized_train[0][1])

#shuffle
random.shuffle(tokenized_train)

tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,  

In [None]:
class LlamaRetrofit(torch.nn.Module):
    def __init__(self, llama, rtr_num=7):
        super().__init__()

        self.rtr_num = rtr_num

        self.emb = llama.model.decoder.embed_tokens
        self.blocks = llama.model.decoder.layers
        # self.norm = llama.model.decoder.norm
        self.head = llama.lm_head
    #           (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
    #   (project_out): Linear(in_features=1024, out_features=512, bias=False)
    #   (project_in): Linear(in_features=512, out_features=1024, bias=False)
        self.embed_positions = llama.model.decoder.embed_positions
        self.project_out = llama.model.decoder.project_out
        self.project_in = llama.model.decoder.project_in

        self.cross_attn = torch.nn.ModuleList([torch.nn.MultiheadAttention(1024, 32, batch_first=True) for _ in range(len(self.blocks)//4)])

    def forward(self, x):
        x = self.emb(x)
        x = self.project_in(torch.permute(x, [1, 0, 2]))
        posit = self.embed_positions(torch.ones_like(x[:, :, 0]).long())
        x = x + posit
        for i, block in enumerate(self.blocks):
            x = block(x)[0]
            if i % 4 == 0 and i != 0:
                x = torch.permute(x, [1, 0, 2])
                # do self attention as normal, but then reshape the batch into [-1, rtr_num+1, seq_len, hidden_size], and use the q, k, and v from the current layer to perform cross attention from item 0 to 1-rtr_num, and vis versa, to move information between the retrieved items and the query.
                tiled = x[0].unsqueeze(0).repeat(self.rtr_num, 1, 1)
                new_query = self.cross_attn[i//4](tiled, x[1:], x[1:])[0]
                # average across batch
                new_query = new_query.mean(dim=0)
                x = torch.cat([new_query.unsqueeze(0), x[1:]], dim=0)
                x = torch.permute(x, [1, 0, 2])
        x = self.project_out(x)
        x = self.head(x)
        x = torch.permute(x, [1, 0, 2])
        return x[0]

# create test model and pass a sample batch of 16
model = LlamaRetrofit(pig_mod)
# move to cuda

optimizer = torch.optim.Adam(model.cross_attn.parameters(), lr=1e-4)
optimizer.zero_grad()



In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(len(tokenized_train)):
    label = tokenized_train[i][0][0][1:]
    # add a padding token to the beginning of the label
    label = torch.cat([torch.tensor([1]), label])
    inp = torch.cat([tokenized_train[i][0], tokenized_train[i][1]])
    output = model(inp)
    # argmax and decode
    # get the first non-pad token id in the label
    first_nonpad = torch.argmax(label != 1)
    # cut off the output and label 1 before the first non-pad token
    output = output[first_nonpad:]
    label = label[first_nonpad:]
    loss = loss_fn(output, label)
    loss.backward()
    out_nums = torch.argmax(output, dim=-1)
    stringy = pig_tok.decode(out_nums)
    stringy_in = pig_tok.decode(label)
    print(stringy)
    print(stringy_in)
    if i % 16 == 0:
        optimizer.step()
        optimizer.zero_grad()
        print(loss.item())