This notebook is used just to test using the evaluate package to calculate Rouge and Bert metrics.

In [17]:
!pip install evaluate
!pip install rouge_score
!pip install bert_score



In [18]:
from datasets import load_dataset
import transformers
import torch
import evaluate
import numpy as np

In [19]:
#Download dataset from HF
dataset = load_dataset('BI55/MedText', split='train')

In [20]:
#Get model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')

In [21]:
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

In [22]:
device = torch.device('cpu')
model.to(device)
LIMIT = 10
limited_data_set = dataset.select(range(LIMIT))
predictions = []
references = []
for i in limited_data_set:
  prompt, completion = i['Prompt'], i['Completion']
  tokenized_prompt = tokenizer(prompt, return_tensors='pt')
  with torch.no_grad():
    output = model.generate(**tokenized_prompt, max_length=100, pad_token_id=tokenizer.eos_token_id)
  detokenized_output = tokenizer.decode(output[0], skip_special_tokens=True)
  predictions.append(detokenized_output)
  references.append(completion)
rouge_results = rouge.compute(predictions=predictions, references=references)
bert_results = bertscore.compute(predictions=predictions, references=references, lang="en")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
bert_percision = np.mean(bert_results['precision'])
bert_recall = np.mean(bert_results['recall'])
bert_f1 = np.mean(bert_results['f1'])
print(f"rouge1: {rouge_results["rouge1"]}, rouge2: {rouge_results["rouge2"]}, rougeL: {rouge_results["rougeL"]}")
print(f"bert precision: {bert_percision}, bert recall: {bert_recall}, bert f1: {bert_f1}")

rouge1: 0.28613307616326844, rouge2: 0.07950857749860055, rougeL: 0.17590193929247472
bert precision: 0.8484771311283111, bert recall: 0.8493280410766602, bert f1: 0.8488802552223206
