In [None]:
!pip install datasets
!pip install accelerate -U
!pip install transformers
!pip install bert-score
import re
import numpy as np
import nltk
from nltk.translate.bleu_score import sentence_bleu
from bert_score import BERTScorer

In [None]:
def extract_time_ranges(entity_list):
    time_ranges = []
    for entity in entity_list:
        # Extract the time ranges using a regular expression
        match = re.search(r'\((\d{4}.*)\)', entity)
        if match:
            time_range = f"({match.group(1)})"
            time_ranges.append(time_range)
    return time_ranges

def extract_entities(entity_list):
    entities = []
    for entity in entity_list:
        # Extract the part before the parentheses or part within parentheses that does not contain digits
        match = re.match(r'(.*?)(?: \((\D+?)\))? \((\d{4}.*)\)', entity)
        if match:
            entity_name = match.group(1).strip()
            entities.append(entity_name)
        else:
            entities.append(entity)
    return entities

def compute_time_bleu(gt_list, pred_list, tokenizer):
    gt_time_list = extract_time_ranges(gt_list)
    pred_time_list = extract_time_ranges(pred_list)

    if len(pred_time_list) == 0:
        return 0

    gt_time_list = list(map(lambda x: tokenizer(x, return_tensors="pt").input_ids[0].tolist(), gt_time_list))
    pred_time_list = list(map(lambda x: tokenizer(x, return_tensors="pt").input_ids[0].tolist(), pred_time_list))

    gt_time_list = [lst[:-1] if lst and lst[-1] == 1 else lst for lst in gt_time_list]
    pred_time_list = [lst[:-1] if lst and lst[-1] == 1 else lst for lst in pred_time_list]

    total_bleu = 0
    for i, pred in enumerate(pred_time_list):
        curr_bleu = sentence_bleu(gt_time_list, pred)
        total_bleu += curr_bleu
    total_bleu = total_bleu / len(pred_time_list)
    return total_bleu

def compute_em(gt_list, pred_list):
    # Calculate the match score for each ground truth item
    matches = [1 if gt in pred_list else 0 for gt in gt_list]

    # Compute the average exact match score based on the length of the ground truth list
    exact_match_score = sum(matches) / len(gt_list)

    return exact_match_score

def compute_f1(gt_tokens, pred_tokens):
    precision = compute_precision(gt_tokens, pred_tokens)
    recall = compute_recall(gt_tokens, pred_tokens)

    if precision + recall == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)

    return f1_score

# Computed as # correct tokens / # pred tokens
def compute_precision(gt_tokens, pred_tokens):
    prec = []
    for gt, pred in zip(gt_tokens, pred_tokens):
        true_positives = len(set(gt) & set(pred))
        if len(pred) == 0:
            prec.append(0.0)
        prec.append(true_positives / len(pred))
    return np.mean(prec)


# Evaluates completeness: # correct answers / # GT answers
def compute_recall(gt_tokens, pred_tokens):
    rec = []
    for gt, pred in zip(gt_tokens, pred_tokens):
        true_positives = len(set(gt) & set(pred))
        if len(gt) == 0:
            rec.append(0.0)
        rec.append(true_positives / len(gt))
    return np.mean(rec)

def compute_entity_bert(gt_list, pred_list):
    gt_entity_list = extract_entities(gt_list)
    pred_entity_list = extract_entities(pred_list)

    if len(pred_entity_list) == 0:
        return 0

    # Concatenate all strings in pred and gt lists
    concat_gt = ' '.join(gt_entity_list)
    concat_pred = ' '.join(pred_entity_list)

    scorer = BERTScorer(model_type='bert-base-uncased')  # NOTE: this can take some time to load
    P, R, F1 = scorer.score([concat_gt], [concat_pred])
    # print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")
    return F1.tolist()[0]    # Convert Tensor to List


# Define compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    bleu_score = compute_time_bleu(decoded_labels, decoded_preds, tokenizer)
    em = compute_em(decoded_labels, decoded_preds)
    f1 = compute_f1(labels, predictions)
    recall = compute_recall(labels, predictions)
    entity_bert = compute_entity_bert(decoded_labels, decoded_preds)

    return {'em': em, 'f1': f1, 'recall': recall, 'bleu': bleu_score, 'Entity BERT': entity_bert}

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [None]:
import json
import torch
from datasets import Dataset, load_metric
from transformers import BartTokenizer, BartForConditionalGeneration, BartForQuestionAnswering, Seq2SeqTrainer, Seq2SeqTrainingArguments
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import sentence_bleu

In [None]:
# Load dataset
with open('/kaggle/input/tlqa2dataset/train_TLQA.json', 'r') as f:
    train = json.load(f)

with open('/kaggle/input/tlqa2dataset/val_TLQA.json', 'r') as f:
    val = json.load(f)

# Flatten the answers
def flatten_answers(answers):
    return ', '.join(answers) if isinstance(answers, list) else answers

train_dict = {
    "question": [item["question"] for item in train],
    "answers": [flatten_answers(item['answers']) for item in train]
}

val_dict = {
    "question": [item["question"] for item in val],
    "answers": [flatten_answers(item['answers']) for item in val]
}

train_dataset = Dataset.from_dict(train_dict)
val_dataset = Dataset.from_dict(val_dict)

# Initialize tokenizer and model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

# Tokenize the dataset
# Tokenize the dataset
def preprocess_data(examples):
    inputs = examples['question']
    targets = examples['answers']
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding='max_length')
    model_inputs['labels'] = labels['input_ids']

    # Replace padding token id's in the labels by -100, so they're ignored in the loss computation
    model_inputs['labels'] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in labels_example]
        for labels_example in model_inputs['labels']
    ]
    return model_inputs

train_dataset = train_dataset.map(preprocess_data, batched=True)
val_dataset = val_dataset.map(preprocess_data, batched=True)


In [None]:
import torch
# Set training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',          # Directory to save model output and checkpoints
    num_train_epochs=2,              # Number of epochs to train the model
    per_device_train_batch_size=4,   # Batch size per device during training
    per_device_eval_batch_size=4,    # Batch size for evaluation
    warmup_steps=500,                # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # Weight decay for regularization
    logging_dir='./logs',            # Directory to save logs
    logging_steps=10,                # Log metrics every specified number of steps
    evaluation_strategy="epoch",     # Evaluation is done at the end of each epoch
    report_to='none'
)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

# Train the model
trainer.train()
model.save_pretrained('./my_model_v2', safe_serialization=False)