# Imports

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import random
import torch

# Load Data

In [56]:
train_dataset = load_dataset("scientific_papers", "pubmed", split="train")

Found cached dataset scientific_papers (C:/Users/ronna/.cache/huggingface/datasets/scientific_papers/pubmed/1.1.1/306757013fb6f37089b6a75469e6638a553bd9f009484938d8f75a4c5e84206f)


In [57]:
val_dataset = load_dataset("scientific_papers", "pubmed", split="validation")

Found cached dataset scientific_papers (C:/Users/ronna/.cache/huggingface/datasets/scientific_papers/pubmed/1.1.1/306757013fb6f37089b6a75469e6638a553bd9f009484938d8f75a4c5e84206f)


In [58]:
test_dataset = load_dataset("scientific_papers", "pubmed", split="test")

Found cached dataset scientific_papers (C:/Users/ronna/.cache/huggingface/datasets/scientific_papers/pubmed/1.1.1/306757013fb6f37089b6a75469e6638a553bd9f009484938d8f75a4c5e84206f)


# Preprocess Data

## Load Tokenizer

In [59]:
tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")

## Set Params

In [60]:
max_input_length = 8192
max_output_length = 512
batch_size = 2

In [53]:
def process_data_to_model_inputs(batch):
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=max_input_length)
    outputs = tokenizer(batch["abstract"], padding="max_length", truncation=True, max_length=max_output_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["global_attention_mask"] = len(batch["input_ids"]) * [[0 for _ in range(len(batch["input_ids"][0]))]]
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

## Downsample

In [63]:
num_shards = 100
sub_train_dataset = train_dataset.shard(num_shards=num_shards, index=random.randint(0, num_shards - 1))
sub_val_dataset = val_dataset.shard(num_shards=num_shards, index=random.randint(0, num_shards - 1))

## Tokenize and Convert to Torch

In [64]:
sub_train_dataset = sub_train_dataset.map(process_data_to_model_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "abstract", "section_names"])
sub_val_dataset = sub_val_dataset.map(process_data_to_model_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "abstract", "section_names"])

                                                                                                                       

In [66]:
sub_train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "global_attention_mask", "labels"])
sub_val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "global_attention_mask", "labels"])

## Test Model 1

In [67]:
led = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False)

In [68]:
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

In [None]:
random_index = random.randint(0, len(sub_val_dataset) - 1)
sample = sub_val_dataset[random_index]

input_ids = sample["input_ids"].unsqueeze(0).to(led.device)
attention_mask = sample["attention_mask"].unsqueeze(0).to(led.device)
global_attention_mask = sample["global_attention_mask"].unsqueeze(0).to(led.device)

with torch.no_grad():
    summary_ids = led.generate(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)

generated_summary = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
actual_summary = tokenizer.decode(sample["labels"], skip_special_tokens=True)

# Print and compare both summaries
print("Generated Summary:")
print(generated_summary)
print("\nActual Summary:")
print(actual_summary)

