In [1]:
import os
cache_dir = "/scratches/dialfs/alta/hln35/.cache"
os.environ['TRANSFORMERS_CACHE'] = '/scratches/dialfs/alta/hln35/.cache'

In [2]:
from datasets import load_dataset
from evaluate import load

raw_datasets = load_dataset("xsum", cache_dir=cache_dir)
metric = load("rouge")



In [3]:
import torch
model_checkpoint = "google/flan-t5-small"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)



In [5]:
max_input_length = 1024
max_target_length = 128
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [6]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [7]:
raw_datasets["train"][1]

{'document': 'A fire alarm went off at the Holiday Inn in Hope Street at about 04:20 BST on Saturday and guests were asked to leave the hotel.\nAs they gathered outside they saw the two buses, parked side-by-side in the car park, engulfed by flames.\nOne of the tour groups is from Germany, the other from China and Taiwan. It was their first night in Northern Ireland.\nThe driver of one of the buses said many of the passengers had left personal belongings on board and these had been destroyed.\nBoth groups have organised replacement coaches and will begin their tour of the north coast later than they had planned.\nPolice have appealed for information about the attack.\nInsp David Gibson said: "It appears as though the fire started under one of the buses before spreading to the second.\n"While the exact cause is still under investigation, it is thought that the fire was started deliberately."',
 'summary': 'Two tourist buses have been destroyed by fire in a suspected arson attack in Belf

In [8]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [9]:
tokenized_datasets["train"][5]

{'document': 'Simone Favaro got the crucial try with the last move of the game, following earlier touchdowns by Chris Fusaro, Zander Fagerson and Junior Bulumakau.\nRynard Landman and Ashton Hewitt got a try in either half for the Dragons.\nGlasgow showed far superior strength in depth as they took control of a messy match in the second period.\nHome coach Gregor Townsend gave a debut to powerhouse Fijian-born Wallaby wing Taqele Naiyaravoro, and centre Alex Dunbar returned from long-term injury, while the Dragons gave first starts of the season to wing Aled Brew and hooker Elliot Dee.\nGlasgow lost hooker Pat McArthur to an early shoulder injury but took advantage of their first pressure when Rory Clegg slotted over a penalty on 12 minutes.\nIt took 24 minutes for a disjointed game to produce a try as Sarel Pretorius sniped from close range and Landman forced his way over for Jason Tovey to convert - although it was the lock\'s last contribution as he departed with a chest injury shor

In [10]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11334
    })
})

In [15]:
test_input_ids = tokenized_datasets["train"]["input_ids"]
labels = tokenizer.batch_decode(tokenized_datasets["train"]["labels"], skip_special_tokens=True)

In [18]:
import numpy as np

In [22]:
model_small_distill = f"model/flant5_small_lr_10-4_race_distill_epoch2"
model_small_distill = AutoModelForSeq2SeqLM.from_pretrained(model_small_distill, local_files_only=True).to(device)
results_small_distill = {}
for i in range(1, 2, 1):
        test_tensor = torch.tensor([test_input_ids[i]]).to(device)
        preds_distill = model_small_distill.generate(test_tensor, max_new_tokens=max_target_length, do_sample=False)  
        preds = model.generate(test_tensor, max_new_tokens=max_target_length, do_sample=False)  
        
        preds_distill = tokenizer.batch_decode(preds_distill, skip_special_tokens=True)
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True)  
    
        print(preds_distill, len(preds_distill[0].split(" ")))
        print(preds)
        result = metric.compute(predictions=preds_distill, references=[labels[i]], use_stemmer=True, use_aggregator=False)
        for key, value in result.items():
            if key not in results_small_distill:
                results_small_distill[key] = value
            else:
                results_small_distill[key] += value
results_small_distill_agg = {}

for k, v in results_small_distill.items():
    results_small_distill_agg[k] = np.average(v)
print(f"the average score is: ")
print(results_small_distill_agg)

[' fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm  fire alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm alarm'] 128
['A group of tourists has been attacked by a fire in a bus park in Northern Ireland.']
the average score is: 
{'rouge1': 0.01869158878504673, 'rouge2': 0.0, 'rougeL': 0.01869158878504673, 'rougeLsum': 0.01869158878504673}


In [11]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

In [12]:
import torch
import numpy as np

In [13]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir="model/flant5_small_lr_10-5_wd_10-2",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    save_strategy="epoch",
    num_train_epochs=4,
    predict_with_generate=True,
    load_best_model_at_end=True,
    
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.5513,2.275294,0.2926,0.0854,0.2321,0.2321,18.8551
2,2.5177,2.25592,0.2954,0.0875,0.2348,0.2348,18.8309
3,2.5031,2.24727,0.2968,0.0884,0.236,0.2359,18.8311
4,2.4916,2.244451,0.2972,0.0886,0.2361,0.236,18.8269




TrainOutput(global_step=25508, training_loss=2.5254513894724626, metrics={'train_runtime': 18017.8211, 'train_samples_per_second': 45.298, 'train_steps_per_second': 1.416, 'total_flos': 3.02870562806956e+17, 'train_loss': 2.5254513894724626, 'epoch': 4.0})

In [None]:
# tokenizer2 = AutoTokenizer.from_pretrained(model_checkpoint)
# inputs = tokenizer2(raw_datasets["test"]["document"], return_tensors = "pt").input_ids

In [None]:
# test_tensor = torch.tensor(tokenized_datasets["test"]["input_ids"])[:1000]
# preds = model.generate(test_tensor, max_new_tokens=max_target_length, do_sample=False)  
# # preds = torch.tensor(preds)                                                                  
# preds = tokenizer.batch_decode(preds, skip_special_tokens=True)        
# print(preds.shape)

In [None]:

# result = metric.compute(predictions=preds, references=labels, use_stemmer=True, use_aggregator=False)
# result = {key: value for key, value in result.items()}
# result

In [None]:
# for input_id in tokenized_datasets["test"]["input_ids"]:
#     print(input_id)
#     output = model.generate(input_id, max_new_tokens=max_target_length, do_sample=False)
test_input_ids = tokenized_datasets["test"]["input_ids"]
results = {}
group_len = 20
for i in range(0, len(test_input_ids)):
        test_tensor = torch.tensor([test_input_ids[i]])
        preds = model.generate(test_tensor, max_new_tokens=max_target_length, do_sample=False)                                                               
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True)                          
        result = metric.compute(predictions=preds, references=[labels[i]], use_stemmer=True, use_aggregator=False)
        for key, value in result.items():
            if key not in results:
                results[key] = value
            else:
                results[key] += value

import json
with open("rouge_small_fine_tuned_.txt", "w") as fp:
    json.dump(results, fp)





In [None]:
print(test_tensor.shape)

In [None]:
# result = metric.compute(predictions=preds, references=labels, use_stemmer=True, use_aggregator=False)
# # Extract a few results
# result = {key: value for key, value in result.items()}

In [None]:
results


In [None]:
results_small_agg = {}

for k, v in results.items():
    results_small_agg[k] = np.average(v)

In [None]:
results_large_agg

In [None]:
results_small_agg