<a href="https://colab.research.google.com/github/edwin-19/Text-Generation-Comparison/blob/main/T5_Text_Comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
!pip install transformers
!pip install datasets
!pip install tokenizers
!pip install sentencepiece
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


# Import libraries

In [41]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
from datasets import load_metric
from pprint import pprint

import torch
import random
import numpy as np

from datasets import load_dataset
from datasets import load_metric
from tqdm.notebook import tqdm

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [20]:
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")

t5_model.eval();
t5_model.cuda();

In [5]:
dataset = load_dataset('d0r1h/ILC')

Using custom data configuration d0r1h--ILC-50a6ccb795aeb7a6


Downloading and preparing dataset csv/d0r1h--ILC to /root/.cache/huggingface/datasets/csv/d0r1h--ILC-50a6ccb795aeb7a6/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/36.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/d0r1h--ILC-50a6ccb795aeb7a6/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [33]:
metric = load_metric('rouge')

In [34]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_aggregator: Return aggregates if this is set to True
Retu

In [56]:
def get_batch(batch_input, batch_size=1):
  total_batch_size = dataset['train'].num_rows
  for i in range(0, total_batch_size, batch_size):
    yield batch_input[i:min(i + batch_size, total_batch_size)]

In [65]:
score = []
for data in tqdm(get_batch(dataset['train'], batch_size=32), total=dataset['train'].num_rows // 32):
  inputs = t5_tokenizer(data['Title'], return_tensors='pt',padding=True,truncation=True)
  inputs['input_ids'] = inputs['input_ids'].to(device)
  inputs['attention_mask'] = inputs['attention_mask'].to(device)

  outputs = t5_model.generate(
      **inputs, do_sample=True, num_beams=5, no_repeat_ngram_size=2, min_length=100, max_length=300, early_stopping=True
  )

  text_generated = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
  results = metric.compute(predictions=text_generated, references=data['Summary'])
  score.append(results["rouge1"].mid.fmeasure)

  0%|          | 0/64 [00:00<?, ?it/s]

In [66]:
'Average Rouge Score: {}'.format(np.mean(score)) 

'Average Rouge Score: 0.149423660577378'

In [63]:
!nvidia-smi

Thu Apr 28 17:51:29 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    73W / 149W |   8198MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [39]:
inputs, results

({'input_ids': tensor([[ 1958,   164,   916, 24819,     3,  3565,  5142,    13,     3,  3597,
           6294,     3,    10,  8531,  2243,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 {'rouge1': AggregateScore(low=Score(precision=0.5373134328358209, recall=0.0975609756097561, fmeasure=0.1651376146788991), mid=Score(precision=0.5373134328358209, recall=0.0975609756097561, fmeasure=0.1651376146788991), high=Score(precision=0.5373134328358209, recall=0.0975609756097561, fmeasure=0.1651376146788991)),
  'rouge2': AggregateScore(low=Score(precision=0.09090909090909091, recall=0.016304347826086956, fmeasure=0.027649769585253454), mid=Score(precision=0.09090909090909091, recall=0.016304347826086956, fmeasure=0.027649769585253454), high=Score(precision=0.09090909090909091, recall=0.016304347826086956, fmeasure=0.027649769585253454)),
  'rougeL': AggregateScore(low=Score(precision=0.3283582089552239, recall=0.0596205