# Notebook used to calculate the BartScores for BART

please note: this notebook only works on pre-trained models and the path to this model has to specified.

In [9]:
!pip install transformers sentencepiece torch datasets
!pip install rouge_score



In [10]:
import nltk
nltk.download("punkt", quiet=True)
from sklearn.metrics import classification_report
import numpy as np
import tensorflow as tf
import random as python_random
import datasets

from torch.utils.data import DataLoader
import torch

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BartTokenizer, BartForConditionalGeneration
)

import torch
import torch.nn as nn
import traceback
from typing import List

import pickle as pickle_rick

#TODO: Set your path to a pre-trained model here (BART
#TODO: change path to match locally
trained_bart = f""
# Global vars used in the project -> used in functions as well
tokenizer = AutoTokenizer.from_pretrained(trained_bart)
model = AutoModelForSeq2SeqLM.from_pretrained(trained_bart)


In [1]:
class BARTScorer:
    def __init__(self, device='cuda:0', max_length=1024, checkpoint='facebook/bart-large-cnn'):
        # Set up model
        self.device = device
        self.max_length = max_length
        self.tokenizer = BartTokenizer.from_pretrained(checkpoint)
        self.model = BartForConditionalGeneration.from_pretrained(checkpoint)
        self.model.eval()
        self.model.to(device)

        # Set up loss
        self.loss_fct = nn.NLLLoss(reduction='none', ignore_index=self.model.config.pad_token_id)
        self.lsm = nn.LogSoftmax(dim=1)

    def load(self, path=None):
        """ Load model from paraphrase finetuning """
        if path is None:
            path = 'models/bart.pth'
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def score(self, srcs, tgts, batch_size=4):
        """ Score a batch of examples """
        score_list = []
        for i in range(0, len(srcs), batch_size):
            src_list = srcs[i: i + batch_size]
            tgt_list = tgts[i: i + batch_size]
            try:
                with torch.no_grad():
                    encoded_src = self.tokenizer(
                        src_list,
                        max_length=self.max_length,
                        truncation=True,
                        padding=True,
                        return_tensors='pt'
                    )
                    encoded_tgt = self.tokenizer(
                        tgt_list,
                        max_length=self.max_length,
                        truncation=True,
                        padding=True,
                        return_tensors='pt'
                    )
                    src_tokens = encoded_src['input_ids'].to(self.device)
                    src_mask = encoded_src['attention_mask'].to(self.device)

                    tgt_tokens = encoded_tgt['input_ids'].to(self.device)
                    tgt_mask = encoded_tgt['attention_mask']
                    tgt_len = tgt_mask.sum(dim=1).to(self.device)

                    output = self.model(
                        input_ids=src_tokens,
                        attention_mask=src_mask,
                        labels=tgt_tokens
                    )
                    logits = output.logits.view(-1, self.model.config.vocab_size)
                    loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1))
                    loss = loss.view(tgt_tokens.shape[0], -1)
                    loss = loss.sum(dim=1) / tgt_len
                    curr_score_list = [-x.item() for x in loss]
                    score_list += curr_score_list

            except RuntimeError:
                traceback.print_exc()
                print(f'source: {src_list}')
                print(f'target: {tgt_list}')
                exit(0)
        return score_list

    def multi_ref_score(self, srcs, tgts: List[List[str]], agg="mean", batch_size=4):
        # Assert we have the same number of references
        ref_nums = [len(x) for x in tgts]
        if len(set(ref_nums)) > 1:
            raise Exception("You have different number of references per test sample.")

        ref_num = len(tgts[0])
        score_matrix = []
        for i in range(ref_num):
            curr_tgts = [x[i] for x in tgts]
            scores = self.score(srcs, curr_tgts, batch_size)
            score_matrix.append(scores)
        if agg == "mean":
            score_list = np.mean(score_matrix, axis=0)
        elif agg == "max":
            score_list = np.max(score_matrix, axis=0)
        else:
            raise NotImplementedError
        return list(score_list)

    def test(self, batch_size=3):
        """ Test """
        src_list = [
            'This is a very good idea. Although simple, but very insightful.',
            'Can I take a look?',
            'Do not trust him, he is a liar.'
        ]

        tgt_list = [
            "That's stupid.",
            "What's the problem?",
            'He is trustworthy.'
        ]

        print(self.score(src_list, tgt_list, batch_size))

NameError: name 'List' is not defined

In [12]:
bart_scorer = BARTScorer(checkpoint=trained_bart)

Found GPU at: /device:GPU:0


In [13]:
# Make reproducible as much as possible
np.random.seed(1234)
#tf.random.set_seed(1234)
python_random.seed(1234)
seed = 1234



# Values calculated in the data-explore notebook
max_input_length = 125
max_output_length = 193


def preprocess(data):
    """
    Function used to preprocess the input and output for the model.
    The input gets the special [SEP] token and the output gets a space added.
    """
    return {
        "input": data['premise'] + '</s>' + data['hypothesis'],
        "output": str(data['label']) + ' ' + data['explanation_1'],
    }

# adapted from:
# https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization.py
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    """
    Function used to tokenize our inputs in batches, saves RAM usage/
    """
    input_tokenized = tokenizer(
        batch['input'], padding="max_length", truncation=True, max_length=max_source_length
    )

    output_tokenized = tokenizer(
        batch['output'], padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in input_tokenized.items()}

    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in output_tokenized["input_ids"]
    ]
    return batch



# Adapted from:
# https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization.py
rouge_metric = datasets.load_metric("rouge")

def postprocess_text(preds, labels):
    """
    Function that is used to reformat the data in a way that we can create a classification report
    and obtain the ROUGE scores
    """
    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels]

    # Code that is needed to get the classification report
    nli_preds, nli_golds = [], []
    correct_labels = ['0', '1', '2']
    for idx, l in enumerate(labels):
      n = preds[idx][0]
      if n in correct_labels:
        nli_preds.append(n)
      else:
        nli_preds.append('NA')
      nli_golds.append(l[0])

    return preds, labels, nli_preds, nli_golds


def compute_metrics(eval_preds):
    """
    Function that is used to obtain the ROUGE and classifcation report
    """
   # preds = the output of the model that was generated
   # labels = gold labels from the validation set
    print('Started compute metrics')
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # take care of padding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels, nli_preds, nli_golds = postprocess_text(decoded_preds, decoded_labels)

    print(classification_report(nli_golds, nli_preds))
    
    rouge = rouge_metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )

    # Extract a few results from ROUGE
    rouge = {key: value.mid.fmeasure * 100 for key, value in rouge.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    rouge["gen_len"] = np.mean(prediction_lens)
    rouge = {k: round(v, 4) for k, v in rouge.items()}
    return rouge


def generate_predictions(model, test_dict):
    """
    Function that is used to obtain the test predictions and metrics
    """

    test_data = test_dict.map(
        lambda batch: batch_tokenize_preprocess(
            batch, tokenizer, max_input_length, max_output_length
        ),
        batched=True,
        remove_columns=test_dict.column_names,
    )
    test_data.set_format("torch")

    dataloader = DataLoader(
        test_data, batch_size=32)
    # Use the trained model to generate outputs using the input ids and AM
    outputs = []

    for idx, batch in enumerate(dataloader):
        print(f'Running batch: {idx + 1} of total {len(dataloader)}')

        # Obtain iput ids and attention mask
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)

        outputs.append(model.generate(input_ids, attention_mask=attention_mask))

    test_outputs = torch.cat(outputs, dim=0)
    print('Test Rouge scores:', compute_metrics((test_outputs.cpu().detach(), test_data["labels"].cpu().detach())))

    # Coverts output_ids back to string representation using the decode
    output_str = tokenizer.batch_decode(test_outputs, skip_special_tokens=True)
    
    bart_scores = []
    for idx, generation in enumerate(output_str):
      bart_score = bart_scorer.score([generation], [test_dict['output'][idx]])[0]

      bart_scores.append((generation, test_dict['output'][idx], bart_score))
  
    return output_str, bart_scores


def main():
    """
    Main function of the script!
    """

    test = datasets.load_dataset('esnli', split='test').shuffle(seed)

    test_dict = test.map(preprocess, remove_columns=['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2',
                                                     'explanation_3'])
    
    # Evaluate on our test set
    predictions_after_tuning, bart_scores = generate_predictions(model, test_dict)

    pf = f'pickles/BART_bart_scores.pk'
    with open(pf, 'wb') as f:
      pickle_rick.dump(bart_scores, f)

if __name__ == "__main__":
    main()

Reusing dataset esnli (/root/.cache/huggingface/datasets/esnli/plain_text/0.0.2/a160e6a02bbb8d828c738918dafec4e7d298782c334b5109af632fec6d779bbc)


  0%|          | 0/9824 [00:00<?, ?ex/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

Running batch: 1 of total 307
Running batch: 2 of total 307
Running batch: 3 of total 307
Running batch: 4 of total 307
Running batch: 5 of total 307
Running batch: 6 of total 307
Running batch: 7 of total 307
Running batch: 8 of total 307
Running batch: 9 of total 307
Running batch: 10 of total 307
Running batch: 11 of total 307
Running batch: 12 of total 307
Running batch: 13 of total 307
Running batch: 14 of total 307
Running batch: 15 of total 307
Running batch: 16 of total 307
Running batch: 17 of total 307
Running batch: 18 of total 307
Running batch: 19 of total 307
Running batch: 20 of total 307
Running batch: 21 of total 307
Running batch: 22 of total 307
Running batch: 23 of total 307
Running batch: 24 of total 307
Running batch: 25 of total 307
Running batch: 26 of total 307
Running batch: 27 of total 307
Running batch: 28 of total 307
Running batch: 29 of total 307
Running batch: 30 of total 307
Running batch: 31 of total 307
Running batch: 32 of total 307
Running batch: 33