In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from tqdm import tqdm

# Config - paper used BERT & DistilBERT
MODEL_NAME = "distilbert-base-uncased"
BATCH_SIZE = 16
EPOCHS = 3
LR = 2e-5

# Load dataset
class ExtractiveDataset(Dataset):
    def __init__(self, source_file, labels_file, tokenizer, max_len=256):
        self.samples = []
        with open(source_file) as f_src, open(labels_file) as f_lbl:
            for doc, lbl in zip(self._split_docs(f_src), self._split_labels(f_lbl)):
                for sent, label in zip(doc, lbl):
                    self.samples.append((sent, label))
        self.tokenizer = tokenizer
        self.max_len = max_len

    def _split_docs(self, f):
        doc, docs = [], []
        for line in f:
            line = line.strip()
            if line == "":
                if doc: docs.append(doc); doc = []
            else:
                doc.append(line)
        if doc: docs.append(doc)
        return docs

    def _split_labels(self, f):
        return [list(map(int, line.strip().split())) for line in f]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sent, label = self.samples[idx]
        enc = self.tokenizer(
            sent,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "label": torch.tensor(label, dtype=torch.long) 
        }

# Load model + tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# Set up DataLoader
train_ds = ExtractiveDataset(
    source_file="./textrank_train.source",
    labels_file="./textrank_oracle_labels.txt",
    tokenizer=tokenizer
)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

# Optimizer & Scheduler
optimizer = AdamW(model.parameters(), lr=LR)
total_steps = len(train_dl) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

for epoch in range(EPOCHS):
    pbar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device) 

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix({"loss": loss.item()})