In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import evaluate
from tqdm import tqdm
import random
import torchprofile
import time
from fvcore.nn import FlopCountAnalysis


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
ds = load_dataset("abisee/cnn_dailymail", "3.0.0")
test_dataset = ds['test']


In [None]:
subset_size = 100
random_indices = random.sample(range(len(test_dataset)), subset_size)
test_dataset = test_dataset.select(random_indices)

rouge = evaluate.load('rouge')
bleu = evaluate.load('bleu')
total_flops = 0
total_macs = 0
total_time = 0
outputs = []
targets = []

for example in tqdm(test_dataset):
    article = example['article']
    reference_summary = example['highlights']
    
    inputs = tokenizer(article, 
    return_tensors='pt',
    max_length=1024,
    truncation=True).to(device)

    vocab_size = model.config.vocab_size
    print(f"Vocab size: {vocab_size}")
    print(inputs['input_ids'].min(), inputs['input_ids'].max())
    
    flops_analysis = FlopCountAnalysis(model, inputs['input_ids'])
    total_flops += flops_analysis.total()
    
    start_time = time.time()
    
    summary_ids = model.generate(
        inputs['input_ids']
    )
    
    end_time = time.time()
    
    generated_summary = tokenizer.decode(
        summary_ids[0], 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True
    )
    print("GENERATED SUMMARY")
    print(generated_summary)
    print("TARGET SUMMARY")
    print(reference_summary)
    
    total_time += (end_time - start_time)
    
    outputs.append(generated_summary)
    targets.append(reference_summary)

average_flops = total_flops / len(test_dataset)
average_time = total_time / len(test_dataset)

rouge_results = rouge.compute(predictions=outputs, references=targets)
bleu_results = bleu.compute(predictions=outputs, references=targets)

print("Inference Results:")
print(f"Average FLOPs: {average_flops:.2e}")
print(f"Average forward pass time: {average_time:.4f} seconds")
print("ROUGE Scores:", rouge_results)
print("BLEU Score:", bleu_results)