In [1]:
! pip install transformers > /dev/null
! pip install tqdm > /dev/null

In [2]:
#from google.colab import drive
#drive.mount('/content/drive')

In [3]:
import sys
sys.path.append("/home/jovyan/work/olivetree/palmtree/code/")

In [4]:
# General imports
import os
import json
import random
import itertools
from tqdm import tqdm

# pytorch imports
import torch
import torchmetrics
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, MetricCollection, Precision, Recall

# Transformer imports
from transformers import BertTokenizerFast
from transformers import BertForPreTraining, BertForMaskedLM

# Palmtree imports
from palmtree import dataset
from palmtree import model

`fused_weight_gradient_mlp_cuda` module not found. gradient accumulation fusion with weight gradient computation disabled.


In [5]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [21]:
# in common
base_path = "/home/jovyan/work/olivetree/final_for_paper"

base_data_path = os.path.join(base_path, "tests_cagate", "similarity", "strands", "data")
test_name = "test_strands_equivalence_others_DEFINITIVE_nDCG.csv"
test_path = os.path.join(base_data_path, "definitive", test_name)
base_res_path = os.path.join(base_path, "tests_cagate", "similarity", "strands", "search_results")

# for olivetree
base_olivetree = os.path.join(base_path, "models")
base_olivetree_finetuned = os.path.join(base_path, "tests_cagate", "similarity", "strands", "fine_tuned_models", "olivetree")
tokenizer_path = os.path.join(base_path, "tokenizer")
w_tokenizer_path = os.path.join(base_path, "whitespace_tokenizer")
w2_tokenizer_path = os.path.join(base_path, "whitespace_tokenizer2")
u_tokenizer_path = os.path.join(base_path, "unigram_tokenizer")

# for palmtree
base_palmtree = os.path.join(base_path, "..", "palmtree", "models")
base_palmtree_finetuned = os.path.join(base_path, "tests_cagate", "similarity", "strands", "fine_tuned_models", "palmtree")
vocab_path = os.path.join(base_path, "..", "palmtree", "data", "palmtree_complete_dataset", "vocab")

olivetree_n_layers = 13
palmtree_n_layers = 12

ol_focus_layer = 11 # 0-based
pt_focus_layer = 10 # 0-based

In [7]:
import pandas as pd
import numpy as np

df = pd.read_csv(test_path, sep='\t')
df = df.where(pd.notnull(df), None)

queries = df[~df['ground_truth'].isnull()]
gts = [json.loads(x) for x in df["ground_truth"].to_list() if x is not None]
lens = list(map(lambda x: len(x), gts))
avg = sum(lens) / len(gts)

print(df.keys())
print(len(df))
print(len(queries))
print(avg, max(lens), min(lens))

Index(['strand_id', 'ot_strand', 'pt_strand', 'ground_truth'], dtype='object')
42547
4942
8.611290975313638 30 5


In [8]:
print(df.shape)
df.head()

(42547, 4)


Unnamed: 0,strand_id,ot_strand,pt_strand,ground_truth
0,14940295,"mov rax, qword ptr [rbp - 1120] NEXT_I mov qwo...",mov rax [ rbp - 0x460 ] NEXT_I mov [ rbp - 0x4...,
1,27682307,"mov rax, qword ptr [rbp - 1120] NEXT_I mov rdi...",mov rax [ rbp - 0x460 ] NEXT_I mov rdi rax NEX...,"[""14940295"", ""27682307"", ""25128455"", ""10593229..."
2,25128455,"mov rax, qword ptr [rbp - 1120] NEXT_I mov rdi...",mov rax [ rbp - 0x460 ] NEXT_I mov rdi rax NEX...,
3,10593229,"mov rax, qword ptr [rbp - 1120] NEXT_I mov rdi...",mov rax [ rbp - 0x460 ] NEXT_I mov rdi rax NEX...,
4,23016934,"mov rax, qword ptr [rbp - 1120] NEXT_I mov rdi...",mov rax [ rbp - 0x460 ] NEXT_I mov rdi rax NEX...,


In [9]:
class HParams():
    lstm_hidden_size = 128
    lstm_num_layers  = 2
    lstm_dropout     = 0
    learning_rate    = 0.0001
    dropout          = 0.2

class SiameseFinenuting(pl.LightningModule):

    def __init__(self, hparams, vocab, pt_model_path=None, batch_size=64):
        
        super().__init__()
        
        self.batch_size = batch_size
        # self.hparams = hparams
        
        self.pt_embedding = torch.load(pt_model_path)
        
        self.lstm  = torch.nn.LSTM(input_size   = self.pt_embedding.hidden, 
                           hidden_size  = hparams.lstm_hidden_size,
                           num_layers   = hparams.lstm_num_layers,
                           dropout      = hparams.lstm_dropout,
                           batch_first  = True,
                           bidirectional= True)
        
        self.cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

        # Criterion
        self.loss = torch.nn.MSELoss()

        # metrics
        self.train_auc = torchmetrics.AUROC()
        self.val_auc = torchmetrics.AUROC()
        
    def samples_embedding(self, batch):
        
        token_ids = batch["token_ids"]
        num_strands_ins = batch["num_strands_ins"]
        masks = batch["masks"]
        seq_lens = batch["seq_lens"]
    
        outputs = self.pt_embedding.encode_last(token_ids, torch.tensor(1, device="cuda"))
        
        # instruction embeddings
        denom = torch.sum(masks, -1, keepdim=True)
        denom[denom==0] = 1 # avoid zero denominator
        instructions_emb = torch.sum(outputs * masks.unsqueeze(-1), dim=1) / denom
        strands_outputs = torch.stack(torch.split(instructions_emb,num_strands_ins))
        
        # sample embedding
        # len_mask = (torch.arange(max(seq_lens),device="cuda").expand(len(seq_lens), max(seq_lens)) < seq_lens.unsqueeze(1)).type(torch.uint8)
        # denom = torch.sum(len_mask, -1, keepdim=True)
        # denom[denom==0] = 1 # avoid zero denominator
        # sample_emb = torch.sum(strands_outputs * len_mask.unsqueeze(-1), dim=1) / denom
        strands_outputs = torch.nn.utils.rnn.pack_padded_sequence(strands_outputs, 
                                                          seq_lens, 
                                                          batch_first=True, 
                                                          enforce_sorted=False)
        
        outputs, _ = lstm_output = self.lstm(strands_outputs)
        
        unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(outputs, 
                                                                        batch_first=True)
        
        sum_unpacked = unpacked.sum(1)
        avgs = sum_unpacked/unpacked_len.unsqueeze(-1).to("cuda")
        
        return avgs
         
    def forward(self, asm_input):

        result = {}

        first_embeddings = self.samples_embedding(asm_input["first"])
        second_embeddings = self.samples_embedding(asm_input["second"])
        
        cosines = self.cosine(first_embeddings, second_embeddings)
        
        result['prediction'] = cosines
        result['labels'] = asm_input["labels"]
        
        return result

    def training_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        prediction = forward_output["prediction"]
        labels = forward_output["labels"]
        
        loss = self.loss(prediction, labels.float())
        
        l2 = labels.clone()
        l2[l2==-1]=0
        # self.train_auc.update(prediction, l2)
        m = self.train_auc(prediction, l2)
         
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log('train_auc', m, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss,
                "train_auc":m}

    def validation_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        prediction = forward_output["prediction"]
        labels = forward_output["labels"]
        
        loss = self.loss(prediction, labels.float())

        l2 = labels.clone()
        l2[l2==-1]=0
        # self.val_auc.update(prediction, l2)
        m = self.val_auc(prediction, l2)
    
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log('val_auc', m, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss,
                "val_auc":m}

    def test_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        prediction = forward_output["prediction"]
        labels = forward_output["labels"]
        
        loss = self.loss(prediction, labels.float())

        l2 = labels.clone()
        l2[l2==-1]=0
        # self.val_auc.update(prediction, l2)
        m = self.val_auc(prediction, l2)
    
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log('val_auc', m, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss,
                "val_auc":m}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        return optimizer

In [10]:
def parse_ss_search_csv(df, start, end, is_olivetree=True):
    rows = []
    
    df_slice = df.iloc[start:end, 0:df.shape[1]]
  
    for _, row in df_slice.iterrows():
    
        if is_olivetree:
            info = [" ".join(row["ot_strand"].split(" NEXT_I "))]
        else:
            info = row["pt_strand"].split(" NEXT_I ")
    
        if row[3] is not None:
            g_t = json.loads(row["ground_truth"])
            g_t = list(map(lambda x: int(x), g_t))
        else:
            g_t = None

        parsed_row = (int(row["strand_id"]), info, g_t)
        rows.append(parsed_row)
    
    return rows

In [11]:
def load_olivetree_model(tokenizer_path, best_checkpoint, mlm=False):
    
    print("Loading Tokenizer ->", tokenizer_path)
    tokenizer =  BertTokenizerFast.from_pretrained(tokenizer_path)

    print("Loading Model ->", best_checkpoint)
    if not mlm:
        model = BertForPreTraining.from_pretrained(best_checkpoint, output_hidden_states=True)
    else:
        model = BertForMaskedLM.from_pretrained(best_checkpoint, output_hidden_states=True)

    model.to("cuda")
    model.eval()

    return tokenizer, model

In [12]:
def load_palmtree_model(vocab_path, best_checkpoint):
      
    print("Loading Vocab ->", vocab_path)
    vocab = dataset.WordVocab.load_vocab(vocab_path)

    print("Loading Model ->", best_checkpoint)
    bert = model.BERT(len(vocab), hidden=128, n_layers=12, attn_heads=8, dropout=0.0)
    bert_model = torch.load(best_checkpoint)

    bert_model.eval()
    bert_model.to("cuda")
    
    return vocab, bert_model

In [13]:
def load_palmtree_lstm_model(vocab_path, best_checkpoint):
      
    print("Loading Vocab ->", vocab_path)
    vocab = dataset.WordVocab.load_vocab(vocab_path)

    print("Loading Model ->", best_checkpoint)
    model = SiameseFinenuting(hparams=HParams(), vocab=vocab, pt_model_path=os.path.join(base_palmtree, "complete_palmtree_model", "transformer.ep0"))
    checkpoint = torch.load(best_checkpoint)
    model.load_state_dict(checkpoint["state_dict"])
    
    model.eval()
    model.to("cuda")
    
    return vocab, model

In [14]:
# olivetree
from math import ceil
from tqdm.notebook import tqdm

def compute_embeddings_olivetree(tokenizer_path, model_path, df, layer, batch_size, is_cls=False):

    tokenizer, model = load_olivetree_model(tokenizer_path, model_path, mlm = False if "next" in model_path else True)
    n_iterations = ceil(df.shape[0]/batch_size)

    embeddings_list = []

    print(f"Selected batch size {batch_size} and n iterations {n_iterations}")

    for i in tqdm(range(n_iterations)):
        start = i * batch_size
        end = (i+1) * batch_size

        strands = parse_ss_search_csv(df, start, end, is_olivetree=True)

        batch_strands = []
        for s in strands:
            batch_strands.extend(s[1])
    
        tokenized_strands = tokenizer(batch_strands, padding=True, truncation=True, max_length=512)

        for key, value in tokenized_strands.items():
            tokenized_strands[key] = torch.tensor(value, device="cuda")

        with torch.no_grad():
            output = model(**tokenized_strands)

        #for l, layer in enumerate(layers):
            # for each layer computing the embeddings of the batch
        strands_embeddings = []
        for i, hidden_layer in enumerate(output.hidden_states[layer]):
            if is_cls:
                strand_embedding = strand_embedding = hidden_layer[0]
            else:
                strand_embedding = torch.mean(hidden_layer[tokenized_strands['attention_mask'][i].type(torch.bool)], 0)
            strands_embeddings.append(strand_embedding)

        embeddings_list.extend(strands_embeddings)

        del tokenized_strands
        del output
        torch.cuda.empty_cache()

    return df['strand_id'].tolist(), embeddings_list

In [15]:
# palmtree
from math import ceil
from tqdm.notebook import tqdm

def compute_embeddings_palmtree(vocab_path, model_path, df, layer, batch_size):

    vocab, model = load_palmtree_model(vocab_path, model_path)
    n_iterations = ceil(df.shape[0]/batch_size)

    embeddings_list = []

    print(f"Selected batch size {batch_size} and n iterations {n_iterations}")

    for i in tqdm(range(n_iterations)):
        start = i * batch_size
        end = (i+1) * batch_size

        strands = parse_ss_search_csv(df, start, end, is_olivetree=False)
    
        batch_strands_instr = []
        batch_strands_sizes = []
        for s in strands:
            batch_strands_instr.extend(s[1])
            batch_strands_sizes.append(len(s[1]))
        
        tokenized_instr = []
        max_batch = max([len(c1.split()) for c1 in batch_strands_instr])
        
        for c1 in batch_strands_instr:
            tok_instr = [vocab.stoi.get(c, vocab.unk_index) for c in c1.split()]
            num_pad = max_batch - len(tok_instr)
            tokenized_instr.append([vocab.sos_index] + tok_instr + [vocab.eos_index] + [vocab.pad_index] * num_pad)
        
        tokenized_instr = torch.tensor(tokenized_instr).to("cuda")
        masks = torch.tensor(tokenized_instr>0, dtype=torch.int32)
        
        with torch.no_grad():
            layers_hidden_states = model.all_layers_encode(tokenized_instr, torch.tensor(1).to("cuda"))
    
        #for l, layer in enumerate(layers):
        instructions_embeddings = []
        for i, hidden_layer in enumerate(layers_hidden_states[layer]):
            instructions_embedding = torch.mean(hidden_layer[masks[i].type(torch.bool)], 0)
            instructions_embeddings.append(instructions_embedding)

        instructions_matrix = torch.stack(instructions_embeddings).to("cuda")
        strands_embeddings = torch.split(instructions_matrix, batch_strands_sizes)

        strands_embeddings = [bb.mean(dim=0) for bb in strands_embeddings]
        embeddings_list.extend(strands_embeddings)
    
        torch.cuda.empty_cache()
        
    return df['strand_id'].tolist(), embeddings_list

In [16]:
# palmtree finetuned
from math import ceil
from tqdm.notebook import tqdm

def _pad(elem, max_instr_batch, max_tok_batch, vocab):
        
        masks = []
        batch_token_ids = []
        
        num_pad_inst = max_instr_batch - len(elem)

        for j, instr in enumerate(elem):
            
            num_pad_tok = max_tok_batch - len(instr)
            batch_token_ids.append(instr + [vocab.pad_index] * num_pad_tok)
            masks.append([1] * len(instr) + [0] * num_pad_tok)

        batch_token_ids += [[vocab.pad_index] * max_tok_batch] * num_pad_inst
        masks += [[0] * max_tok_batch] * num_pad_inst
           
        return masks, batch_token_ids
    
def prepare_for_palmtree_lstm(strands, vocab):
    
    # firstly, tokenize each instruction of each strand
    # returns list (strands) of list (instructions) of list of tokens
    tok_instr_list = []
    for strand in strands:
        tok_instr = [[vocab.sos_index] + [vocab.stoi.get(tok, vocab.unk_index) for tok in instr.split()] + 
                     [vocab.eos_index] for instr in strand]
        tok_instr_list.append(tok_instr)
    
    # finds max num ins for strand
    max_instr_batch = max([len(strand) for strand in tok_instr_list])
    # finds max num tokens for instruction for strand
    max_tok_batch = max([len(instr) for strand in tok_instr_list for instr in strand])
    
    masks = []
    seq_lens = []
    batch_token_ids = []
    
    for elem in tok_instr_list:
            
        seq_lens.append(len(elem))
        s_masks, s_batch_token_ids = _pad(elem, max_instr_batch, max_tok_batch, vocab)
        masks.extend(s_masks)
        batch_token_ids.extend(s_batch_token_ids)
    
    batch_result = {"num_strands_ins": max_instr_batch, 
                    "token_ids": torch.tensor(batch_token_ids, device="cuda"),
                    "masks": torch.tensor(masks, device="cuda"),
                    "seq_lens":  seq_lens}
    
    return batch_result
    
def compute_embeddings_palmtree_lstm(vocab_path, model_path, df, batch_size):

    vocab, model = load_palmtree_lstm_model(vocab_path, model_path)
    n_iterations = ceil(df.shape[0]/batch_size)

    embeddings_list = []

    print(f"Selected batch size {batch_size} and n iterations {n_iterations}")

    for i in tqdm(range(n_iterations)):
        start = i * batch_size
        end = (i+1) * batch_size

        strands = parse_ss_search_csv(df, start, end, is_olivetree=False)

        #for each strand I have list of instructions
        batch_result = prepare_for_palmtree_lstm(list(map(lambda x: x[1], strands)), vocab)
        
        with torch.no_grad():
            embeddings = model.samples_embedding(batch_result)
    
        embeddings_list.extend(embeddings)
        
        torch.cuda.empty_cache()
        
    return df['strand_id'].tolist(), embeddings_list

In [17]:
# class for making fast cosine similarity
from operator import itemgetter 
import torch.nn.functional as F

class SearchEngine:
    
    def __init__(self, checkpoint, ids_filename, layer):
        self.layer = layer
        self.ids = json.load(open(os.path.join(checkpoint, ids_filename), "r"))
        self.matrix = torch.load(os.path.join(checkpoint, f"embeddings_layer_{layer}.pt"))
        
        #for layer in range(n_layers):
        self.matrix.to("cuda")
        
    #def find_top_k_all(self, id_to_query, k):
    #    res = []
    #    scores = []
    #    idx = self.ids.index(id_to_query)
        
    #    for layer in range(self.n_layers):
    #        embedding_to_query = torch.clone(self.matrix_by_layer[layer][idx])
    #        embedding_to_query.to("cuda")

    #        dist = F.cosine_similarity(self.matrix_by_layer[layer], embedding_to_query)
        
    #        index_sorted = torch.argsort(dist, descending=True)
    #        top_k = index_sorted[:k]
            
    #        res.append([self.ids[k] for k in top_k])
    #        scores.append([dist[k].item() for k in top_k])
            
    #    return res, scores
    
    def find_top_k(self, id_to_query, k):
        res = []
        scores = []
        idx = self.ids.index(id_to_query)
        
        embedding_to_query = torch.clone(self.matrix[idx])
        embedding_to_query.to("cuda")

        dist = F.cosine_similarity(self.matrix, embedding_to_query)

        index_sorted = torch.argsort(dist, descending=True)
        top_k = index_sorted[:k]

        res.extend([self.ids[k] for k in top_k])
        scores.extend([dist[k].item() for k in top_k])
        
        return res, scores
    
    def find_top_k_batch(self, ids_to_query, k):
        # list of lists
        res = []
        scores = []
        
        idxs = []
        for idx in ids_to_query:
            idxs.append(self.ids.index(idx))
        
        m_query = []
        for idx in idxs:
            m_query.append(torch.clone(self.matrix[idx]))
        
        m_query = torch.stack(m_query)
        m_query.to("cuda")
        
        a_norm = self.matrix / self.matrix.norm(dim=-1)[:, None]
        b_norm = m_query / m_query.norm(dim=-1)[:, None]
        
        # dist[i][j] = dist(b_norm[i] * a_norm[j])
        dist = torch.mm(b_norm, a_norm.transpose(0,1))
        
        for i in range(len(ids_to_query)):
            index_sorted = torch.argsort(dist[i], descending=True)
            top_k = index_sorted[:k]

            res.append([self.ids[k] for k in top_k])
            scores.append([dist[i][k].item() for k in top_k])
        
        return res, scores

In [18]:
from tqdm.notebook import tqdm

def find_top_k_similar(data_path, df, batch_size, focus_layer, k):

    queries = df[~df['ground_truth'].isnull()]
    SE = SearchEngine(data_path, "strands_ids.json", focus_layer)
    
    p_bar = tqdm(total=len(queries))
    answers = []
    
    count = 0
    gts = []
    ids = []
    for i, (index, row) in enumerate(queries.iterrows()):
        ground_truth = json.loads(row['ground_truth'])
        ground_truth = list(map(lambda x: int(x), ground_truth))
        gts.append(ground_truth)
        ids.append(row["strand_id"])
        count += 1
        
        if len(ids) == batch_size or count == len(queries):
            top_k_lists, scores = SE.find_top_k_batch(ids, 200)
            
            true_labels = []
            for j in range(len(top_k_lists)):
                true_labels.append([1 if k_id in gts[j] else 0 for k_id in top_k_lists[j]])
            
            answers.extend([(true_labels[j], len(gts[j]), scores[j]) for j in range(len(top_k_lists))])
            
            p_bar.update(len(top_k_lists))
            gts, ids = [], []

    return answers

In [19]:
import math
import numpy as np
from tqdm.notebook import tqdm

def find_dcg(element_list):
    dcg_score = 0.0
    for j, sim in enumerate(element_list):
        dcg_score += float(sim) / math.log(j + 2)
    return dcg_score


def count_ones(element_list):
    return len([x for x in element_list if x == 1])


def extract_info(answers):
    
    # we analyze only one layer here
    performance1 = []
    average_recall_k1 = []
    precision_at_k1 = []
    
    for f_index in tqdm(range(0, len(answers))):
        # each data is a tuple of 3 elements:
        # - [0]: list of lists 1 and 0
        # - [1]: num of correct answers
        # - [2]: list of lists of similarity scores
        data = answers[f_index]

        f1 = data[0]
        pf1 = data[1]
        
        tp1 = []
        recall_p1 = []
        precision_p1 = []

        for k in range(1, len(f1)):
            cut1 = f1[0:k]
            dcg1 = find_dcg(cut1)
            ideal1 = find_dcg(([1] * (pf1) + [0] * (k - pf1))[0:k])

            p1k = float(count_ones(cut1))

            tp1.append(dcg1 / ideal1)
            recall_p1.append(p1k / pf1)
            precision_p1.append(p1k / k)

        performance1.append(tp1)
        average_recall_k1.append(recall_p1)
        precision_at_k1.append(precision_p1)
    
    avg_p1 = np.average(performance1, axis=0)
    avg_recall = np.average(average_recall_k1, axis=0)
    average_precision = np.average(precision_at_k1, axis=0)

    return list(avg_p1), list(avg_recall), list(average_precision)

In [20]:
infos = []

In [43]:
# olivetree

ot_lst = [
    #("next_sentence_prediction_bert_normal_mask30", "checkpoint-67246", False, False),
    #("masked_language_model_only_bert_normal_mask30", "checkpoint-95846", False, False),
    ("next_sentence_prediction_bert_normal_unigram_tok_mask30", "checkpoint-33623", False, False),
    ("next_sentence_prediction_bert_normal_whitespace_tok_mask30", "checkpoint-67246", False, False),
    ("next_sentence_prediction_bert_normal_whitespace2_tok_mask30", "checkpoint-33623", False, False),

    #("nsp_normal_mask30_triplet_loss", "epoch-10", True, False),
    #("mlm_normal_mask30_triplet_loss", "epoch-6", True, False),
    #("nsp_normal_unigram_mask30_triplet_loss", "epoch-4", True, False),
    #("nsp_normal_whitespace_mask30_triplet_loss", "epoch-4", True, False),
    #("nsp_normal_whitespace2_mask30_triplet_loss", "epoch-12", True, False),
    #("from_scratch_normal_triplet_loss", "epoch-16", True, False)
]

for olivetree_model_name, olivetree_checkpoint, is_fine_tuned, is_cls in ot_lst:
    if is_fine_tuned:
        olivetree_model_path = os.path.join(base_olivetree_finetuned, olivetree_model_name, olivetree_checkpoint)
    else:
        olivetree_model_path = os.path.join(base_olivetree, olivetree_model_name, olivetree_checkpoint)
    
    focus_layer = ol_focus_layer if not is_fine_tuned else -1
    
    if "unigram" in olivetree_model_name:
        t_path = u_tokenizer_path
    elif "whitespace2" in olivetree_model_name:
        t_path = w2_tokenizer_path
    elif "whitespace" in olivetree_model_name:
        t_path = w_tokenizer_path
    else:
        t_path = tokenizer_path
    
    ot_strands_ids, ot_embeddings_list = compute_embeddings_olivetree(t_path, olivetree_model_path, df, focus_layer, 128, is_cls=is_cls)
    ot_path = os.path.join(base_data_path, "embeddings", olivetree_model_name, olivetree_checkpoint)
    os.makedirs(ot_path, exist_ok=True)
    
    with open(os.path.join(ot_path, f"strands_ids.json"), "w") as f:
        json.dump(ot_strands_ids, f)
    
    #for layer_id in range(len(ot_embeddings_list)):
    embeddings = ot_embeddings_list
    matrix = torch.stack(embeddings)
    torch.save(matrix, os.path.join(ot_path, f"embeddings_layer_{focus_layer}.pt"))
    
    olivetree_answers = find_top_k_similar(ot_path, df, 1, focus_layer, 200)
    ol_avg_p1, ol_recal_p1, ol_pre_p1 = extract_info(olivetree_answers)
    
    with open(os.path.join(ot_path, f"{test_name}.json"), "w") as f:
        json.dump([ol_avg_p1, ol_recal_p1, ol_pre_p1], f)
        
    # infos.append((ol_avg_p1, ol_recal_p1, ol_pre_p1))

Loading Tokenizer -> /home/jovyan/work/olivetree/final_for_paper/unigram_tokenizer
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/models/next_sentence_prediction_bert_normal_unigram_tok_mask30/checkpoint-33623
Selected batch size 128 and n iterations 333


  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

Loading Tokenizer -> /home/jovyan/work/olivetree/final_for_paper/whitespace_tokenizer
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/models/next_sentence_prediction_bert_normal_whitespace_tok_mask30/checkpoint-67246
Selected batch size 128 and n iterations 333


  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

Loading Tokenizer -> /home/jovyan/work/olivetree/final_for_paper/whitespace_tokenizer2
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/models/next_sentence_prediction_bert_normal_whitespace2_tok_mask30/checkpoint-33623
Selected batch size 128 and n iterations 333


  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

In [23]:
# palmtree

pt_lst = [
    ("complete_palmtree_model", "transformer.ep0", False),
    ("epoch1_finetuning_lstm", "epoch=19-val_auc=0.9840.ckpt", True),
    ("from_scratch_finetuning_lstm", "epoch=19-val_auc=0.9720.ckpt", True)
]

for palmtree_model_name, palmtree_checkpoint, is_fine_tuned in pt_lst:
    
    focus_layer = pt_focus_layer if not is_fine_tuned else -1

    if is_fine_tuned:
        palmtree_model_path = os.path.join(base_palmtree_finetuned, palmtree_model_name, palmtree_checkpoint)
        pt_strands_ids, pt_embeddings_list = compute_embeddings_palmtree_lstm(vocab_path, palmtree_model_path, df, 128)
    else:
        palmtree_model_path = os.path.join(base_palmtree, palmtree_model_name, palmtree_checkpoint)
        pt_strands_ids, pt_embeddings_list = compute_embeddings_palmtree(vocab_path, palmtree_model_path, df, focus_layer, 128)
    
    pt_path = os.path.join(base_data_path, "embeddings", palmtree_model_name, palmtree_checkpoint)
    os.makedirs(pt_path, exist_ok=True)
    
    with open(os.path.join(pt_path, f"strands_ids.json"), "w") as f:
        json.dump(pt_strands_ids, f)
    
    #for layer_id in range(len(pt_embeddings_list)):
    embeddings = pt_embeddings_list
    matrix = torch.stack(embeddings)
    torch.save(matrix, os.path.join(pt_path, f"embeddings_layer_{focus_layer}.pt"))
    
    palmtree_answers = find_top_k_similar(pt_path, df, 1, focus_layer, 200)
    pt_avg_p1, pt_recal_p1, pt_pre_p1 = extract_info(palmtree_answers)
    
    with open(os.path.join(pt_path, f"{test_name}.json"), "w") as f:
        json.dump([pt_avg_p1, pt_recal_p1, pt_pre_p1], f)
        
    # infos.append((pt_avg_p1, pt_recal_p1, pt_pre_p1))

Loading Vocab -> /home/jovyan/work/olivetree/final_for_paper/../palmtree/data/palmtree_complete_dataset/vocab
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/../palmtree/models/complete_palmtree_model/transformer.ep0
Selected batch size 128 and n iterations 333


  0%|          | 0/333 [00:00<?, ?it/s]

  masks = torch.tensor(tokenized_instr>0, dtype=torch.int32)


  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

Loading Vocab -> /home/jovyan/work/olivetree/final_for_paper/../palmtree/data/palmtree_complete_dataset/vocab
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/tests_cagate/similarity/strands/fine_tuned_models/palmtree/epoch1_finetuning_lstm/epoch=19-val_auc=0.9840.ckpt
Selected batch size 128 and n iterations 333




  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

Loading Vocab -> /home/jovyan/work/olivetree/final_for_paper/../palmtree/data/palmtree_complete_dataset/vocab
Loading Model -> /home/jovyan/work/olivetree/final_for_paper/tests_cagate/similarity/strands/fine_tuned_models/palmtree/from_scratch_finetuning_lstm/epoch=19-val_auc=0.9720.ckpt
Selected batch size 128 and n iterations 333


  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

In [29]:
# TREX

for t_n, t_t in [("_intr_TREX.csv", "TREX_INTR"), ("_extr_TREX.csv", "TREX_EXTR")]:
    TREX_info = test_path.replace(".csv", t_n)
    TREX_path = os.path.join(base_data_path, "embeddings", t_t)

    df_trex = pd.read_csv(TREX_info, sep='\t')
    df_trex = df_trex.where(pd.notnull(df_trex), None)

    trex_functions_ids = df_trex['strand_id'].tolist()

    os.makedirs(TREX_path, exist_ok=True)
    with open(os.path.join(TREX_path, f"strands_ids.json"), "w") as f:
        json.dump(trex_functions_ids, f)

    trex_embeddings = df_trex['embedding'].to_list()

    trex_embeddings = list(map(lambda x: torch.from_numpy(np.array(json.loads(x))), trex_embeddings))
    matrix = torch.stack(trex_embeddings)
    torch.save(matrix, os.path.join(TREX_path, f"embeddings_layer_{-1}.pt"))

    trex_answers = find_top_k_similar(TREX_path, df_trex, 128, -1, 200)
    trex_avg_p1, trex_recal_p1, trex_pre_p1 = extract_info(trex_answers)

    with open(os.path.join(TREX_path, f"{test_name}.json"), "w") as f:
        json.dump([trex_avg_p1, trex_recal_p1, trex_pre_p1], f)

infos.append((trex_avg_p1, trex_recal_p1, trex_pre_p1))

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

  0%|          | 0/4942 [00:00<?, ?it/s]

In [24]:
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

def print_graph(scores, max_pos, file_name, label_y, titles, image_title=None):

    plt.clf()
    plt.grid()

    plt.xlabel("Number of Nearest Results")
    plt.ylabel(label_y)
    plt.ylim([0, 1])
    
    cmap = plt.get_cmap('brg')
    #colors = [cmap(i) for i in np.linspace(0.2, 0.9, len(titles))]
    
    for i, title in enumerate(titles):
        x = range(0, max_pos)
        y = scores[i][:max_pos]
        
        X_Y_Spline = make_interp_spline(x, y)
        X_ = np.linspace(0, max_pos, 500)
        Y_ = X_Y_Spline(X_)
        
        if title == "BinBert-FT" or title == "BinBert":
            color = "tab:blue"
        elif title == "BinBert-MLM-FT" or title == "BinBert-MLM":
            color = "tab:orange"
        elif title == "TREX-FT" or title == "TREX":
            color = "tab:green"
        elif title == "PalmTree-FT" or title == "PalmTree":
            color = "tab:red"
        else:
            color = "tab:purple"
            
        plt.plot(X_, Y_, label=title, color=color, linewidth=1)
    
    if image_title is not None:
        plt.title(image_title)
        
    plt.legend(prop={'size': 8})
    #plt.legend()
    plt.savefig(file_name, dpi=500, format="pdf")
    plt.close(file_name)

In [33]:
from texttable import Texttable

def print_table(info, titles):
    t = Texttable(max_width=1000)
    # idx = [4, 9, 14, 19, 29, 39, 49, 79, 129, 149, 198]
    idx = [9, 19, 39]
    hit_labels = ['model'] + [f'hit-{i+1}' for i in idx]
    t.add_row(hit_labels)
    
    for i, model in enumerate(titles):
        local_info = [info[i][j] for j in idx]
        local_info.insert(0, model)
        t.add_row(local_info)
        
    print(t.draw())

In [44]:
model_lists = [
    #("next_sentence_prediction_bert_normal_mask30", "checkpoint-67246", "BinBert"),
    #("masked_language_model_only_bert_normal_mask30", "checkpoint-95846", "BinBert-MLM"),
    ("next_sentence_prediction_bert_normal_unigram_tok_mask30", "checkpoint-33623", "BinBert-Unigram"),
    ("next_sentence_prediction_bert_normal_whitespace_tok_mask30", "checkpoint-67246", "BinBert-Whitespace"),
    ("next_sentence_prediction_bert_normal_whitespace2_tok_mask30", "checkpoint-33623", "BinBert-Whitespace2"),
    
    #("nsp_normal_mask30_triplet_loss", "epoch-10", "BinBert-FT"),
    #("mlm_normal_mask30_triplet_loss", "epoch-6", "BinBert-MLM-FT"),
    #("nsp_normal_unigram_mask30_triplet_loss", "epoch-4", "BinBert-FT-Unigram"),
    #("nsp_normal_whitespace_mask30_triplet_loss", "epoch-4", "BinBert-FT-Whitespace"),
    #("nsp_normal_whitespace2_mask30_triplet_loss", "epoch-12", "BinBert-FT-Whitespace2"),
    #("from_scratch_normal_triplet_loss", "epoch-16", "BinBert-FS"),
    
    #("complete_palmtree_model", "transformer.ep0", "PalmTree"),
    #("epoch1_finetuning_lstm", "epoch=19-val_auc=0.9840.ckpt", "PalmTree-FT"),
    # ("from_scratch_finetuning_lstm", "epoch=19-val_auc=0.9720.ckpt", ""),
    
    #("TREX_INTR", "", "TREX"),
    #("TREX_EXTR", "", "TREX-FT"),
]

titles = list(map(lambda x: x[2], model_lists))

res_names = [
    "test_strands_equivalence_others_DEFINITIVE_nDCG.csv.json", 
]

n, r, p = [], [], []
sn, sr, sp = [], [], []

for model_name, checkpoint, _ in model_lists:
    
    json_base_path = os.path.join(base_data_path, "embeddings", model_name, checkpoint)
    
    n_local, r_local, p_local = [], [], []
    for res_name in res_names:
        res_json = json.load(open(os.path.join(json_base_path, res_name)))
        n_local.append(res_json[0])
        r_local.append(res_json[1])
        p_local.append(res_json[2])
    
    n.append(np.mean(n_local, axis=0))
    r.append(np.mean(r_local, axis=0))
    p.append(np.mean(p_local, axis=0))
    
    sn.append(np.std(n_local, axis=0))
    sr.append(np.std(r_local, axis=0))
    sp.append(np.std(p_local, axis=0))
        
#print_graph(n, 50, os.path.join(base_res_path, "nDCG_strands_extrinsic.pdf"), "nDCG", 
#             titles, image_title="Extrinsic Similarity at Strands Level")
#print_graph(p, 50, os.path.join(base_res_path, f"precision_strands_extrinsic.pdf"), "Precision", 
#             titles, image_title="Extrinsic Similarity at Strands Level")
#print_graph(r, 50, os.path.join(base_res_path, f"recall_strands_extrinsic.pdf"), "Recall", 
#             titles, image_title="Extrinsic Similarity at Strands Level")

print("Results nDCG")
print_table(n, titles)
print()
print("Results Precision")
print_table(p, titles)
print()
print("Results Recall")
print_table(r, titles)

Results nDCG
+---------------------+--------+--------+--------+
| model               | hit-10 | hit-20 | hit-40 |
+---------------------+--------+--------+--------+
| BinBert-Unigram     | 0.590  | 0.600  | 0.622  |
+---------------------+--------+--------+--------+
| BinBert-Whitespace  | 0.630  | 0.641  | 0.661  |
+---------------------+--------+--------+--------+
| BinBert-Whitespace2 | 0.630  | 0.641  | 0.661  |
+---------------------+--------+--------+--------+

Results Precision
+---------------------+--------+--------+--------+
| model               | hit-10 | hit-20 | hit-40 |
+---------------------+--------+--------+--------+
| BinBert-Unigram     | 0.356  | 0.210  | 0.118  |
+---------------------+--------+--------+--------+
| BinBert-Whitespace  | 0.389  | 0.229  | 0.128  |
+---------------------+--------+--------+--------+
| BinBert-Whitespace2 | 0.389  | 0.230  | 0.128  |
+---------------------+--------+--------+--------+

Results Recall
+---------------------+--------+--