In [29]:
# General imports
import os
import random
import math
import json
import itertools
import pandas as pd
from tqdm import tqdm
import numpy as np
from datasets import load_metric

# pytorch imports
import torch

# Transformer tokenizer imports
from transformers import BertTokenizerFast

# Transformers data collator
from transformers.data.data_collator import DataCollatorWithPadding

# Transformers Bert model
from transformers import AutoModelForSequenceClassification, BertForPreTraining, Trainer, TrainingArguments, EarlyStoppingCallback, BertConfig

MAX_SEQ_LEN = 512

In [2]:
# GPU settings
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0,1,2"

In [3]:
def load_tokenizer(tokenizer_path):
    # load tokenizer from dict
    tokenizer =  BertTokenizerFast.from_pretrained(tokenizer_path)
    return tokenizer

In [4]:
class StrandExecDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path, tokenizer):
        
        self.data_store = []
        df = pd.read_csv(dataset_path, sep="\t").fillna('')
        self.samples = df[["ot_instructions", "ot_inputs", "ot_concrete_expr"]]
        self.tokenizer = tokenizer
            
        self.__init_structures()
         
    def __init_structures(self):
        
        for x, inputs, expr in tqdm(self.samples.values):
            inputs = json.loads(inputs.replace("\'", "\""))
            expr = json.loads(expr.replace("\'", "\""))
            y = ""
            for k in inputs:
                y += f" {k} = {inputs[k]}"
            for k in expr:
                y += f" {k}"
                label = expr[k]
            example = self.tokenizer(text=x, text_pair=y, truncation=True, max_length=MAX_SEQ_LEN)
            example["label"] = label
            self.data_store.append(example)
            
        random.shuffle(self.data_store)
                
    def __len__(self) -> int:
        return len(self.data_store)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data_store[idx]
    
    def save_to_file(self,save_file):
        torch.save(self.data_store, save_file)

# **Training**

In [5]:
from_scratch = True

#if from_scratch:
# LEARNING_RATE = 0.0001
#else:
LEARNING_RATE = 0.00001

NUM_TRAIN_EPOCHS = 20
PER_DEVICE_TRAIN_BATCH_SIZE = 32
PER_DEVICE_EVAL_BATCH_SIZE = 32
DATA_LOADER_NUM_WORKERS = 4
PATIENCE = 3

#models
BXSMAL="bert_xsmall"
BSMAL="bert_small"
BNORM="bert_normal"
BLARG="bert_larg"

MODEL=BNORM
if MODEL == BXSMAL:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 128
    INTERMEDIATE_SIZE = 1024
    NUM_ATTENTION_HEADS = 8
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2

if MODEL == BSMAL:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 512
    INTERMEDIATE_SIZE = 2048
    NUM_ATTENTION_HEADS = 8
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2
    
if MODEL == BNORM:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 768
    INTERMEDIATE_SIZE = 3072
    NUM_ATTENTION_HEADS = 12
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2    
    
if MODEL == BLARG:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 1024
    INTERMEDIATE_SIZE = 4096
    NUM_ATTENTION_HEADS = 16
    NUM_HIDDEN_LAYERS = 24
    TYPE_VOCAB_SIZE = 2

In [6]:
base_path = "../../"
prt_model = os.path.join(base_path, "models", "pretraining_model", "checkpoint-67246")
train_path = os.path.join(base_path, "dataset", "finetuning_dataset", "strand_execution", "train_concrete_execution.csv")
val_path = os.path.join(base_path, "dataset", "finetuning_dataset", "strand_execution", "val_concrete_execution.csv")
tokenizer_path = os.path.join(base_path, "tokenizer")

model_name =  f"BinBert_strand_execution"
output_dir = model_path = os.path.join(base_path, "models", "finetuned_models", "strand_execution", model_name)

/home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch


In [7]:
tokenizer = load_tokenizer(tokenizer_path)
train_dataset = StrandExecDataset(train_path, tokenizer)
val_dataset = StrandExecDataset(val_path, tokenizer)

100%|██████████| 40000/40000 [00:05<00:00, 7465.08it/s]
100%|██████████| 5000/5000 [00:00<00:00, 8314.18it/s]


In [8]:
labels = set()
for elem in train_dataset.data_store:
    labels.add(elem["label"])
print(labels)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200}


In [9]:
if not from_scratch:
    model = AutoModelForSequenceClassification.from_pretrained(prt_model, num_labels=len(labels))
else:
    config = BertConfig(
                vocab_size = len(tokenizer.vocab),
                max_position_embeddings = MAX_POSITION_EMBEDDINGS,
                hidden_size = HIDDEN_SIZE,
                intermediate_size = INTERMEDIATE_SIZE,
                num_attention_heads = NUM_ATTENTION_HEADS,
                num_hidden_layers = NUM_HIDDEN_LAYERS,
                type_vocab_size = TYPE_VOCAB_SIZE
    )
    config.num_labels = len(labels)
    model = AutoModelForSequenceClassification.from_config(config)

In [10]:
training_args = TrainingArguments(
                    output_dir = output_dir,
                    overwrite_output_dir = True,
                    num_train_epochs = NUM_TRAIN_EPOCHS,
                    learning_rate = LEARNING_RATE,
                    per_device_train_batch_size = PER_DEVICE_TRAIN_BATCH_SIZE,
                    save_strategy = 'epoch',
                    save_total_limit = 1,
                    logging_strategy = 'epoch',
                    # prediction_loss_only = True,
                    # fp16=True,
                    load_best_model_at_end = True,
                    do_eval = True,
                    evaluation_strategy = 'epoch',
                    metric_for_best_model = 'eval_accuracy',
                    per_device_eval_batch_size = PER_DEVICE_EVAL_BATCH_SIZE,
                    dataloader_num_workers = DATA_LOADER_NUM_WORKERS)

In [11]:
collator = DataCollatorWithPadding(tokenizer, padding=True)

In [12]:
metric = load_metric("accuracy")

In [13]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator = collator,
    compute_metrics=compute_metrics,
    # callbacks = [EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
)

In [15]:
trainer.train()

***** Running training *****
  Num examples = 40000
  Num Epochs = 20
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 96
  Gradient Accumulation steps = 1
  Total optimization steps = 8340


Epoch,Training Loss,Validation Loss,Accuracy
1,5.0481,4.961988,0.0234
2,4.8663,4.735435,0.0404
3,4.6897,4.576944,0.051
4,4.5346,4.420834,0.0638
5,4.3556,4.1852,0.1206
6,4.075,3.843431,0.2182
7,3.7741,3.580924,0.2268
8,3.5578,3.425163,0.2354
9,3.3982,3.280021,0.242
10,3.2728,3.1851,0.2452


***** Running Evaluation *****
  Num examples = 5000
  Batch size = 96
Saving model checkpoint to /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-417
Configuration saved in /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-417/config.json
Model weights saved in /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-417/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 96
Saving model checkpoint to /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-834
Configuration saved in /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/

TrainOutput(global_step=8340, training_loss=3.5433929516543015, metrics={'train_runtime': 1687.0445, 'train_samples_per_second': 474.202, 'train_steps_per_second': 4.944, 'total_flos': 7.97681615033618e+16, 'train_loss': 3.5433929516543015, 'epoch': 20.0})

# **Testing**

In [30]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

In [16]:
base_path = "../../"
test_path = os.path.join(base_path, "dataset", "finetuning_dataset", "strand_execution", "test_concrete_execution.csv")
tokenizer_path = os.path.join(base_path, "tokenizer")

model_name = "BinBert_strand_execution/checkpoint-12500"
model_path = os.path.join(base_path, "models", "finetuned_models", "strand_execution", model_name)
res_filename = os.path.join(base_path, "results", "strand_execution", model_name.replace(os.sep,"_"))


In [18]:
def load_binbert_model(best_checkpoint):

    print("Loading Model ->", best_checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(best_checkpoint, output_hidden_states=True)

    model.to("cuda")
    model.eval()

    return model

In [19]:
metric = load_metric("accuracy")

In [20]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [21]:
tokenizer = load_tokenizer(tokenizer_path)
model = load_binbert_model(model_path)

Didn't find file /home/jovyan/work/olivetree/final_for_paper/tokenizer/added_tokens.json. We won't load it.
loading file /home/jovyan/work/olivetree/final_for_paper/tokenizer/vocab.txt
loading file /home/jovyan/work/olivetree/final_for_paper/tokenizer/tokenizer.json
loading file None
loading file /home/jovyan/work/olivetree/final_for_paper/tokenizer/special_tokens_map.json
loading file /home/jovyan/work/olivetree/final_for_paper/tokenizer/tokenizer_config.json
loading configuration file /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-8340/config.json
Model config BertConfig {
  "_name_or_path": "/home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-8340",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "

Loading Model -> /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-8340


All model checkpoint weights were used when initializing BertForSequenceClassification.

All the weights of BertForSequenceClassification were initialized from the model checkpoint at /home/jovyan/work/olivetree/final_for_paper/tests/strand_execution/finetuned_models_new/olivetree/normal_strandexec_from_scratch/checkpoint-8340.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.


In [22]:
test_dataset = StrandExecDataset(test_path, tokenizer)
collator = DataCollatorWithPadding(tokenizer, padding=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=collator)

100%|██████████| 5000/5000 [00:00<00:00, 8250.29it/s]


In [23]:
def compute_test_accuracy(test_data_loader):
    logits = []
    references = []
    for batch in tqdm(test_data_loader):
        for elem in batch:
            batch[elem] = batch[elem].to("cuda")
        with torch.no_grad():
            pred = model(**batch).logits.cpu().detach().numpy()
        ref = batch["labels"].cpu().detach().numpy()
        logits.append(pred)
        references.append(ref)

    predictions = np.concatenate(logits)
    references = np.concatenate(references)
    accuracy = compute_metrics((predictions, references))["accuracy"]
    return accuracy, predictions, references

In [24]:
def compute_test_conf_matrix(predictions, references, res_filename, labels):

    cm = confusion_matrix(references, np.argmax(predictions, axis=-1))
    
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap="Blues")

    # labels, title and ticks
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels')
    ax.set_title(f'Confusion Matrix with Accuracy {round(accuracy,4)}')
    ax.xaxis.set_xticks(labels)
    ax.xaxis.set_ticklabels(labels)
    ax.yaxis.set_ticklabels(labels)
    ax.yaxis.set_yticks(labels)
    res_filename =  f"{res_filename}_acc_{round(accuracy,4)}.png"
    print(res_filename)
    plt.savefig(res_filename, dpi=300)

In [25]:
# sorted(list(labels))

In [26]:
accuracy, predictions, references = compute_test_accuracy(test_data_loader)
# compute_test_conf_matrix(predictions, references, res_filename, sorted(list(labels)))

100%|██████████| 79/79 [00:04<00:00, 16.09it/s]


In [27]:
res_filename =  f"{res_filename}_acc_{round(accuracy,4)}.txt"

In [31]:
a = classification_report(references, np.argmax(predictions,axis=1))
with open(res_filename, "w") as f:
    f.write(a)
print(a)

              precision    recall  f1-score   support

           0       0.74      0.91      0.82        46
           1       0.86      0.90      0.88        71
           2       0.40      0.55      0.46        53
           3       0.31      0.26      0.28        43
           4       0.28      0.49      0.36        47
           5       0.32      0.52      0.40        61
           6       0.52      0.44      0.47        32
           7       0.44      0.35      0.39        49
           8       0.21      0.43      0.29        46
           9       0.43      0.24      0.31        42
          10       0.14      0.18      0.16        44
          11       0.37      0.27      0.31        52
          12       0.13      0.13      0.13        38
          13       0.40      0.33      0.36        49
          14       0.29      0.43      0.34        42
          15       0.35      0.29      0.32        45
          16       0.32      0.51      0.39        55
          17       0.53    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
