In [1]:
# General imports
import os
import random
import math
import itertools
import pandas as pd
from tqdm import tqdm
import numpy as np
from datasets import load_metric

# pytorch imports
import torch
import pytorch_lightning as pl
import torchmetrics
from torch.utils.data import Dataset, DataLoader

# Transformer tokenizer imports
from transformers import BertTokenizerFast

# Transformers Bert model
from transformers import BertModel, BertForPreTraining, Trainer, TrainingArguments, EarlyStoppingCallback, BertConfig

MAX_SEQ_LEN = 512

In [2]:
# GPU settings
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0"

import random
from datetime import datetime
random.seed(datetime.now())

since Python 3.9 and will be removed in a subsequent version. The only 
supported seed types are: None, int, float, str, bytes, and bytearray.
  random.seed(datetime.now())


In [3]:
def load_tokenizer(tokenizer_path):
    # load tokenizer from dict
    tokenizer =  BertTokenizerFast.from_pretrained(tokenizer_path)
    return tokenizer

In [4]:
class SimDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path, tokenizer):
        
        self.data_store = []
        df = pd.read_csv(dataset_path, sep="\t").fillna('')
        #self.samples = df.iloc[:, [2, 5, 6]]
        self.samples = df[["ot_strand_anchor", "ot_strand_pos", "ot_strand_neg"]]
        self.tokenizer = tokenizer
            
        self.__init_structures()
         
    def __init_structures(self):
        
        for first_asm, second_asm, third_asm in tqdm(self.samples.values):
            first_asm_example = self.tokenizer(text=" ".join(first_asm.split(" NEXT_I ")), truncation=True, max_length=MAX_SEQ_LEN)
            second_asm_example = self.tokenizer(text=" ".join(second_asm.split(" NEXT_I ")), truncation=True, max_length=MAX_SEQ_LEN)
            third_asm_example = self.tokenizer(text=" ".join(third_asm.split(" NEXT_I ")), truncation=True, max_length=MAX_SEQ_LEN)
            
            example = {"anchor": first_asm_example,
                      "pos": second_asm_example,
                      "neg": third_asm_example}
            
            self.data_store.append(example)
            
        random.shuffle(self.data_store)
                
    def __len__(self) -> int:
        return len(self.data_store)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data_store[idx]
    
    def save_to_file(self,save_file):
        torch.save(self.data_store, save_file)

In [5]:
class AsmDataModule(pl.LightningDataModule):

    def __init__(self, train_path, val_path, test_path, batch_size, tokenizer):
        
        super().__init__()
        
        self.train_path = train_path
        self.val_path = val_path
        self.test_path = test_path
        self.tokenizer = tokenizer
        
        self.batch_size = batch_size

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):
        
        if stage == 'fit':
            self.train_dataset = SimDataset(self.train_path, self.tokenizer,)
            self.val_dataset   = SimDataset(self.val_path, self.tokenizer)
            
        elif stage == 'test':
            self.test_dataset = SimDataset(self.test_path, self.tokenizer)
            
    def __pad(self, samples):
        
        all_tokens_ids = []
        all_masks = []
        
        max_tok_batch = max([len(block["input_ids"]) for block in samples])
        
        for block in samples:
             
            num_pad_inst = max_tok_batch - len(block["input_ids"])
            all_tokens_ids.append(block["input_ids"] + [self.tokenizer.pad_token_id] * num_pad_inst)
            all_masks.append(block["attention_mask"] + [0] * num_pad_inst)
            
        return {"input_ids": torch.tensor(all_tokens_ids, device="cuda"),
               "attention_mask": torch.tensor(all_masks, device="cuda")}
        

    def collate_with_padding(self, batch):
        
        first_p  = [elem["anchor"] for elem in batch]
        second_p = [elem["pos"] for elem in batch]
        third_p = [elem["neg"] for elem in batch]
        
        first_p = self.__pad(first_p)
        second_p = self.__pad(second_p)
        third_p = self.__pad(third_p)
        
        batch_result = {"anchor": first_p,
                        "pos": second_p,
                        "neg": third_p}
        
        return batch_result

    def train_dataloader(self, *args, **kwargs):    
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size, 
                          collate_fn=self.collate_with_padding)

    def val_dataloader(self, *args, **kwargs):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size, 
                          collate_fn=self.collate_with_padding)

    def test_dataloader(self, *args, **kwargs):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size, 
                          collate_fn=self.collate_with_padding)

In [6]:
base_path = "../../"
prt_model = os.path.join(base_path, "models", "pretraining_model", "checkpoint-67246")
train_path = os.path.join(base_path, "dataset", "finetuning_dataset", "similarity", "strands", "train_strands_similarity_triplets.csv")
val_path = os.path.join(base_path, "dataset", "finetuning_dataset", "similarity", "strands", "val_strands_similarity_triplets.csv")
tokenizer_path = os.path.join(base_path, "tokenizer")

model_name =  f"BinBert_strand_similarity"
output_model_path = os.path.join(base_path, "models", "finetuned_models", "similarity", "strands", model_name)

In [7]:
from_scratch = True

LEARNING_RATE = 0.00001
BATCH_SIZE = 8

NUM_TRAIN_EPOCHS = 20
PER_DEVICE_TRAIN_BATCH_SIZE = 16
PER_DEVICE_EVAL_BATCH_SIZE = 16
DATA_LOADER_NUM_WORKERS = 4
PATIENCE = 3

#models
BXSMAL="bert_xsmall"
BSMAL="bert_small"
BNORM="bert_normal"
BLARG="bert_larg"

MODEL=BNORM

if MODEL == BXSMAL:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 128
    INTERMEDIATE_SIZE = 1024
    NUM_ATTENTION_HEADS = 8
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2

if MODEL == BSMAL:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 512
    INTERMEDIATE_SIZE = 2048
    NUM_ATTENTION_HEADS = 8
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2
    
if MODEL == BNORM:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 768
    INTERMEDIATE_SIZE = 3072
    NUM_ATTENTION_HEADS = 12
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2    
    
if MODEL == BLARG:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 1024
    INTERMEDIATE_SIZE = 4096
    NUM_ATTENTION_HEADS = 16
    NUM_HIDDEN_LAYERS = 24
    TYPE_VOCAB_SIZE = 2

In [8]:
class SiameseFinenuting(pl.LightningModule):

    def __init__(self, model_path, batch_size, vocab=None):
        
        super().__init__()
        
        self.batch_size = batch_size
        
        # Model
        if from_scratch:
            print("From scratch")
            config = BertConfig(
                vocab_size = len(vocab),
                max_position_embeddings = MAX_POSITION_EMBEDDINGS,
                hidden_size = HIDDEN_SIZE,
                intermediate_size = INTERMEDIATE_SIZE,
                num_attention_heads = NUM_ATTENTION_HEADS,
                num_hidden_layers = NUM_HIDDEN_LAYERS,
                type_vocab_size = TYPE_VOCAB_SIZE,
                output_hidden_states=True
            )
            self.model = BertModel(config=config)
        else:
            self.model = BertModel.from_pretrained(model_path, output_hidden_states=True)
        
        # criterion
        self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        
        self.cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        self.train_auc = torchmetrics.AUROC()
        self.val_auc = torchmetrics.AUROC()

        
    def forward(self, pairs_asm_input):
        
        first = pairs_asm_input["anchor"]
        second = pairs_asm_input["pos"]
        third = pairs_asm_input["neg"]
        
        first_output = self.model(**first)
        second_output = self.model(**second)
        third_output = self.model(**third)
        
        first_hidden_states = first_output.hidden_states[-1]
        second_hidden_states = second_output.hidden_states[-1]
        third_hidden_states = third_output.hidden_states[-1]

        first_masks = first['attention_mask']
        second_masks = second['attention_mask']
        third_masks = third['attention_mask']

        first_partial_mul = first_hidden_states * first_masks.unsqueeze(-1)
        second_partial_mul = second_hidden_states * second_masks.unsqueeze(-1)
        third_partial_mul = third_hidden_states * third_masks.unsqueeze(-1)
        
        first_partial_sum = torch.sum(first_partial_mul, dim=1)
        second_partial_sum = torch.sum(second_partial_mul, dim=1)
        third_partial_sum = torch.sum(third_partial_mul, dim=1)
        
        first_n = torch.sum(first_masks, dim=1)
        second_n = torch.sum(second_masks, dim=1)
        third_n = torch.sum(third_masks, dim=1)
        
        first_embeddings = first_partial_sum / first_n.unsqueeze(-1)
        second_embeddings = second_partial_sum / second_n.unsqueeze(-1)
        third_embeddings = third_partial_sum / third_n.unsqueeze(-1)
        
        cosines_pos = self.cosine(first_embeddings, second_embeddings)
        cosines_neg = self.cosine(first_embeddings, third_embeddings)
        
        result = dict()
        result['first_embedding'] = first_embeddings
        result['second_embedding'] = second_embeddings
        result['third_embedding'] = third_embeddings
        result['cos_pos'] = cosines_pos
        result['cos_neg'] = cosines_neg
        
        return result

    def training_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        loss = self.triplet_loss(forward_output['first_embedding'], 
                                 forward_output['second_embedding'], 
                                 forward_output['third_embedding'])
         
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        pred_pos = forward_output['cos_pos']
        pred_neg = forward_output['cos_neg']
        labels_pos = torch.tensor([1]*pred_pos.shape[0]).to("cuda")
        labels_neg = torch.tensor([0]*pred_neg.shape[0]).to("cuda")
        
        m = self.train_auc(torch.cat((pred_pos,pred_neg),0), torch.cat((labels_pos,labels_neg),0))
        
        self.log('train_auc', m, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss, 'train_auc':m}

    def validation_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        loss = self.triplet_loss(forward_output['first_embedding'], 
                                 forward_output['second_embedding'], 
                                 forward_output['third_embedding'])
         
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        pred_pos = forward_output['cos_pos']
        pred_neg = forward_output['cos_neg']
        labels_pos = torch.tensor([1]*pred_pos.shape[0]).to("cuda")
        labels_neg = torch.tensor([0]*pred_neg.shape[0]).to("cuda")
        
        m = self.train_auc(torch.cat((pred_pos,pred_neg),0), torch.cat((labels_pos,labels_neg),0))
        
        self.log('val_auc', m, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss, 'val_auc':m}

    def test_step(self, batch, batch_idx):
        forward_output = self.forward(batch)
        
        loss = self.triplet_loss(forward_output['first_embedding'], 
                                 forward_output['second_embedding'], 
                                 forward_output['third_embedding'])
         
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        return {"loss":loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        return optimizer

In [9]:
def get_trainer(ckpt_dir):

    early_stopping = pl.callbacks.EarlyStopping(
        monitor='val_auc',
        patience=3,
        verbose=True,
        mode='max', # wheter we want to maximize (max) or minimize the "monitor" value.
    )

    check_point_callback = pl.callbacks.ModelCheckpoint(
        monitor='val_auc',
        verbose=True,
        save_top_k=1,
        mode='max', # wheter we want to maximize (max) or minimize the "monitor" value.
        dirpath=ckpt_dir,
        filename='{epoch}-{val_auc:.4f}',
        # save_weights_only = True
    )

    # the PyTorch Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=NUM_TRAIN_EPOCHS,
        gpus=1,
        progress_bar_refresh_rate=5,
        # callbacks=[early_stopping, check_point_callback]
        callbacks=[check_point_callback]
    )

    return trainer

In [10]:
tokenizer = load_tokenizer(tokenizer_path)
data_module = AsmDataModule(train_path, val_path, None, BATCH_SIZE, tokenizer)

In [11]:
print(prt_model)
simaese_model = SiameseFinenuting(prt_model, BATCH_SIZE, vocab=tokenizer.vocab if from_scratch else None)

/home/jovyan/work/olivetree/final_for_paper/models/next_sentence_prediction_bert_normal_mask30/checkpoint-67246
From scratch




In [12]:
trainer = get_trainer(ckpt_dir=output_model_path)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model=simaese_model, datamodule=data_module)

100% 39978/39978 [00:11<00:00, 3381.28it/s]
100% 9996/9996 [00:02<00:00, 3717.82it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type              | Params
---------------------------------------------------
0 | model        | BertModel         | 92.0 M
1 | triplet_loss | TripletMarginLoss | 0     
2 | cosine       | CosineSimilarity  | 0     
3 | train_auc    | AUROC             | 0     
4 | val_auc      | AUROC             | 0     
---------------------------------------------------
92.0 M    Trainable params
0         Non-trainable params
92.0 M    Total params
368.176   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4997: val_auc reached 0.96934 (best 0.96934), saving model to "/home/jovyan/work/olivetree/final_for_paper/tests/similarity/strands/fine_tuned_models/olivetree/from_scratch_normal_triplet_loss/epoch=0-val_auc=0.9693.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 9995: val_auc reached 0.97926 (best 0.97926), saving model to "/home/jovyan/work/olivetree/final_for_paper/tests/similarity/strands/fine_tuned_models/olivetree/from_scratch_normal_triplet_loss/epoch=1-val_auc=0.9793.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 14993: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 19991: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 24989: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 29987: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 34985: val_auc reached 0.98279 (best 0.98279), saving model to "/home/jovyan/work/olivetree/final_for_paper/tests/similarity/strands/fine_tuned_models/olivetree/from_scratch_normal_triplet_loss/epoch=6-val_auc=0.9828.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 39983: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 44981: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 49979: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 54977: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 59975: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 64973: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 69971: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 74969: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 79967: val_auc was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 84965: val_auc reached 0.98720 (best 0.98720), saving model to "/home/jovyan/work/olivetree/final_for_paper/tests/similarity/strands/fine_tuned_models/olivetree/from_scratch_normal_triplet_loss/epoch=16-val_auc=0.9872.ckpt" as top 1


In [None]:
##To convert Model Output
config = BertConfig(
            vocab_size = len(tokenizer.vocab),
            max_position_embeddings = MAX_POSITION_EMBEDDINGS,
            hidden_size = HIDDEN_SIZE,
            intermediate_size = INTERMEDIATE_SIZE,
            num_attention_heads = NUM_ATTENTION_HEADS,
            num_hidden_layers = NUM_HIDDEN_LAYERS,
            type_vocab_size = TYPE_VOCAB_SIZE)
model = BertModel(config=config)

In [None]:
model_torch = torch.load(os.path.join(output_model_path, "epoch=16-val_auc=0.9872.ckpt"))

In [None]:
new_dict = dict()
for k in model_torch["state_dict"]:
    new_dict[k.replace("model.","")] = model_torch["state_dict"][k]
del model_torch

In [None]:
model.load_state_dict(new_dict)

In [None]:
model.save_pretrained(os.path.join(output_model_path, "epoch-16"))