In [None]:
import random
import time
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import evaluate
import torch
from fvcore.nn import FlopCountAnalysis

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)

ds = load_dataset("abisee/cnn_dailymail", "3.0.0")
test_dataset = ds['test']

subset_size = 100
random_indices = random.sample(range(len(test_dataset)), subset_size)
test_dataset = test_dataset.select(random_indices)
rouge = evaluate.load('rouge')
bleu = evaluate.load('bleu')
total_flops = 0
total_time = 0
outputs = []
targets = []

for example in tqdm(test_dataset):
    article = example['article']
    reference_summary = example['highlights']
    tokens = tokenizer.tokenize(article, max_length=1024, truncation=True)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    token_ids_tensor = torch.tensor(token_ids).to(device)
    
    with torch.no_grad():
        token_embeddings = model.model.shared(token_ids_tensor)
    
    importance_scores = torch.norm(token_embeddings, dim=1)
    
    threshold = torch.quantile(importance_scores, 0.25)
    mask = importance_scores > threshold
    pruned_token_ids = token_ids_tensor[mask]
    
    if pruned_token_ids.numel() == 0:
        pruned_token_ids = token_ids_tensor  # Fall back to original tokens
    
    inputs = {
        'input_ids': pruned_token_ids.unsqueeze(0).to(device),
        'attention_mask': torch.ones_like(pruned_token_ids).unsqueeze(0).to(device)
    }
    
    flops_analysis = FlopCountAnalysis(model, inputs['input_ids'])
    total_flops += flops_analysis.total()
    
    start_time = time.time()
    
    summary_ids = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask']
    )
    
    end_time = time.time()
    
    generated_summary = tokenizer.decode(
        summary_ids[0], 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True
    )
    
    total_time += (end_time - start_time)
    outputs.append(generated_summary)
    targets.append(reference_summary)

average_flops = total_flops / len(test_dataset)
average_time = total_time / len(test_dataset)

rouge_results = rouge.compute(predictions=outputs, references=targets)
bleu_results = bleu.compute(predictions=outputs, references=targets)

print("Inference Results:")
print(f"Average FLOPs: {average_flops:.2e}")
print(f"Average forward pass time: {average_time:.4f} seconds")
print("ROUGE Scores:", rouge_results)
print("BLEU Score:", bleu_results)
