# Setup





In [None]:
# !pip install transformers -q
# !pip install datasets -q
# !pip install rouge -q
# !pip install torch -q
# !pip install tqdm -q

In [None]:
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import os
import string
import operator
import random

from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm, trange

seed = 42
torch.cuda.empty_cache()
device = torch.device('cuda')

# Module Untils

In [None]:
n_gpu = '0'
gradient_accumulation_steps = 1
lr = 2e-4
adam_epsilon = 1e-8
weight_decay = 0.0
num_warmup_steps= 0.0
num_train_epochs = 20
save_model = True
save_last_k = 1
save_last = True
train_batch_size = 8
eval_batch_size = 128
model_checkpoint = 'VietAI/vit5-base' #'google/mt5-base'
max_seq_length = 256
elem_dict = ["subject", "object", "aspect", "predicate", "label"]
data_dir = "/kaggle/input//t5-data-new/"

working_dir = "/kaggle/working"
result_dir = f"{working_dir}/result/model"
inference_dir = f"{working_dir}/result/inference"

if not os.path.exists(result_dir):
    os.makedirs(result_dir)
if not os.path.exists(inference_dir):
    os.makedirs(inference_dir)

## Data utils

In [None]:
import re

def read_data_file(data_path):
  with open(data_path, 'r', encoding='UTF-8') as fp:
      sents, labels = [], []
      for line in fp:
          # print(line)
          line = line.rstrip("\n")
          sent, tuples = line.split('===>')
          sents.append(sent)
          # tuples = tuples.replace("'", "")
          labels.append(tuples)

  return sents, labels


In [None]:
def get_max_length(inputs, tokenizer):
    return max(len(tokenizer.encode(i)) for i in inputs)

In [None]:
class MyDataset(Dataset):
    def __init__(self, tokenizer, inputs=None, targets=None):
        self.tokenizer = tokenizer
        self.inputs, self.targets = inputs or [], targets or []
        self.input_tensor_list, self.target_tensor_list = [], []

        self.max_len = max_seq_length

        self.input_tensor_list, self.target_tensor_list = self.encode(self.inputs, self.targets)



    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        source_ids = self.input_tensor_list[idx]["input_ids"].squeeze()
        target_ids = self.target_tensor_list[idx]["input_ids"].squeeze()

        source_mask = self.input_tensor_list[idx]["attention_mask"].squeeze()
        target_mask = self.target_tensor_list[idx]["attention_mask"].squeeze()

        return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": target_ids, "target_mask": target_mask}

    def encode(self, inputs=[], targets=[]):
        input_tensor_list, target_tensor_list = [], []

        for i in range(len(inputs)):
            input_i = ' '.join(inputs[i]) if isinstance(inputs[i], list) else inputs[i]
            target_i = ' '.join(targets[i]) if isinstance(targets[i], list) else targets[i]

            tokenized_input = self.tokenizer.batch_encode_plus([input_i], max_length=self.max_len,padding='max_length', truncation=True, return_tensors="pt")
            tokenized_target = self.tokenizer.batch_encode_plus([target_i], max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")

            input_tensor_list.append(tokenized_input)
            target_tensor_list.append(tokenized_target)
            
        return input_tensor_list, target_tensor_list

def get_dataset(file_path, tokenizer, mode="train"):
    inputs, targets = read_data_file(file_path)
    dataset = MyDataset(tokenizer, inputs=inputs, targets=targets)
    return dataset




In [None]:

# SPECIAL_TOKENS = ['<sub>', '<obj>', '<asp>', '<pred>', '<lab>', '<unk>', 'COM', 'COM+', 'COM-', 'SUP', 'SUP+', 'SUP-', 'EQL', 'DIF', '(', ')']

# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
# tokenizer.add_tokens(SPECIAL_TOKENS)
# train_data = get_dataset(os.path.join(data_dir, 'test.txt'), tokenizer)

# max_length = 0
# max_seq = ''
# for data in train_data:
#     if torch.count_nonzero(data['target_ids']) > max_length:
#         max_length = torch.count_nonzero(data['target_ids'])
#         max_seq = tokenizer.convert_ids_to_tokens(data['target_ids'], skip_special_tokens=True) #tokenizer.decode(data['target_ids'])
        
# print(max_seq)

# max_length = 0
# max_seq = ''
# for data in train_data:
#     if torch.count_nonzero(data['source_ids']) > max_length:
#         max_length = torch.count_nonzero(data['source_ids'])
#         max_seq = tokenizer.decode(data['source_ids'], skip_special_tokens=True)
        
# print(max_seq)



## Infer

In [None]:
def calculate_inference_loss(model, tokenizer, input_text, true_sequences):
    # Generate sequences
    with torch.no_grad():
        generated_sequences = model.generate(**inputs)

    # Tokenize true sequences
    true_inputs = tokenizer(true_sequences, return_tensors="pt", truncation=True, padding=True)

    # Forward pass through the model for true sequences
    with torch.no_grad():
        true_outputs = model(**true_inputs)

    # Get logits from the output for true sequences
    true_logits = true_outputs.logits

    # Calculate cross-entropy loss
    loss = torch.nn.functional.cross_entropy(true_logits, generated_sequences.view(-1))

    return loss.item()

In [None]:
SPECIAL_TOKENS = ['<sub>', '<obj>', '<asp>', '<pred>', '<lab>', '[UNK]', 'COM', 'COM+', 'COM-', 'SUP', 'SUP+', 'SUP-', 'EQL', 'DIF', '(', ')', ';']

def prepare_constrained_vocab(name):
    inputs, _ = read_data_file(os.path.join(data_dir, f"{name}.txt"))
    constrained_vocab = set(" ".join(inputs).split())
    constrained_vocab.update(SPECIAL_TOKENS)
    constrained_vocab = list(constrained_vocab)
    return list(SPECIAL_TOKENS)
    
    
    
class Prefix_fn_cls():
    def __init__(self, tokenizer, name, input_enc_idxs):
        self.input_enc_idxs=input_enc_idxs
        self.tokenizer= tokenizer
        self.constrained_vocab = prepare_constrained_vocab(name)
        # only add special_tokens for extract process
        self.special_ids = [element for l in self.tokenizer(self.constrained_vocab, add_special_tokens=False)['input_ids'] for element in l]
        self.special_ids = list(set(self.special_ids))

    def get(self, batch_id, previous_tokens):
        inputs = list(set(self.input_enc_idxs[batch_id].tolist())) + self.special_ids
        return inputs

In [None]:
def infer(dataset, model, tokenizer, batch_size, keep_mask= False, name="eval", constrained=False, **decode_dict):
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)

    if keep_mask:
        print("Keep mask: ", keep_mask)
        unwanted_tokens = [tokenizer.eos_token, tokenizer.pad_token]
        unwanted_ids = tokenizer.convert_tokens_to_ids(unwanted_tokens)

        def filter_decode(ids):
            ids = [i for i in ids if i not in unwanted_ids]
            tokens = tokenizer.convert_ids_to_tokens(ids)
            sentence = tokenizer.convert_tokens_to_string(tokens)
            return sentence

    inputs, outputs, targets = [], [], []    
    average_loss = 0
    
    model.eval()
    
    if name != "eval":
        with torch.no_grad():
            for batch in tqdm(data_loader, disable=True):
                if constrained:
                    prefix_fn_obj = Prefix_fn_cls(tokenizer, name, batch['source_ids'].to(device))
                    prefix_fn = lambda batch_id, sent: prefix_fn_obj.get(batch_id, sent)
                else:
                    prefix_fn = None
                outs_dict = model.generate(input_ids = batch['source_ids'].to(device),
                                           attention_mask = batch['source_mask'].to(device),
                                           output_scores = True,
                                           return_dict_in_generate = True,
                                           max_length = max_seq_length,
                                           prefix_allowed_tokens_fn = prefix_fn,
                                           **decode_dict)

                outs = outs_dict['sequences']

                if keep_mask:
                    input_ = [filter_decode(ids) for ids in batch['source_ids']]
                    dec = [filter_decode(ids) for ids in outs]
                    target = [filter_decode(ids) for ids in batch['target_ids']]
                else:
                    input_ = [tokenizer.decode(ids, skip_special_tokens=True) for ids in batch['source_ids']]
                    dec = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
                    target = [tokenizer.decode(ids, skip_special_tokens=True) for ids in batch['target_ids']]

                inputs.extend(input_)
                outputs.extend(dec)
                targets.extend(target)
                
    elif name =="eval":
        criterion = nn.CrossEntropyLoss()
        total_loss = 0
        num_batches = len(data_loader)
        with torch.no_grad():
            for batch in tqdm(data_loader, disable=True):
                lm_labels = batch["target_ids"]
                lm_labels[lm_labels[:, :] == tokenizer.pad_token_id] = -100
                outs = model(
                    batch["source_ids"].to(device),
                    attention_mask = batch["source_mask"].to(device),
                    labels = lm_labels.to(device),
                    decoder_attention_mask = batch["target_mask"].to(device),
                    decoder_input_ids = None,
                )

            loss = outs[0]
            total_loss += loss.item()
            
        average_loss = total_loss/num_batches
#         print(f"Average Evaluation Loss: {average_loss}")
                
   
    with open(os.path.join(inference_dir, f"{name}_output_{constrained}.txt"), "w", encoding="utf-8") as f:
        for i, o in enumerate(outputs):
            f.write(f"{inputs[i]} ===> {o}\n")

    
    return average_loss, inputs, outputs, targets




## eval metrics

In [None]:
import copy
from sklearn.metrics import f1_score, precision_recall_fscore_support
import re

def extract_elements(input_string):
    input_list = input_string.split(';')
    pattern = re.compile(r'<sub>(.*?)<obj>(.*?)<asp>(.*?)<pred>(.*?)<lab>(.*?)$')
    result=[]
    for i in input_list:
        i = i.strip()
        match = re.match(pattern, i[1:-1].strip())
        
        if match:
            items = match.groups()  
            new_items = []
            for i in range(len(items)):
                new_items.append(items[i].strip())
            result.append(new_items)
        else:
            result.append(None)
        
    return result

def compute_metrics(predicted_list, gold_list):
    # Transpose the list of tuples to get a list of lists where each list corresponds to a position
    predicted_positions = list(map(list, zip(*predicted_list)))
    gold_positions = list(map(list, zip(*gold_list)))

    precision_scores = []
    recall_scores = []
    micro_f1_scores = []
    macro_f1_scores = []
    f1_scores = []

    # Iterate over each position
    for predicted, gold in zip(predicted_positions, gold_positions):
        # Compute micro-F1 for the position
        micro_f1 = f1_score(predicted, gold, average='micro')
        micro_f1_scores.append(micro_f1)

        # Compute macro-F1 for the position
        macro_f1 = f1_score(predicted, gold, average='macro')
        macro_f1_scores.append(macro_f1)

        # Compute F1-score for the position
        p, r, f1, _ = precision_recall_fscore_support(predicted, gold, average=None)
        f1_scores.append(f1[0])
        precision_scores.append(p[0])
        recall_scores.append(r[0])

    return precision_scores, recall_scores, micro_f1_scores, macro_f1_scores, f1_scores


def eval(pred_tups, gold_tups, verbose="quite", elem_dict=None):
    assert len(pred_tups) == len(gold_tups)

    elem_dict = elem_dict
    all_labels, all_predictions, error_preds = [], [], []
    for index in range(len(gold_tups)):
        predict_list = extract_elements(pred_tups[index])
        gold_list = extract_elements(gold_tups[index])
        
        if len(gold_list) > len(predict_list):
            error_preds.append(f"{index} Incomplete Prediction: {gold_tups[index]} ===> {pred_tups[index]}")
            
        for i in range(len(predict_list)):
            if  i >= len(gold_list):
                error_preds.append(f"{index} Adundant Prediction: {pred_tups[index]}")
            elif predict_list[i] is None or gold_list[i] is None or len(gold_list[i]) != len(predict_list[i]) or len(predict_list[i]) != 5:
                error_preds.append(f"{index}: {gold_tups[index]} ===> {pred_tups[index]}")
            else:
                all_labels.append(gold_list[i])
                all_predictions.append(predict_list[i])
        
        

    precision_scores, recall_scores, micro_f1, macro_f1, f1_scores = compute_metrics(all_predictions, all_labels)

    scores_dict = {}
    for i, elem in enumerate(elem_dict):
        scores_dict[elem] = {"P": precision_scores[i], "R": recall_scores[i], "F1": f1_scores[i], "Marco - F1": macro_f1[i], "Micro - F1": micro_f1[i]}

    with open(os.path.join(inference_dir, 'error_prediction.txt'), 'w', encoding='utf-8') as fout:
        for i, error in enumerate(error_preds):
            fout.write(f"{error}\n")

    print(f"The number of error predictions: {len(error_preds)}")
    if verbose != "quiet":
        print(f"Evaluation Result: {scores_dict}")

    return scores_dict


## train

In [None]:
def train(model, tokenizer, train_data, val_data, epochs, lr, train_batch_size, eval_batch_size, acc_step=None, save_model=False, save_last=False, elem_dict= elem_dict, constrained=False ):
    print("#"*20+" BEGIN TRAINING "+ "#"*20)
    no_decay =["bias", "LayerNorm.Weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": weight_decay, },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)

    if acc_step is None:
        acc_step = gradient_accumulation_steps
        
    train_loader = DataLoader(train_data, batch_size=train_batch_size, drop_last=True, shuffle=True)
    t_total = (
        (len(train_loader.dataset) // (train_batch_size * max(1, len(n_gpu))))
        // acc_step
        *float(epochs)
    )
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps = t_total)
    train_iterator = trange(int(epochs), dynamic_ncols=True, desc="Epoch")
    
    train_losses, eval_losses = [], []
    for n_epoch, _ in enumerate(train_iterator):
        epoch_train_loss = 0.0
        epoch_iterator = tqdm(train_loader, dynamic_ncols=True, desc="Iteration", disable=True)

        for step, batch in enumerate(epoch_iterator):
            model.train()

            lm_labels = batch["target_ids"]
            lm_labels[lm_labels[:, :] == tokenizer.pad_token_id] = -100
            outputs = model(
                batch["source_ids"].to(device),
                attention_mask = batch["source_mask"].to(device),
                labels = lm_labels.to(device),
                decoder_attention_mask = batch["target_mask"].to(device),
                decoder_input_ids = None,
            )

            loss = outputs[0]
            loss.backward()
            epoch_train_loss += loss.item()

            if (step + 1) % acc_step == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
        
#         eval_loss = 0.0
        eval_loss, sents, predictions, golds = infer(val_data, model, tokenizer, batch_size=train_batch_size, name="eval", constrained=constrained)
        eval_losses.append(eval_loss)
        
#         for i in range(5):
#           if i < len(predictions):
#             print(f"{sents[i]} ===> {predictions[i]}")

        # score_dict = eval(predictions, golds, verbose="info", elem_dict= elem_dict)

#         if save_model and n_epoch in range(num_train_epochs)[-save_last_k:]:
#             save_dir = os.path.join(result_dir, f"{model_checkpoint}_checkpoint-e{n_epoch}-constrained-{constrained}")
#             if not os.path.exists(save_dir):
#                 os.makedirs(save_dir)

#             model.save_pretrained(save_dir)
#             tokenizer.save_pretrained(save_dir)

#             print(f"Save model checkpoint to {save_dir}")
        
        train_losses.append(epoch_train_loss / len(epoch_iterator))
        
        print(f"Epoch {n_epoch} - Average epoch train loss: {epoch_train_loss / len(epoch_iterator):.5f} lr: {scheduler.get_last_lr()}")
        print(f"Average Evaluation Loss: {eval_loss:.5f}")

    if save_last:
        save_dir = os.path.join(result_dir, f"{model_checkpoint.split('/')[-1]}-e{n_epoch}-extract-tuple-constrained-model-{constrained}")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        model.save_pretrained(save_dir)
        tokenizer.save_pretrained(save_dir)

        print(f"Save model checkpoint to {save_dir}")
    
    with open(os.path.join(inference_dir, "train_losses.txt"), 'w') as f:
        for loss in train_losses:
            f.write(f"{loss}\n")
    
    with open(os.path.join(inference_dir, "eval_losses.txt"), 'w') as f:
        for loss in eval_losses:
            f.write(f"{loss}\n")

    print("#"*20+" FINISH TRAINING "+ "#"*20)


# Run

In [None]:
# do_train = True
# do_test = True
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def main(do_train, do_test, test_label, constrained=False):
    if do_train:
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        tokenizer.add_tokens(SPECIAL_TOKENS)
        
        train_inputs, train_labels = read_data_file(os.path.join(data_dir, "train.txt"))
        train_max_length = get_max_length(train_labels, tokenizer)
        eval_inputs, eval_labels = read_data_file(os.path.join(data_dir, "dev.txt"))
        eval_max_length = get_max_length(eval_labels, tokenizer)
    #     eval_max_length = 0

        train_data = get_dataset(os.path.join(data_dir, "train.txt"), tokenizer=tokenizer)
#         eval_data = []
        eval_data = get_dataset(os.path.join(data_dir, "dev.txt"), tokenizer=tokenizer)

        model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
        model.resize_token_embeddings(len(tokenizer))
        model.to(device)

        print(f"Train: {len(train_data)}, Eval: {len(eval_data)}, Model: {model_checkpoint}")
        print(f"Train Label Max Length : {train_max_length}\nEval Label Max Length: {eval_max_length}")

        print("*"*20+" Training "+"*"*20)
        train(model, tokenizer, train_data, eval_data, epochs=num_train_epochs, lr=lr, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, save_model=True, save_last=True, elem_dict=elem_dict, constrained=constrained)

    
    
    if do_test:
        print("*"*20+" TESTING "+"*"*20)
        all_checkpoints = []
        saved_model_dir = result_dir

        for f in os.listdir(saved_model_dir):
            file_name = os.path.join(saved_model_dir, f)
            if 'constrained-model' in f and model_checkpoint.split('/')[-1] in f:
                all_checkpoints.append(file_name)
        
    
        test_inputs, _ = read_data_file(os.path.join(data_dir, "test.txt"))
        print(f"Test: {len(test_inputs)}")

        best_f1, best_checkpoint, best_epoch = -999999.0, None, None
        best_score_dict, best_pred_dict = None, None
        all_epochs = []

    #     del model
#         print("*"*20+" Testing "+"*"*20)

        for checkpoint in all_checkpoints:
            model_name = checkpoint.split('/')[-1]
#             epoch = checkpoint.split('-')[-1][1:]
#             all_epochs.append(epoch)
            print(f"Load model from checkpoint {checkpoint}")

            model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
            tokenizer = AutoTokenizer.from_pretrained(checkpoint)
              
            test_data = get_dataset(os.path.join(data_dir, "test.txt"), tokenizer=tokenizer)
              
            model.to(device)

            _, sents, predictions, golds = infer(test_data, model, tokenizer, batch_size=train_batch_size, name="test", constrained=constrained)

            for i in range(5):
              if i < len(predictions):
                print(f"{sents[i]} ===> {predictions[i]}")
              
            if test_label:
                score_dict = eval(predictions, golds, verbose="info", elem_dict=elem_dict)

            with open(f"{inference_dir}/test-{model_name}.txt", 'w', encoding="utf-8") as fout:
                for i, s in enumerate(sents):
                    fout.write(f"{s} ===> {predictions[i]}\n")


In [None]:
main(do_train = True, do_test = True, test_label = False, constrained=True)

In [None]:
# !rm -rf /kaggle/working/*