In [None]:
!pip install transformers datasets nltk rouge-score

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BartTokenizer, BartForConditionalGeneration, get_linear_schedule_with_warmup
from tqdm import tqdm
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk.tokenize import sent_tokenize

# === Load your DistilBERT extractive model ===
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
distilbert_model_path = "emmabry/disilbert-eurlex"
distilbert_model = DistilBertForSequenceClassification.from_pretrained(distilbert_model_path).to(DEVICE)
distilbert_tokenizer = DistilBertTokenizer.from_pretrained(distilbert_model_path)
distilbert_model.eval()

def distilbert_extract(text, max_tokens=1024):
    sentences = sent_tokenize(text)
    selected_sents = []
    current_tokens = 0

    for sent in sentences:
        inputs = distilbert_tokenizer(sent, truncation=True, padding=True, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            outputs = distilbert_model(**inputs)
            pred = torch.argmax(outputs.logits, dim=1).item()
        if pred == 1:
            tokens = len(distilbert_tokenizer(sent)["input_ids"])
            if current_tokens + tokens <= max_tokens:
                selected_sents.append(sent)
                current_tokens += tokens
            else:
                break

    return " ".join(selected_sents)

# === Custom Dataset using distilbert_extract ===
class SummarizationDataset(Dataset):
    def __init__(self, source_file, target_file, tokenizer, max_source_len=1024, max_target_len=128):
        self.sources = []
        self.targets = []
        with open(source_file, encoding='utf-8') as f_src, open(target_file, encoding='utf-8') as f_tgt:
            for src_line, tgt_line in zip(f_src, f_tgt):
                self.sources.append(src_line.strip())
                self.targets.append(tgt_line.strip())
        self.tokenizer = tokenizer
        self.max_source_len = max_source_len
        self.max_target_len = max_target_len

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        raw_source = self.sources[idx]
        target = self.targets[idx]

        # Apply extractive summarisation
        extracted = distilbert_extract(raw_source)

        # Tokenize
        source_enc = self.tokenizer(
            extracted,
            max_length=self.max_source_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_enc = self.tokenizer(
            target,
            max_length=self.max_target_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_enc['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # ignore pad tokens in loss

        return {
            'input_ids': source_enc['input_ids'].squeeze(),
            'attention_mask': source_enc['attention_mask'].squeeze(),
            'labels': labels
        }

# === Model, Tokenizer and Hyperparameters ===
MODEL_NAME = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)

BATCH_SIZE = 2
EPOCHS = 3
LR = 2e-5

# === Load dataset ===
train_ds = SummarizationDataset('train.source', 'train.target', tokenizer)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

# === Optimizer & Scheduler ===
optimizer = AdamW(model.parameters(), lr=LR)
total_steps = len(train_dl) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# === Training Loop ===
model.train()
for epoch in range(EPOCHS):
    pbar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix({"loss": loss.item()})
