In [1]:
from datasets import load_dataset
from evaluate import load

raw_datasets = load_dataset("xsum")
metric = load("rouge")



In [2]:
model_checkpoint = "google/flan-t5-small"

In [3]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, low_cpu_mem_usage=True, device_map="auto")



  return torch._C._cuda_getDeviceCount() > 0


In [4]:
max_input_length = 1024
max_target_length = 128
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding='max_length')

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [6]:
raw_datasets["train"][5]

{'document': 'Simone Favaro got the crucial try with the last move of the game, following earlier touchdowns by Chris Fusaro, Zander Fagerson and Junior Bulumakau.\nRynard Landman and Ashton Hewitt got a try in either half for the Dragons.\nGlasgow showed far superior strength in depth as they took control of a messy match in the second period.\nHome coach Gregor Townsend gave a debut to powerhouse Fijian-born Wallaby wing Taqele Naiyaravoro, and centre Alex Dunbar returned from long-term injury, while the Dragons gave first starts of the season to wing Aled Brew and hooker Elliot Dee.\nGlasgow lost hooker Pat McArthur to an early shoulder injury but took advantage of their first pressure when Rory Clegg slotted over a penalty on 12 minutes.\nIt took 24 minutes for a disjointed game to produce a try as Sarel Pretorius sniped from close range and Landman forced his way over for Jason Tovey to convert - although it was the lock\'s last contribution as he departed with a chest injury shor

In [7]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [8]:
tokenized_datasets["train"][5]

{'document': 'Simone Favaro got the crucial try with the last move of the game, following earlier touchdowns by Chris Fusaro, Zander Fagerson and Junior Bulumakau.\nRynard Landman and Ashton Hewitt got a try in either half for the Dragons.\nGlasgow showed far superior strength in depth as they took control of a messy match in the second period.\nHome coach Gregor Townsend gave a debut to powerhouse Fijian-born Wallaby wing Taqele Naiyaravoro, and centre Alex Dunbar returned from long-term injury, while the Dragons gave first starts of the season to wing Aled Brew and hooker Elliot Dee.\nGlasgow lost hooker Pat McArthur to an early shoulder injury but took advantage of their first pressure when Rory Clegg slotted over a penalty on 12 minutes.\nIt took 24 minutes for a disjointed game to produce a try as Sarel Pretorius sniped from close range and Landman forced his way over for Jason Tovey to convert - although it was the lock\'s last contribution as he departed with a chest injury shor

In [9]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11334
    })
})

In [10]:
labels = tokenizer.batch_decode(tokenized_datasets["test"]["labels"], skip_special_tokens=True)
labels

['There is a "chronic" need for more housing for prison leavers in Wales, according to a charity.',
 'A man has appeared in court after firearms, ammunition and cash were seized by police in Edinburgh.',
 'Four people accused of kidnapping and torturing a mentally disabled man in a "racially motivated" attack streamed on Facebook have been denied bail.',
 'West Brom have appointed Nicky Hammond as technical director, ending his 20-year association with Reading.',
 'The pancreas can be triggered to regenerate itself through a type of fasting diet, say US researchers.',
 'Since their impending merger was announced in January, there has been remarkably little comment about the huge proposed deal to combine Essilor and Luxottica.',
 'A "medal at any cost" approach created a "culture of fear" at British Cycling, says former rider Wendy Houvenaghel.',
 'Have you heard the one about the computer programmer who bought a failing comedy club in Texas and turned it into a million dollar a year bu

In [11]:
import nltk
nltk.download('punkt')
labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels][:10]
labels

[nltk_data] Downloading package punkt to /home/mifs/hln35/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


['There is a "chronic" need for more housing for prison leavers in Wales, according to a charity.',
 'A man has appeared in court after firearms, ammunition and cash were seized by police in Edinburgh.',
 'Four people accused of kidnapping and torturing a mentally disabled man in a "racially motivated" attack streamed on Facebook have been denied bail.',
 'West Brom have appointed Nicky Hammond as technical director, ending his 20-year association with Reading.',
 'The pancreas can be triggered to regenerate itself through a type of fasting diet, say US researchers.',
 'Since their impending merger was announced in January, there has been remarkably little comment about the huge proposed deal to combine Essilor and Luxottica.',
 'A "medal at any cost" approach created a "culture of fear" at British Cycling, says former rider Wendy Houvenaghel.',
 'Have you heard the one about the computer programmer who bought a failing comedy club in Texas and turned it into a million dollar a year bu

In [12]:
import torch
import numpy as np

In [13]:
for a in tokenized_datasets["test"]["input_ids"][:10]:
    print(len(a))

1024
1024
1024
1024
1024
1024
1024
1024
1024
1024


In [14]:
# tokenizer2 = AutoTokenizer.from_pretrained(model_checkpoint)
# inputs = tokenizer2(raw_datasets["test"]["document"], return_tensors = "pt").input_ids

In [15]:
# for input_id in tokenized_datasets["test"]["input_ids"]:
#     print(input_id)
#     output = model.generate(input_id, max_new_tokens=max_target_length, do_sample=False)
test_tensor = torch.tensor(tokenized_datasets["test"]["input_ids"])
print(test_tensor.shape)
results = []
for i in range(0, len(test_tensor), 10):
    if i < len(test_tensor) - 10:
        preds = model.generate(test_tensor[i:i+10], max_new_tokens=max_target_length, do_sample=False)  
    # preds = torch.tensor(preds)                                                                  
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True)                            
        result = metric.compute(predictions=preds, references=labels[i:i+10], use_stemmer=True, use_aggregator=False)
        results.append(result['rougeL'])
    else:
        preds = model.generate(test_tensor[i:], max_new_tokens=max_target_length, do_sample=False)  
    # preds = torch.tensor(preds)                                                                  
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True)                            
        result = metric.compute(predictions=preds, references=labels[i:], use_stemmer=True, use_aggregator=False)
        results.append(result['rougeL'])






torch.Size([11334, 1024])
torch.Size([1024])
torch.Size([1024])


ValueError: not enough values to unpack (expected 2, got 1)

In [None]:
print(test_tensor.shape)

In [None]:
# result = metric.compute(predictions=preds, references=labels, use_stemmer=True, use_aggregator=False)
# # Extract a few results
# result = {key: value for key, value in result.items()}

In [None]:
results
