In [None]:
# =========================================================
# NewsSumm LED Fine-Tuning with AMD GPU (DirectML)
# Single-file script
# =========================================================

import pandas as pd
import torch
import torch_directml
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
from bert_score import score

# ---------------------------------------------------------
# 1. SET DEVICE (AMD GPU via DirectML)
# ---------------------------------------------------------
device = torch_directml.device()
print("Using device:", device)

# ---------------------------------------------------------
# 2. LOAD CLEANED NEWS SUMM DATASET
# ---------------------------------------------------------
df = pd.read_excel("newssumm_cleaned.xlsx")

# VERY IMPORTANT: start small (you can increase later)
df = df.sample(2000, random_state=42)

# Keep only required columns
df = df[["article_text", "human_summary"]]

# Convert to HuggingFace Dataset
dataset = Dataset.from_pandas(df)

# Train / validation split
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_data = dataset["train"]
val_data = dataset["test"]

print("Train samples:", len(train_data))
print("Validation samples:", len(val_data))

# ---------------------------------------------------------
# 3. LOAD MODEL & TOKENIZER
# ---------------------------------------------------------
model_name = "allenai/led-base-16384"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name)

# Required settings for LED
model.config.use_cache = False
model.to(device)

# ---------------------------------------------------------
# 4. PREPROCESS FUNCTION (NEW API â€“ NO ERROR)
# ---------------------------------------------------------
def preprocess(batch):
    model_inputs = tokenizer(
        batch["article_text"],
        truncation=True,
        padding="max_length",
        max_length=1024
    )

    labels = tokenizer(
        text_target=batch["human_summary"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize
train_tokenized = train_data.map(
    preprocess,
    batched=True,
    remove_columns=train_data.column_names
)

val_tokenized = val_data.map(
    preprocess,
    batched=True,
    remove_columns=val_data.column_names
)

# ---------------------------------------------------------
# 5. TRAINING ARGUMENTS (DIRECTML SAFE)
# ---------------------------------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir="./newssumm_led_results",
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=2e-5,
    num_train_epochs=1,       # ONE epoch only
    save_steps=500,
    eval_steps=500,
    logging_steps=100,
    fp16=False,               # DirectML DOES NOT support fp16
    report_to="none"
)

# ---------------------------------------------------------
# 6. TRAINER
# ---------------------------------------------------------
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator
)

# ---------------------------------------------------------
# 7. TRAIN MODEL (THIS IS FINE-TUNING)
# ---------------------------------------------------------
trainer.train()

# ---------------------------------------------------------
# 8. GENERATE SUMMARY (EVALUATION)
# ---------------------------------------------------------
model.eval()

sample_article = val_data[0]["article_text"][:1500]
reference_summary = val_data[0]["human_summary"]

inputs = tokenizer(
    sample_article,
    return_tensors="pt",
    truncation=True,
    max_length=1024
)

inputs = {k: v.to(device) for k, v in inputs.items()}

summary_ids = model.generate(
    inputs["input_ids"],
    max_length=200,
    num_beams=4
)

generated_summary = tokenizer.decode(
    summary_ids[0],
    skip_special_tokens=True
)

print("\nGENERATED SUMMARY:\n", generated_summary)

# ---------------------------------------------------------
# 9. ROUGE SCORE
# ---------------------------------------------------------
rouge = evaluate.load("rouge")

rouge_scores = rouge.compute(
    predictions=[generated_summary],
    references=[reference_summary]
)

print("\nROUGE SCORES:", rouge_scores)

# ---------------------------------------------------------
# 10. BERTScore (WHAT THEY WANT)
# ---------------------------------------------------------
P, R, F1 = score(
    [generated_summary],
    [reference_summary],
    lang="en",
    model_type="microsoft/deberta-xlarge-mnli"
)

print("\nBERTScore F1:", F1.mean().item())
