In [1]:
from datasets import load_dataset, load_metric, concatenate_datasets
from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments
import torch

In [2]:
from torch.utils.data import DataLoader
from torch.nn import functional as F
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.meteor_score import meteor_score

In [3]:
device = torch.device('mps')

In [4]:
test_data = load_dataset('wikisql', split='test')

In [5]:
START_TOK = '[SOS] '
def format_dataset(example):
    return {'input': START_TOK+example['question'], 'target': example['sql']['human_readable']}

test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)

In [6]:
BUFFER = 2 # start end tokens
MAX_LENGTH = 64 + BUFFER

In [7]:
CHECKPOINT = 'facebook/bart-base'
tokenizer = BartTokenizer.from_pretrained(CHECKPOINT)

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

In [8]:
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    
    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }


    return encodings

In [9]:
finaltest_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names, num_proc=4)


Map (num_proc=4):   0%|          | 0/15878 [00:00<?, ? examples/s]

In [10]:
columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']
finaltest_data.set_format(type='torch', columns=columns)

In [11]:
local = './bart-base-model'
model = BartForConditionalGeneration.from_pretrained(local, device_map=device)

In [12]:
test_dl = DataLoader(finaltest_data, batch_size=50, shuffle=True)

In [13]:
model = model.eval()

In [14]:
# Assuming testdata is a DataLoader that batches your Dataset
total_loss = 0
total_bleu = 0
total_meteor = 0

with torch.no_grad():  # No need to track gradients in evaluation
    for batch in test_dl:
        # Send your batch of inputs to the device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        decoder_attention_mask = batch['decoder_attention_mask'].to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask)

        # Compute loss
        loss = outputs.loss
        total_loss += loss.item()

        # Compute BLEU score
        predictions = outputs.logits.argmax(-1)  # Get the model's predictions
        for prediction, label in zip(predictions, labels):
            # Convert tensors to lists
            prediction = prediction.tolist()
            label = label.tolist()

            # Compute the BLEU score between the predicted and actual sentence
            bleu_score = sentence_bleu([label], prediction)
            total_bleu += bleu_score
            
            label_str = list(map(lambda x: str(x),label))
            prediction_str = list(map(lambda x: str(x),prediction))
            # Compute the Meteor score between the predicted and actual sentence
            meteor_scr = meteor_score([label_str], prediction_str)
            total_meteor += meteor_scr


In [15]:
# Compute the average loss and BLEU score over all the batches
avg_loss = total_loss / finaltest_data.shape[0]
avg_bleu = total_bleu / finaltest_data.shape[0]
avg_meteor = total_meteor / finaltest_data.shape[0]

print(f'Average loss: {avg_loss}, Average BLEU score: {avg_bleu}, Average Meteor score: {avg_meteor}')

Average loss: 0.0019856277123114253, Average BLEU score: 0.9555418548086508, Average Meteor score: 0.9798122369119409


In [None]:
# T5 - Average loss: 0.002602017767439262, Average BLEU score: 0.94424568132835, Average Meteor score: 0.9748364119186308
# BART - Average loss: 0.0019856277123114253, Average BLEU score: 0.9555418548086508, Average Meteor score: 0.9798122369119409
