In [None]:
# General imports
import os
import random
import math
import itertools
import pandas as pd
from tqdm import tqdm
import fastrand

# pytorch imports
import torch
torch.backends.cuda.matmul.allow_tf32 = True

# Transformer tokenizer imports
from tokenizers import BertWordPieceTokenizer
from tokenizers.pre_tokenizers import Whitespace
from transformers import BertTokenizerFast
from tokenizers.processors import TemplateProcessing
from transformers.tokenization_utils_base import BatchEncoding

# Transformers data collator
from transformers.data.data_collator import DataCollatorForLanguageModeling

# Transformers Bert model
from transformers import BertConfig, BertForPreTraining, BertForMaskedLM, Trainer, TrainingArguments, EarlyStoppingCallback

In [3]:
# Data and models path settings
base_path = "./"
base_model_path = os.path.join(base_path, "/models")
base_data_path = os.path.join(base_path, "/dataset")


train_path = os.path.join(base_data_path, "train.csv")
val_path = os.path.join(base_data_path, "val.csv")

TASK_CONFIG = {"name": "next_sentence_prediction"}

tokenizer_path = os.path.join(base_path, "tokenizer")

# Parameter tokenizer
VOCAB_SIZE = 30000
MIN_FREQ = 2

#Dataset parameters
MAX_CHAR_LEN_SYM_EXPR= 5000 #None -- no truncate
RECREATE_TOKENIZED_DATASET=True

#models
BSMAL="bert_small"
BNORM="bert_normal"
BLARG="bert_large"

MODEL=BNORM

if MODEL == BSMAL:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 512
    INTERMEDIATE_SIZE = 2048
    NUM_ATTENTION_HEADS = 8
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2
    
if MODEL == BNORM:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 768
    INTERMEDIATE_SIZE = 3072
    NUM_ATTENTION_HEADS = 12
    NUM_HIDDEN_LAYERS = 12
    TYPE_VOCAB_SIZE = 2

if MODEL == BLARG:
    MAX_SEQ_LEN = 512
    MAX_POSITION_EMBEDDINGS = 514
    HIDDEN_SIZE = 1024
    INTERMEDIATE_SIZE = 4096
    NUM_ATTENTION_HEADS = 16
    NUM_HIDDEN_LAYERS = 24
    TYPE_VOCAB_SIZE = 2

model_path = os.path.join(base_model_path, "{}_{}".format(TASK_CONFIG["name"],MODEL))

# Training Parameters
train_from_scratch = True
best_ckp = ""
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 0.0001
PER_DEVICE_TRAIN_BATCH_SIZE = 32
PER_DEVICE_EVAL_BATCH_SIZE = 64
MASKING_RATE = 0.30
PATIENCE = 3

# GPU settings
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0,1,2,3"

In [4]:
def train_tokenizer(data_path, tokenizer_path, vocab_size, min_freq, max_char_len_sym=None):

    
    # Prepare data to fed as input to the tokenizer training
    df = pd.read_csv(data_path, sep="\t")
    df = df.fillna('')

    print("Input Data Read")
    if max_char_len_sym==None:
        tokenizer_input = df["instructions"].apply(lambda x: x.replace("NEXT_I ","")) + " " + df["sym_expression"]
    else:
        tokenizer_input = df["instructions"].apply(lambda x: x.replace("NEXT_I ","")) + " " + df["sym_expression"].str.slice(0,max_char_len_sym)
    
        
    # Train the tokenizer
    tokenizer = BertWordPieceTokenizer()
    tokenizer.pre_tokenizer = Whitespace()
    
    tokenizer.train_from_iterator(iterator=tqdm(tokenizer_input,total=len(tokenizer_input)), vocab_size=vocab_size, 
                                  min_frequency=min_freq)
    
    # Define the post processing for data
    tokenizer.post_processor = TemplateProcessing(
      single="[CLS] $A [SEP]",
      pair="[CLS] $A [SEP] $B:1 [SEP]:1",
      special_tokens=[("[CLS]", tokenizer.token_to_id("[CLS]")), ("[SEP]", tokenizer.token_to_id("[SEP]"))]
      )
    
    wrapped_tokenizer = BertTokenizerFast(tokenizer_object=tokenizer, model_max_length = MAX_SEQ_LEN)
    
    wrapped_tokenizer.save_pretrained(tokenizer_path)
    
    return tokenizer

In [5]:
def load_tokenizer(tokenizer_path):
  # load tokenizer from dict
  tokenizer =  BertTokenizerFast.from_pretrained(tokenizer_path)
  return tokenizer

In [6]:
# Train Bpe only if it has not been trained yet
if not os.path.exists(tokenizer_path):
    print("Training tokenizer...")
    print(tokenizer_path)
    tok = train_tokenizer(train_path, tokenizer_path, VOCAB_SIZE, MIN_FREQ,MAX_CHAR_LEN_SYM_EXPR)
    print("Training Ended.")
else:
    print("Tokenizer Already Trained!")

Tokenizer Already Trained!


In [11]:
class AsmToSymbolicDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path, tokenizer, task_config, save_file=None, incorrect_mapping_prob=0.5, num_trials=10, max_char_len_sym=None, batch_size=None):
        
        self.data_store = []

        self.tokenizer = tokenizer
        self.incorrect_mapping_prob = incorrect_mapping_prob
        self.num_trials = num_trials
        self.max_char_len_sym=max_char_len_sym
        self.batch_size=batch_size
        
        if save_file == None:
            self.df = pd.read_csv(dataset_path, sep="\t").fillna('')
            valid_task_configs = {"contain_intruders", "is_shuffled", "next_sentence_prediction", "masked_language_model_only"}
            if task_config["name"] not in valid_task_configs:
                raise Exception("{} is not a valid task_config, please insert one among 'contain_intruders' and 'is_shuffled'".format(task_config))
            else:
                self.task_config = task_config
            
            if task_config["name"] == "masked_language_model_only":
                self.df = self.df.drop_duplicates(["instructions"])

            self.__init_structures()
        else:
            print("Loading from save file: {}".format(save_file))
            self.__init_structures_from_file(save_file)
        
    def __insert_intruders(self, instruction_list, instruction_set):

        original_instructions = " ".join(instruction_list)
        modified = False

        for i in range(0, self.num_trials):

            number_instruders = math.ceil(len(instruction_list) * self.task_config["intruder_percentage"])
            positions = random.sample(range(len(instruction_list)),number_instruders)
            random_elems =  random.sample(instruction_set, number_instruders)

            for position,random_elem in zip(positions,random_elems):
                instruction_list[position] = random_elem

            modified_instructions = " ".join(instruction_list)
            if modified_instructions != original_instructions:
                modified = True
                break

        return modified, modified_instructions
    
    def __shuffle(self, instruction_list):

        original_instructions = " ".join(instruction_list)

        modified = False

        for i in range(0, self.num_trials):
            number_instruders = math.ceil(len(instruction_list) * self.task_config["shuffling_percentage"])
            positions = random.sample(range(len(instruction_list)), number_instruders)
            random_elems =  [elem for i,elem in enumerate(instruction_list) if i in positions]
            random.shuffle(random_elems)

            for position,random_elem in zip(positions,random_elems):
                instruction_list[position] = random_elem
            modified_instructions = " ".join(instruction_list)

            if modified_instructions != original_instructions:
                modified = True
                break
                
        return modified, modified_instructions
   
        
    def __init_structures(self):
        
        instruction_set = set(itertools.chain.from_iterable(self.df["instructions"].apply(lambda x: x.split("NEXT_I")).values))
        instruction_set = {elem.strip() for elem in instruction_set}
        
        pairs = self.df[["instructions", "sym_expression"]].values
        
        #accumulators for batch tokenization
        accum_x=[]
        accum_y=[]
        
        for x,y in tqdm(pairs):
            
            if self.task_config["name"] != "masked_language_model_only":

                is_incorrect = random.random() < self.incorrect_mapping_prob

                if is_incorrect:
            
                    if self.task_config["name"] == "contain_intruders":
                        instruction_list = x.split("NEXT_I")
                        instruction_list = [inst.strip() for inst in instruction_list]
                        modified, x = self.__insert_intruders(instruction_list, instruction_set)

                    elif self.task_config["name"] == "is_shuffled":
                        instruction_list = x.split("NEXT_I")
                        instruction_list = [inst.strip() for inst in instruction_list]
                        modified, x = self.__shuffle(instruction_list)

                    elif self.task_config["name"] == "next_sentence_prediction":
                        modified = False
                        for i in range(0, self.num_trials):
                            x = x.replace("NEXT_I ", "")
                            #random_idx = random.randint(0, len(pairs) - 1)
                            random_idx = fastrand.pcg32bounded(len(pairs))
                            y_new = pairs[random_idx][1]
                            if y_new != y:
                                modified = True
                                y = y_new
                                break

                    label = torch.tensor(1, dtype=torch.long) if modified else torch.tensor(0, dtype=torch.long)

                else:
                    x = x.replace("NEXT_I ", "")
                    label = torch.tensor(0, dtype=torch.long)
                if self.max_char_len_sym != None and len(y)> self.max_char_len_sym:
                    y=y[:self.max_char_len_sym]
        
                if self.batch_size==None:
                    example = self.tokenizer(text=x, text_pair=y, truncation=True, max_length=MAX_SEQ_LEN)

                    example = {
                        "input_ids": torch.tensor(example["input_ids"], dtype=torch.long),
                        "token_type_ids": torch.tensor(example["token_type_ids"], dtype=torch.long),
                        "next_sentence_label": label
                        }
                else:
                    accum_x.append(x)
                    accum_y.append(y)
                    if len(accum_x) == self.batch_size:
                        example = self.tokenizer(text=accum_x, text_pair=accum_y, truncation=True, max_length=MAX_SEQ_LEN)
                        accum_x.clear()
                        accum_y.clear()
                    
            else:
                x = x.replace("NEXT_I ", "")                
                if self.batch_size==None:
                    example = self.tokenizer(text=x, truncation=True, max_length=MAX_SEQ_LEN)
                    example = {
                        "input_ids": torch.tensor(example["input_ids"], dtype=torch.long),
                        "token_type_ids": torch.tensor(example["token_type_ids"], dtype=torch.long)
                    }
                else:
                    accum_x.append(x)
                    if len(accum_x) == self.batch_size:
                        example = self.tokenizer(text=x, truncation=True, max_length=MAX_SEQ_LEN)
                        accum_x.clear()

            self.data_store.append(example)
        
    def __init_structures_from_file(self,save_file):
        self.data_store=torch.load(save_file)
                
    def __len__(self) -> int:
        return len(self.data_store)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data_store[idx]
    
    def save_to_file(self,save_file):
        torch.save(self.data_store, save_file)

In [12]:
tokenizer = load_tokenizer(tokenizer_path)

In [13]:
print("Creating or loading train dataset")
train_tokenized_path=train_path+".token"
train_dataset =None
if os.path.isfile(train_tokenized_path) and not RECREATE_TOKENIZED_DATASET:
    print("Tokenized train dataset appear to exists, loading")
    train_dataset = AsmToSymbolicDataset(train_path, tokenizer,TASK_CONFIG, save_file=train_tokenized_path,max_char_len_sym=MAX_CHAR_LEN_SYM_EXPR,batch_size=128)
else:
    print("Tokenized train dataset does not appear to exists, creating and saving")
    train_dataset = AsmToSymbolicDataset(train_path, tokenizer, TASK_CONFIG,max_char_len_sym=MAX_CHAR_LEN_SYM_EXPR)
    print("Saving to file..")

Creating or loading train dataset
Tokenized train dataset does not appear to exists, creating and saving


 15%|█▍        | 2575165/17215046 [14:10<1:01:10, 3989.04it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

 25%|██▌       | 4343699/17215046 [23:42<1:10:14, 3054.28it/s]


KeyboardInterrupt: 

In [None]:
print("Creating or loading validation dataset")
val_tokenized_path=val_path+".token"
val_dataset =None
if os.path.isfile(val_tokenized_path) and not RECREATE_TOKENIZED_DATASET:
    print("Tokenized validation dataset appear to exists, loading")
    val_dataset = AsmToSymbolicDataset(val_path, tokenizer,TASK_CONFIG,  save_file=val_tokenized_path,max_char_len_sym=MAX_CHAR_LEN_SYM_EXPR)
else:
    print("Tokenized validation dataset does not appear to exists, creating and saving")
    val_dataset = AsmToSymbolicDataset(val_path, tokenizer, TASK_CONFIG,max_char_len_sym=MAX_CHAR_LEN_SYM_EXPR)
    print("Saving to file..")

In [14]:
mlm_collator = DataCollatorForLanguageModeling(tokenizer, mlm_probability=MASKING_RATE)

In [15]:
if TASK_CONFIG["name"] != "masked_language_model_only":
    total_samples = len(train_dataset.data_store)
    incorrect_samples = sum([1 for elem in train_dataset.data_store if elem ["next_sentence_label"] == 1 ])
    print(f"{incorrect_samples} Incorrent samples over {total_samples} total samples")

8605970 Incorrent samples over 17215046 total samples


In [16]:
confignew = BertConfig(
                vocab_size = len(tokenizer.vocab),
                max_position_embeddings = MAX_POSITION_EMBEDDINGS,
                hidden_size = HIDDEN_SIZE,
                intermediate_size = INTERMEDIATE_SIZE,
                num_attention_heads = NUM_ATTENTION_HEADS,
                num_hidden_layers = NUM_HIDDEN_LAYERS,
                type_vocab_size = TYPE_VOCAB_SIZE
)

In [17]:
if TASK_CONFIG["name"] == "masked_language_model_only":
    model = BertForMaskedLM(config=confignew)
else:
    model = BertForPreTraining(config=confignew)
    
model.train()

print("Total number of parameters: ", model.num_parameters())

Total number of parameters:  92645512


In [18]:
training_args = TrainingArguments(
            output_dir = model_path,
            overwrite_output_dir = True,
            num_train_epochs = NUM_TRAIN_EPOCHS,
            learning_rate = LEARNING_RATE,
            per_device_train_batch_size = PER_DEVICE_TRAIN_BATCH_SIZE,
            save_strategy = 'epoch',
            # save_steps=7126,
            # save_steps=67246,
            save_total_limit = 100,
            warmup_steps =1425,
            logging_strategy = 'steps',
            logging_steps = 500,
            prediction_loss_only = True,
            load_best_model_at_end = True,
            # fp16=True,
            do_eval = True,
            gradient_accumulation_steps=2,
            evaluation_strategy = 'epoch',
            # eval_steps=7126,
            # eval_steps=67246
            metric_for_best_model = 'eval_loss',
            per_device_eval_batch_size = 128,
            dataloader_num_workers =4,
            dataloader_pin_memory=True
            )

trainer = Trainer(
            model = model,
            args = training_args,
            data_collator = mlm_collator,
            train_dataset = train_dataset,
            eval_dataset = val_dataset,
            # callbacks = [EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
            )

In [19]:
torch.cuda.empty_cache()

In [None]:
if train_from_scratch:
    trainer.train()
else:
    trainer.train(best_ckp)



***** Running training *****
  Num examples = 17215046
  Num Epochs = 1
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 2
  Total optimization steps = 67246


Epoch,Training Loss,Validation Loss
