# GRU with attention: training and inference

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## First attempt - LSTM (no attention)

In [None]:
"""
Train a Seq2Seq LSTM on the XSum dataset with early stopping, saving of best and final models, and tokenizer.
"""
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer
from tqdm import tqdm

# ─────────────────────────────────────────────
# CONFIGURATION
# ─────────────────────────────────────────────
MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 64
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 1e-3
NUM_WORKERS = 2
PATIENCE = 2    # epochs with no improvement before stopping
DELTA = 0.0     # minimum change to qualify as improvement

# Specify your data directory
DATA_DIR = "/content/drive/MyDrive/xsum"
SAVE_DIR = DATA_DIR
os.makedirs(SAVE_DIR, exist_ok=True)

# DEVICE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ─────────────────────────────────────────────
# CUSTOM DATASET CLASS
# ─────────────────────────────────────────────
class XSumDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.inputs = df["document"].tolist()
        self.targets = df["summary"].tolist()
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        src = self.inputs[idx]
        tgt = self.targets[idx]

        src_enc = self.tokenizer(
            src,
            max_length=MAX_INPUT_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        tgt_enc = self.tokenizer(
            tgt,
            max_length=MAX_TARGET_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": src_enc["input_ids"].squeeze(0),
            "target_ids": tgt_enc["input_ids"].squeeze(0)
        }

# ─────────────────────────────────────────────
# SEQ2SEQ LSTM MODEL DEFINITION
# ─────────────────────────────────────────────
class Seq2SeqLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, src, trg):
        # src: [batch_size, src_len]
        # trg: [batch_size, trg_len]
        embedded_src = self.embedding(src)
        _, (hidden, cell) = self.encoder(embedded_src)

        embedded_trg = self.embedding(trg)
        outputs, _ = self.decoder(embedded_trg, (hidden, cell))
        logits = self.fc_out(outputs)
        return logits

# ─────────────────────────────────────────────
# LOAD DATA AND TOKENIZER
# ─────────────────────────────────────────────
tokenizer = T5Tokenizer.from_pretrained("t5-small")
train_df = pd.read_csv(os.path.join(DATA_DIR, "xsum_train.csv")).dropna().reset_index(drop=True)
val_df   = pd.read_csv(os.path.join(DATA_DIR, "xsum_val.csv")).dropna().reset_index(drop=True)

# Optionally subsample for faster iteration
#train_df = train_df.sample(frac=0.5, random_state=42).reset_index(drop=True)
#val_df   = val_df.sample(frac=0.5, random_state=42).reset_index(drop=True)

train_dataset = XSumDataset(train_df, tokenizer)
val_dataset   = XSumDataset(val_df, tokenizer)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

# ─────────────────────────────────────────────
# INITIALIZE MODEL, OPTIMIZER, CRITERION
# ─────────────────────────────────────────────
vocab_size = tokenizer.vocab_size
model = Seq2SeqLSTM(vocab_size, EMBEDDING_DIM, HIDDEN_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# ─────────────────────────────────────────────
# TRAINING LOOP WITH EARLY STOPPING
# ─────────────────────────────────────────────
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(1, EPOCHS + 1):
    # Training
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        src = batch["input_ids"].to(DEVICE)
        trg = batch["target_ids"].to(DEVICE)

        optimizer.zero_grad()
        # input trg except last token, to predict next tokens
        output = model(src, trg[:, :-1])
        loss = criterion(
            output.reshape(-1, vocab_size),
            trg[:, 1:].reshape(-1)
        )
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            src = batch["input_ids"].to(DEVICE)
            trg = batch["target_ids"].to(DEVICE)
            output = model(src, trg[:, :-1])
            loss = criterion(
                output.reshape(-1, vocab_size),
                trg[:, 1:].reshape(-1)
            )
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"\tValidation Loss: {avg_val_loss:.4f}")

    # Early Stopping Check
    if avg_val_loss < best_val_loss - DELTA:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        best_path = os.path.join(SAVE_DIR, 'best_model.pt')
        torch.save(model.state_dict(), best_path)
        print(f"Validation loss improved; saved best model to {best_path}")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s)")
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping triggered at epoch {epoch}")
            break

# Final Save
final_model_path = os.path.join(SAVE_DIR, "seq2seq_final.pt")
torch.save(model.state_dict(), final_model_path)
print(f"Saved final model state_dict to {final_model_path}")
tokenizer.save_pretrained(os.path.join(SAVE_DIR, 'tokenizer'))
print(f"Saved tokenizer to {os.path.join(SAVE_DIR, 'tokenizer')}")

Training Epoch 1: 100%|██████████| 1594/1594 [07:41<00:00,  3.45it/s]


Epoch 1 | Train Loss: 5.0312


Validation: 100%|██████████| 89/89 [00:16<00:00,  5.53it/s]


	Validation Loss: 4.3337
Validation loss improved; saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 2: 100%|██████████| 1594/1594 [07:10<00:00,  3.71it/s]


Epoch 2 | Train Loss: 4.0052


Validation: 100%|██████████| 89/89 [00:16<00:00,  5.32it/s]


	Validation Loss: 3.9501
Validation loss improved; saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 3: 100%|██████████| 1594/1594 [07:12<00:00,  3.69it/s]


Epoch 3 | Train Loss: 3.6155


Validation: 100%|██████████| 89/89 [00:16<00:00,  5.49it/s]


	Validation Loss: 3.7822
Validation loss improved; saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 4: 100%|██████████| 1594/1594 [07:11<00:00,  3.70it/s]


Epoch 4 | Train Loss: 3.3577


Validation: 100%|██████████| 89/89 [00:15<00:00,  5.64it/s]


	Validation Loss: 3.6851
Validation loss improved; saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 5: 100%|██████████| 1594/1594 [07:10<00:00,  3.70it/s]


Epoch 5 | Train Loss: 3.1578


Validation: 100%|██████████| 89/89 [00:16<00:00,  5.52it/s]


	Validation Loss: 3.6371
Validation loss improved; saved best model to /content/drive/MyDrive/xsum/best_model.pt
Saved final model state_dict to /content/drive/MyDrive/xsum/seq2seq_final.pt
Saved tokenizer to /content/drive/MyDrive/xsum/tokenizer


## Final TRAINING - GRU with attention

In [None]:
"""
Seq2Seq GRU with attention on the XSum dataset.
Includes dropout regularization.
"""
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer
from tqdm import tqdm
from torch.amp import autocast

# ─────────────────────────────────────────────
# CONFIGURATION
# ─────────────────────────────────────────────
MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 64
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
BATCH_SIZE = 64
EPOCHS = 15
LEARNING_RATE = 1e-3
NUM_WORKERS = 2
PATIENCE = 2
DROPOUT = 0.1

DATA_DIR = "/content/drive/MyDrive/xsum"
SAVE_DIR = DATA_DIR
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ─────────────────────────────────────────────
# DATASET CLASS
# ─────────────────────────────────────────────
class XSumDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.inputs = df["document"].tolist()
        self.targets = df["summary"].tolist()
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        src = self.inputs[idx]
        tgt = self.targets[idx]

        src_enc = self.tokenizer(src, max_length=MAX_INPUT_LEN, padding="max_length", truncation=True, return_tensors="pt")
        tgt_enc = self.tokenizer(tgt, max_length=MAX_TARGET_LEN, padding="max_length", truncation=True, return_tensors="pt")

        return {
            "input_ids": src_enc["input_ids"].squeeze(0),
            "target_ids": tgt_enc["input_ids"].squeeze(0)
        }

# ─────────────────────────────────────────────
# ENCODER AND DECODER CLASSES WITH DROPOUT
# ─────────────────────────────────────────────
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, pad_token_id):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_token_id)
        self.dropout = nn.Dropout(DROPOUT)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True, num_layers=1)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        outputs, hidden = self.gru(embedded)
        return outputs, hidden

class AttnDecoderRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, pad_token_id):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_token_id)
        self.dropout = nn.Dropout(DROPOUT)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True, num_layers=1)
        self.attn_combine = nn.Linear(hidden_dim * 2, hidden_dim)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, encoder_outputs, target):
        B, T = target.shape
        embedded = self.dropout(self.embedding(target))
        hidden = encoder_outputs[:, -1:, :].transpose(0, 1)
        outputs = torch.zeros(B, T, self.out.out_features, device=target.device)
        input_tok = target[:, 0]

        for t in range(1, T):
            input_emb = self.dropout(self.embedding(input_tok)).unsqueeze(1)
            output, hidden = self.gru(input_emb, hidden)

            attn_scores = torch.bmm(encoder_outputs, output.transpose(1, 2)).squeeze(2)
            attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(1)
            context = torch.bmm(attn_weights, encoder_outputs)

            combined = torch.cat([output, context], dim=2)
            combined = torch.tanh(self.attn_combine(combined))
            token_logits = self.out(self.dropout(combined.squeeze(1)))
            outputs[:, t] = token_logits

            input_tok = target[:, t]

        return outputs

class Seq2SeqGRUAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, pad_token_id):
        super().__init__()
        self.encoder = EncoderRNN(vocab_size, embedding_dim, hidden_dim, pad_token_id)
        self.decoder = AttnDecoderRNN(vocab_size, embedding_dim, hidden_dim, pad_token_id)

    def forward(self, src, trg):
        encoder_outputs, _ = self.encoder(src)
        output = self.decoder(encoder_outputs, trg)
        return output

# ─────────────────────────────────────────────
# LOAD DATA AND TOKENIZER
# ─────────────────────────────────────────────
tokenizer = T5Tokenizer.from_pretrained("t5-small")
train_df = pd.read_csv(os.path.join(DATA_DIR, "xsum_train.csv")).dropna().reset_index(drop=True)
val_df = pd.read_csv(os.path.join(DATA_DIR, "xsum_val.csv")).dropna().reset_index(drop=True)

train_dataset = XSumDataset(train_df, tokenizer)
val_dataset = XSumDataset(val_df, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ─────────────────────────────────────────────
# TRAINING SETUP
# ─────────────────────────────────────────────
vocab_size = tokenizer.vocab_size
pad_token_id = tokenizer.pad_token_id
model = Seq2SeqGRUAttention(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, pad_token_id).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)

best_val_loss = float('inf')
epochs_no_improve = 0

# ─────────────────────────────────────────────
# TRAINING LOOP
# ─────────────────────────────────────────────
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        src = batch["input_ids"].to(DEVICE)
        trg = batch["target_ids"].to(DEVICE)

        optimizer.zero_grad()
        with autocast(device_type=DEVICE):
            output = model(src, trg)
            loss = criterion(output[:, 1:].reshape(-1, vocab_size), trg[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            src = batch["input_ids"].to(DEVICE)
            trg = batch["target_ids"].to(DEVICE)
            with autocast(device_type=DEVICE):
                output = model(src, trg)
                loss = criterion(output[:, 1:].reshape(-1, vocab_size), trg[:, 1:].reshape(-1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"\tValidation Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        best_path = os.path.join(SAVE_DIR, 'best_model.pt')
        torch.save(model.state_dict(), best_path)
        print(f"Validation improved. Saved best model to {best_path}")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s)")
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

# ─────────────────────────────────────────────
# SAVE FINAL MODEL AND TOKENIZER
# ─────────────────────────────────────────────
final_model_path = os.path.join(SAVE_DIR, "seq2seq_gru_attention_final.pt")
torch.save(model.state_dict(), final_model_path)
torch.jit.save(torch.jit.script(model), os.path.join(SAVE_DIR, "seq2seq_gru_attention_scripted.pt"))
tokenizer.save_pretrained(os.path.join(SAVE_DIR, 'tokenizer'))
print(f"Saved final model to {final_model_path} and scripted model to seq2seq_gru_attention_scripted.pt")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Training Epoch 1: 100%|██████████| 3188/3188 [35:21<00:00,  1.50it/s]


Epoch 1 | Train Loss: 4.9856


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.84it/s]


	Validation Loss: 4.2914
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 2: 100%|██████████| 3188/3188 [35:18<00:00,  1.50it/s]


Epoch 2 | Train Loss: 4.2019


Validation: 100%|██████████| 177/177 [00:35<00:00,  5.05it/s]


	Validation Loss: 3.9865
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 3: 100%|██████████| 3188/3188 [35:25<00:00,  1.50it/s]


Epoch 3 | Train Loss: 3.9741


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.88it/s]


	Validation Loss: 3.8581
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 4: 100%|██████████| 3188/3188 [35:14<00:00,  1.51it/s]


Epoch 4 | Train Loss: 3.8494


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.88it/s]


	Validation Loss: 3.7888
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 5: 100%|██████████| 3188/3188 [35:12<00:00,  1.51it/s]


Epoch 5 | Train Loss: 3.7671


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.92it/s]


	Validation Loss: 3.7288
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 6: 100%|██████████| 3188/3188 [35:12<00:00,  1.51it/s]


Epoch 6 | Train Loss: 3.7063


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.90it/s]


	Validation Loss: 3.7016
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 7: 100%|██████████| 3188/3188 [35:11<00:00,  1.51it/s]


Epoch 7 | Train Loss: 3.6613


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.90it/s]


	Validation Loss: 3.6796
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 8: 100%|██████████| 3188/3188 [35:11<00:00,  1.51it/s]


Epoch 8 | Train Loss: 3.6250


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.89it/s]


	Validation Loss: 3.6677
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 9: 100%|██████████| 3188/3188 [35:10<00:00,  1.51it/s]


Epoch 9 | Train Loss: 3.5963


Validation: 100%|██████████| 177/177 [00:36<00:00,  4.90it/s]


	Validation Loss: 3.6574
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 10: 100%|██████████| 3188/3188 [35:09<00:00,  1.51it/s]


Epoch 10 | Train Loss: 3.5723


Validation: 100%|██████████| 177/177 [00:35<00:00,  4.92it/s]


	Validation Loss: 3.6442
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 11: 100%|██████████| 3188/3188 [35:09<00:00,  1.51it/s]


Epoch 11 | Train Loss: 3.5529


Validation: 100%|██████████| 177/177 [00:34<00:00,  5.07it/s]


	Validation Loss: 3.6408
Validation improved. Saved best model to /content/drive/MyDrive/xsum/best_model.pt


Training Epoch 12:  72%|███████▏  | 2295/3188 [25:19<09:56,  1.50it/s]

## Inference check - 5 samples

### No attention model

In [None]:
import os
import torch
import torch.nn.functional as F
from transformers import T5Tokenizer

# 1) Debug sincrono CUDA (mettere PRIMA di import torch nei notebook)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# 2) Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 3) Parametri di generazione
MAX_GEN_LEN = 64      # stesso MAX_TARGET_LEN usato in train
K = 50
P = 0.9
BEAM_SIZE = 5

# 4) Ricarica tokenizer e modello
DATA_DIR = "/content/drive/MyDrive/xsum"
tokenizer = T5Tokenizer.from_pretrained(os.path.join(DATA_DIR, "tokenizer"))

vocab_size = tokenizer.vocab_size

# Ripeti la classe Seq2SeqLSTM esattamente com’era a training
class Seq2SeqLSTM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=512):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.encoder = torch.nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.decoder = torch.nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc_out = torch.nn.Linear(hidden_dim, vocab_size)

    def forward(self, src, trg):
        src_emb = self.embedding(src)
        _, (hidden, cell) = self.encoder(src_emb)
        trg_emb = self.embedding(trg)
        output, _ = self.decoder(trg_emb, (hidden, cell))
        return self.fc_out(output)

# Istanzia e carica pesi
model = Seq2SeqLSTM(vocab_size).to(DEVICE)
ckpt_path = os.path.join(DATA_DIR, "seq2seq_final_NOATT.pt")
state = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(state)
model.eval()

# 5) Funzioni di sampling
def assert_tensor_ok(tensor, name):
    assert not torch.isnan(tensor).any(), f"{name} contains NaN"
    assert not torch.isinf(tensor).any(), f"{name} contains Inf"

def top_k_sampling(logits, k):
    # logits: 1D tensor (vocab_size,)
    assert_tensor_ok(logits, "logits (top-k)")
    values, indices = torch.topk(logits, k)
    # probabilità sui top-k
    probs = F.softmax(values, dim=-1)
    assert_tensor_ok(probs, "probs (top-k)")
    assert probs.sum() > 0, "sum(probs_top_k)==0"
    choice = torch.multinomial(probs, 1)
    return indices.gather(-1, choice)

def top_p_sampling(logits, p):
    assert_tensor_ok(logits, "logits (top-p)")
    probs = F.softmax(logits, dim=-1)
    assert_tensor_ok(probs, "probs (top-p)")
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumprobs = torch.cumsum(sorted_probs, dim=-1)
    # maschera fuori tutto ciò che eccede p (ma conserva il primo token)
    mask = cumprobs > p
    mask[..., 1:] = mask[..., :-1]
    filtered = sorted_probs.masked_fill(mask, 0.0)
    assert filtered.sum() > 0, "sum(probs_top_p)==0"
    filtered = filtered / filtered.sum()
    choice_idx = torch.multinomial(filtered, 1)
    return sorted_indices.gather(-1, choice_idx)

def generate_greedy(model, tokenizer, src_ids):
    with torch.no_grad():
        # encoder
        src_emb = model.embedding(src_ids)
        _, (h, c) = model.encoder(src_emb)

        # inizializza prev come (batch,1)
        prev = torch.full((src_ids.size(0), 1),
                          tokenizer.pad_token_id,
                          dtype=torch.long,
                          device=src_ids.device)
        seq = []

        for _ in range(MAX_GEN_LEN):
            # embedding: batch x 1 -> batch x 1 x embed_dim
            prev_emb = model.embedding(prev)
            assert prev_emb.dim() == 3

            out, (h, c) = model.decoder(prev_emb, (h, c))
            logits = model.fc_out(out[:, -1, :])  # (batch, vocab)
            nxt = logits.argmax(dim=-1, keepdim=True)  # (batch, 1)

            if nxt.item() == tokenizer.eos_token_id:
                break

            seq.append(nxt.item())
            prev = nxt  # already shape (batch,1)
        return tokenizer.decode(seq, skip_special_tokens=True)

def generate_top_k(model, tokenizer, src_ids, k=K):
    with torch.no_grad():
        src_emb = model.embedding(src_ids)
        _, (h, c) = model.encoder(src_emb)

        prev = torch.full((src_ids.size(0), 1),
                          tokenizer.pad_token_id,
                          dtype=torch.long,
                          device=src_ids.device)
        seq = []

        for _ in range(MAX_GEN_LEN):
            prev_emb = model.embedding(prev)         # (batch,1,emb)
            out, (h, c) = model.decoder(prev_emb, (h, c))
            logits = model.fc_out(out[:, -1, :]).squeeze(0)
            nxt = top_k_sampling(logits, k).unsqueeze(0)  # (1,1)
            if nxt.item() == tokenizer.eos_token_id:
                break
            seq.append(nxt.item())
            prev = nxt
        return tokenizer.decode(seq, skip_special_tokens=True)

def top_p_sampling(logits, p):
    # logits: 1D tensor (vocab_size,)
    # 1) trasforma in probabilità
    probs = F.softmax(logits, dim=-1)
    # 2) ordina decrescente
    sorted_probs, sorted_inds = torch.sort(probs, descending=True)
    # 3) cumulata
    cumprobs = torch.cumsum(sorted_probs, dim=-1)
    # 4) costruisci mask senza sovrapposizioni di memoria
    first = torch.zeros_like(cumprobs[..., :1], dtype=torch.bool)
    rest  = cumprobs[..., :-1] > p
    mask  = torch.cat([first, rest], dim=-1)    # shape == sorted_probs.shape
    # 5) applica mask e rinormalizza
    filtered = sorted_probs.masked_fill(mask, 0.0)
    filtered = filtered / filtered.sum(dim=-1, keepdim=True)
    # 6) campiona un indice dall’insieme filtrato
    idx_in_sorted = torch.multinomial(filtered, 1)
    # 7) torna l’indice originale
    return sorted_inds.gather(-1, idx_in_sorted)

def generate_top_p(model, tokenizer, src_ids, p=P):
    """
    model     : il tuo Seq2SeqLSTM già in eval() e sul DEVICE
    tokenizer : il T5Tokenizer usato in training
    src_ids   : Tensor shape (1, seq_len) sul DEVICE
    p         : soglia nucleus (es. 0.9)
    """
    with torch.no_grad():
        # encoding
        src_emb = model.embedding(src_ids)
        _, (h, c) = model.encoder(src_emb)

        # inizializza prev come (batch_size=1, 1)
        prev = torch.full((src_ids.size(0), 1),
                          tokenizer.pad_token_id,
                          dtype=torch.long,
                          device=src_ids.device)
        seq = []

        for _ in range(MAX_GEN_LEN):
            # embedding e decoding di un solo token
            prev_emb = model.embedding(prev)          # (1,1,emb_dim)
            out, (h, c) = model.decoder(prev_emb, (h, c))
            logits = model.fc_out(out[:, -1, :]).squeeze(0)  # (vocab_size,)

            # selezione top-p
            nxt = top_p_sampling(logits, p).unsqueeze(0)     # (1,1)
            if nxt.item() == tokenizer.eos_token_id:
                break

            seq.append(nxt.item())
            prev = nxt

        return tokenizer.decode(seq, skip_special_tokens=True)

def generate_beam_search(model, tokenizer, src_ids, beam_size=BEAM_SIZE):
    with torch.no_grad():
        src_emb = model.embedding(src_ids)
        _, (h0, c0) = model.encoder(src_emb)

        beams = [([tokenizer.pad_token_id], h0, c0, 0.0)]
        completed = []

        for _ in range(MAX_GEN_LEN):
            new_beams = []
            for seq_ids, h, c, score in beams:
                # costruisci prev di shape (1,1)
                prev = torch.tensor([[seq_ids[-1]]],
                                    dtype=torch.long,
                                    device=src_ids.device)
                prev_emb = model.embedding(prev)    # (1,1,emb)
                out, (h1, c1) = model.decoder(prev_emb, (h, c))
                logits = model.fc_out(out[:, -1, :]).squeeze(0)
                logps = F.log_softmax(logits, dim=-1)
                topk_lp, topk_idx = torch.topk(logps, beam_size)

                for lp, idx in zip(topk_lp.tolist(), topk_idx.tolist()):
                    new_seq = seq_ids + [idx]
                    new_score = score + lp
                    if idx == tokenizer.eos_token_id:
                        completed.append((new_seq, new_score))
                    else:
                        new_beams.append((new_seq, h1, c1, new_score))

            beams = sorted(new_beams, key=lambda x: x[3], reverse=True)[:beam_size]
            if not beams:
                break

        best = max(completed, key=lambda x: x[1]) if completed else beams[0]
        # scarta il primo pad token
        return tokenizer.decode(best[0][1:], skip_special_tokens=True)

# 6) Esempio di inferenza su val_ds
from torch.utils.data import DataLoader

# Ricarica il dataset di validazione
import pandas as pd
class XSumDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.inputs = df["document"].tolist()
        self.tokenizer = tokenizer
    def __len__(self): return len(self.inputs)
    def __getitem__(self, idx):
        enc = self.tokenizer(self.inputs[idx],
                             max_length=512, padding="max_length",
                             truncation=True, return_tensors="pt")
        return enc["input_ids"].squeeze(0)

val_df = pd.read_csv(os.path.join(DATA_DIR, "xsum_val.csv")).dropna()
val_ds = XSumDataset(val_df, tokenizer)

for i in range(5):
    # Input
    src_cpu = val_ds[i]                           # shape [seq_len]
    src = src_cpu.unsqueeze(0).to(DEVICE)         # [1, seq_len]

    # Reference summary (ground truth)
    ref = val_df.loc[i, "summary"]

    # Generazioni
    out_greedy = generate_greedy(model, tokenizer, src)
    out_topk   = generate_top_k(model, tokenizer, src)
    out_topp   = generate_top_p(model, tokenizer, src)
    out_beam   = generate_beam_search(model, tokenizer, src)

    # Stampa
    print(f"\n--- ESEMPIO {i+1} ---")
    print("Documento:", tokenizer.decode(src_cpu, skip_special_tokens=True)[:200], "…")
    print("Reference  :", ref)
    print("Greedy     :", out_greedy)
    print("Top-K      :", out_topk)
    print("Top-P      :", out_topp)
    print("Beam search:", out_beam)


--- ESEMPIO 1 ---
Documento: The ex-Reading defender denied fraudulent trading charges relating to the Sodje Sports Foundation - a charity to raise money for Nigerian sport. Mr Sodje, 37, is jointly charged with elder brothers Ef …
Reference  : Former Premier League footballer Sam Sodje has appeared in court alongside three brothers accused of charity fraud.
Greedy     : a man who was stabbed to death in a crash in County Antrim has been named as the new Bishop of the MCC.
Top-K      : sex seamer, one of the highest in a small group of men who travelled the RNLI could be revealed.
Top-P      : has been named as Liverpool's BSC All Blacks and Andyist's hopes of match in 13 games to the group's Six Nations.
Beam search: police have been arrested in connection with the death of a man who was stabbed to death.

--- ESEMPIO 2 ---
Documento: Voges was forced to retire hurt on 86 after suffering the injury while batting during the County Championship draw with Somerset on 4 June. Middlesex h

### Final GRU model with attention

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inference for Seq2Seq GRU+Attention model on XSum.
Supports greedy, top-k, top-p, and beam search decoding.
"""
import os
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import T5Tokenizer
from torch import nn
from torch.utils.data import Dataset

# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR    = "/content/drive/MyDrive/xsum"
MAX_GEN_LEN = 64
K           = 50
P           = 0.9
BEAM_SIZE   = 5

# ─────────────────────────────────────────────
# TOKENIZER
# ─────────────────────────────────────────────
tokenizer    = T5Tokenizer.from_pretrained(os.path.join(DATA_DIR, "tokenizer"))
VOCAB_SIZE   = tokenizer.vocab_size
PAD_ID       = tokenizer.pad_token_id
EOS_ID       = tokenizer.eos_token_id

# ─────────────────────────────────────────────
# MODEL DEFINITION
# ─────────────────────────────────────────────
EMBEDDING_DIM = 256
HIDDEN_DIM    = 512
DROPOUT       = 0.1

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, pad_id):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.dropout   = nn.Dropout(DROPOUT)
        self.gru       = nn.GRU(emb_dim, hid_dim, batch_first=True)

    def forward(self, x):
        emb      = self.dropout(self.embedding(x))
        outputs, hidden = self.gru(emb)
        return outputs, hidden

class AttnDecoderRNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, pad_id):
        super().__init__()
        self.embedding    = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.dropout      = nn.Dropout(DROPOUT)
        self.gru          = nn.GRU(emb_dim, hid_dim, batch_first=True)
        self.attn_combine = nn.Linear(hid_dim*2, hid_dim)
        self.out          = nn.Linear(hid_dim, vocab_size)

    def step(self, input_tok, hidden, enc_outputs):
        emb      = self.dropout(self.embedding(input_tok)).unsqueeze(1)      # (b,1,emb)
        output, hidden = self.gru(emb, hidden)                               # (b,1,hid)
        attn_scores    = torch.bmm(enc_outputs, output.transpose(1,2)).squeeze(2)  # (b,seq)
        attn_weights   = torch.softmax(attn_scores, dim=1).unsqueeze(1)      # (b,1,seq)
        context        = torch.bmm(attn_weights, enc_outputs)                # (b,1,hid)
        combined       = torch.cat([output, context], dim=2)                 # (b,1,2*hid)
        combined       = torch.tanh(self.attn_combine(combined))             # (b,1,hid)
        logits         = self.out(self.dropout(combined.squeeze(1)))         # (b,vocab)
        return logits, hidden

class Seq2SeqGRUAttention(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, pad_id):
        super().__init__()
        self.encoder = EncoderRNN(vocab_size, emb_dim, hid_dim, pad_id)
        self.decoder = AttnDecoderRNN(vocab_size, emb_dim, hid_dim, pad_id)

    def encode(self, src):
        return self.encoder(src)

    def decode_step(self, input_tok, hidden, enc_outputs):
        return self.decoder.step(input_tok, hidden, enc_outputs)

# ─────────────────────────────────────────────
# LOAD MODEL
# ─────────────────────────────────────────────
model = Seq2SeqGRUAttention(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, PAD_ID).to(DEVICE)
state = torch.load(os.path.join(DATA_DIR, "seq2seq_gru_attention_final.pt"), map_location=DEVICE)
model.load_state_dict(state)
model.eval()

# ─────────────────────────────────────────────
# SAMPLING HELPERS
# ─────────────────────────────────────────────
def top_k_sampling(logits, k):
    vals, idxs = torch.topk(F.softmax(logits, dim=-1), k)
    vals       = vals / vals.sum(dim=-1, keepdim=True)
    choice     = torch.multinomial(vals, 1)
    return idxs.gather(-1, choice)

def top_p_sampling(logits, p, eps=1e-8):
    # 1) converti in probabilità
    probs = F.softmax(logits, dim=-1)
    # 2) ordina decrescente
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    # 3) cumulata
    cumprobs = torch.cumsum(sorted_probs, dim=-1)
    # 4) crea mask shifted senza aliasing
    orig_mask = cumprobs > p
    mask = torch.zeros_like(orig_mask)
    mask[..., 1:] = orig_mask[..., :-1]
    # 5) applica mask e rinormalizza
    filtered = sorted_probs.masked_fill(mask, 0.0)
    if filtered.sum() < eps:
        filtered = torch.ones_like(filtered)
        filtered[:10] = 1
    filtered = filtered / filtered.sum()
    # 6) campiona
    idx_in_sorted = torch.multinomial(filtered, 1)
    # 7) mappa indici al vocab
    return sorted_indices.gather(-1, idx_in_sorted)

# ─────────────────────────────────────────────
# GENERATION FUNCTIONS
# ─────────────────────────────────────────────
def generate(model, tokenizer, src_ids, mode="greedy"):
    with torch.no_grad():
        enc_outs, hidden = model.encode(src_ids)
        input_tok = torch.full((src_ids.size(0),), PAD_ID, dtype=torch.long, device=DEVICE)
        seq = []
        for _ in range(MAX_GEN_LEN):
            logits, hidden = model.decode_step(input_tok, hidden, enc_outs)
            if mode == "greedy":
                next_tok = logits.argmax(dim=-1)
            elif mode == "top_k":
                next_tok = top_k_sampling(logits, K).squeeze(-1)
            elif mode == "top_p":
                next_tok = top_p_sampling(logits, P).squeeze(-1)
            else:
                raise ValueError("Unknown mode")
            if next_tok.item() == EOS_ID:
                break
            seq.append(next_tok.item())
            input_tok = next_tok
        return tokenizer.decode(seq, skip_special_tokens=True)

def generate_beam_search(model, tokenizer, src_ids, beam_size=BEAM_SIZE):
    with torch.no_grad():
        enc_outs, hidden = model.encode(src_ids)
        beams = [([PAD_ID], hidden, 0.0)]
        completed = []
        for _ in range(MAX_GEN_LEN):
            new_beams = []
            for seq_ids, h, score in beams:
                prev = torch.tensor([seq_ids[-1]], device=DEVICE).unsqueeze(0)
                logits, h1 = model.decode_step(prev.squeeze(0), h, enc_outs)
                logps = F.log_softmax(logits, dim=-1)
                top_lp, top_idx = torch.topk(logps, beam_size, dim=-1)
                for lp, idx in zip(top_lp[0], top_idx[0]):
                    new_seq = seq_ids + [idx.item()]
                    new_score = score + lp.item()
                    if idx.item() == EOS_ID:
                        completed.append((new_seq, new_score))
                    else:
                        new_beams.append((new_seq, h1, new_score))
            beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_size]
            if not beams:
                break
        best = max(completed, key=lambda x: x[1]) if completed else beams[0]
        return tokenizer.decode(best[0][1:], skip_special_tokens=True)

# ─────────────────────────────────────────────
# RUN EXAMPLES
# ─────────────────────────────────────────────
class XSumDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.inputs = df["document"].tolist()
        self.tok    = tokenizer
    def __len__(self): return len(self.inputs)
    def __getitem__(self, i):
        enc = self.tok(self.inputs[i],
                       max_length=512,
                       padding="max_length",
                       truncation=True,
                       return_tensors="pt")
        return enc["input_ids"].squeeze(0)

val_df = pd.read_csv(os.path.join(DATA_DIR, "xsum_val.csv")).dropna().reset_index(drop=True)
val_ds = XSumDataset(val_df, tokenizer)

for i in range(5):
    src = val_ds[i].unsqueeze(0).to(DEVICE)
    ref = val_df.loc[i, "summary"]
    print(f"\n--- EXAMPLE {i+1} ---")
    print("Doc:      ", tokenizer.decode(src[0], skip_special_tokens=True)[:200], "…")
    print("Reference:", ref)
    print("Greedy:   ", generate(model, tokenizer, src, mode="greedy"))
    print("Top-K:    ", generate(model, tokenizer, src, mode="top_k"))
    print("Top-P:    ", generate(model, tokenizer, src, mode="top_p"))
    print("Beam:     ", generate_beam_search(model, tokenizer, src))


--- EXAMPLE 1 ---
Doc:       The ex-Reading defender denied fraudulent trading charges relating to the Sodje Sports Foundation - a charity to raise money for Nigerian sport. Mr Sodje, 37, is jointly charged with elder brothers Ef …
Reference: Former Premier League footballer Sam Sodje has appeared in court alongside three brothers accused of charity fraud.
Greedy:    s have been charged with a "serious" disciplinary hearing into the Hillsborough disaster.
Top-K:     fans to a taxi driver had a divorce, which led to a role on a former player by a sports tribunal for several months.
Top-P:     footage-Fclem John Sheridan has been granted legal action against a report after playing by two members of the sports team.
Beam:      prosecutors have been charged with raping a footballer who was beaten by a court in the United States.

--- EXAMPLE 2 ---
Doc:       Voges was forced to retire hurt on 86 after suffering the injury while batting during the County Championship draw with Somerset on 

## Final INFERENCE on all test set (GRU with attention)

Different decoding strategies are implemented.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fast, batched inference for Seq2Seq GRU+Attention on the full XSum test set.
Generates summaries via greedy, top-k, top-p, and beam search in batch.
Supports temperature scaling, early EOS stopping, and vectorized repetition penalty.
"""

import os
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import T5Tokenizer
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast
from tqdm import tqdm

# ──────────────────────────────────────────────────────────────────────────
# CONFIGURATION
# ──────────────────────────────────────────────────────────────────────────
DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR            = "/content/drive/MyDrive/xsum"
MAX_GEN_LEN         = 64
K                    = 50
P                    = 0.9
BEAM_SIZE            = 4
BATCH_SIZE           = 64
NUM_WORKERS          = 2
TEMPERATURE          = 1.0
REPETITION_PENALTY  = 1.2   # penalty factor >1

# ──────────────────────────────────────────────────────────────────────────
# TOKENIZER & SPECIAL IDS
# ──────────────────────────────────────────────────────────────────────────
tokenizer = T5Tokenizer.from_pretrained(os.path.join(DATA_DIR, "tokenizer"))
PAD_ID     = tokenizer.pad_token_id
EOS_ID     = tokenizer.eos_token_id
VOCAB_SIZE = tokenizer.vocab_size

# ──────────────────────────────────────────────────────────────────────────
# MODEL DEFINITION
# ──────────────────────────────────────────────────────────────────────────
EMBED_DIM  = 256
HIDDEN_DIM = 512
DROPOUT    = 0.1

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=PAD_ID)
        self.drop  = nn.Dropout(DROPOUT)
        self.gru   = nn.GRU(EMBED_DIM, HIDDEN_DIM, batch_first=True)
    def forward(self, x):
        emb = self.drop(self.embed(x))
        outputs, hidden = self.gru(emb)
        return outputs, hidden

class AttnDecoderRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed        = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=PAD_ID)
        self.drop         = nn.Dropout(DROPOUT)
        self.gru          = nn.GRU(EMBED_DIM, HIDDEN_DIM, batch_first=True)
        self.attn_combine = nn.Linear(HIDDEN_DIM*2, HIDDEN_DIM)
        self.out          = nn.Linear(HIDDEN_DIM, vocab_size)
    def step(self, input_tok, hidden, enc_outs):
        emb    = self.drop(self.embed(input_tok)).unsqueeze(1)
        output, hidden = self.gru(emb, hidden)
        scores  = torch.bmm(enc_outs, output.transpose(1,2)).squeeze(2)
        weights = torch.softmax(scores, dim=1).unsqueeze(1)
        context = torch.bmm(weights, enc_outs)
        comb    = torch.tanh(self.attn_combine(torch.cat([output, context], dim=2)))
        logits  = self.out(self.drop(comb.squeeze(1)))
        return logits, hidden

class Seq2SeqGRUAttention(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.encoder = EncoderRNN(vocab_size)
        self.decoder = AttnDecoderRNN(vocab_size)
    def encode(self, src):
        return self.encoder(src)
    def decode_step(self, tok, hidden, enc):
        return self.decoder.step(tok, hidden, enc)

# ──────────────────────────────────────────────────────────────────────────
# LOAD MODEL WITH KEY MAPPING
# ──────────────────────────────────────────────────────────────────────────
model      = Seq2SeqGRUAttention(VOCAB_SIZE).to(DEVICE)
orig_state = torch.load(os.path.join(DATA_DIR, "best_model.pt"), map_location=DEVICE)
state = {}
for k,v in orig_state.items():
    if k.startswith("encoder.embedding."):
        new_k = k.replace("encoder.embedding.", "encoder.embed.")
    elif k.startswith("decoder.embedding."):
        new_k = k.replace("decoder.embedding.", "decoder.embed.")
    else:
        new_k = k
    state[new_k] = v
model.load_state_dict(state)
model.eval()

# ──────────────────────────────────────────────────────────────────────────
# BATCHED GENERATION FUNCTIONS WITH REPETITION PENALTY
# ──────────────────────────────────────────────────────────────────────────

def generate_greedy_batch(src_ids, temperature=TEMPERATURE):
    B = src_ids.size(0)
    enc_outs, hidden = model.encode(src_ids)
    prev = torch.full((B,), PAD_ID, device=DEVICE)
    seqs = [[] for _ in range(B)]
    used  = torch.zeros(B, VOCAB_SIZE, device=DEVICE)
    finished = torch.zeros(B, dtype=torch.bool, device=DEVICE)

    with autocast(device_type=DEVICE), torch.no_grad():
        for _ in range(MAX_GEN_LEN):
            logits, hidden = model.decode_step(prev, hidden, enc_outs)
            scaled = logits / temperature
            # apply repetition penalty: lower logits of used tokens
            penalty_mask = torch.where(used>0, REPETITION_PENALTY, 1.0)
            scaled = scaled / penalty_mask
            # mask finished sequences
            scaled[finished] = float('-inf'); scaled[finished, EOS_ID] = 0

            next_tok = scaled.argmax(dim=-1)
            for i, tok in enumerate(next_tok.tolist()):
                if not finished[i]:
                    seqs[i].append(tok)
                    used[i, tok] = 1
                    if tok == EOS_ID:
                        finished[i] = True
            prev = next_tok
            if finished.all(): break

    return [tokenizer.decode(seq, skip_special_tokens=True) for seq in seqs]


def generate_top_k_batch(src_ids, k=K, temperature=TEMPERATURE):
    B = src_ids.size(0)
    enc_outs, hidden = model.encode(src_ids)
    prev = torch.full((B,), PAD_ID, device=DEVICE)
    seqs = [[] for _ in range(B)]
    used  = torch.zeros(B, VOCAB_SIZE, device=DEVICE)
    finished = torch.zeros(B, dtype=torch.bool, device=DEVICE)

    with autocast(device_type=DEVICE), torch.no_grad():
        for _ in range(MAX_GEN_LEN):
            logits, hidden = model.decode_step(prev, hidden, enc_outs)
            scaled = logits / temperature
            penalty_mask = torch.where(used>0, REPETITION_PENALTY, 1.0)
            scaled = scaled / penalty_mask
            scaled[finished] = float('-inf'); scaled[finished, EOS_ID] = 0

            vals, idxs = torch.topk(F.softmax(scaled, dim=-1), k, dim=-1)
            probs = vals / vals.sum(dim=-1, keepdim=True)
            choice = torch.multinomial(probs, 1).squeeze(-1)
            next_tok = idxs.gather(-1, choice.unsqueeze(-1)).squeeze(-1)

            for i, tok in enumerate(next_tok.tolist()):
                if not finished[i]:
                    seqs[i].append(tok)
                    used[i, tok] = 1
                    if tok == EOS_ID:
                        finished[i] = True
            prev = next_tok
            if finished.all(): break

    return [tokenizer.decode(seq, skip_special_tokens=True) for seq in seqs]


def generate_top_p_batch(src_ids, p=P, temperature=TEMPERATURE):
    B = src_ids.size(0)
    enc_outs, hidden = model.encode(src_ids)
    prev = torch.full((B,), PAD_ID, device=DEVICE)
    seqs = [[] for _ in range(B)]
    used  = torch.zeros(B, VOCAB_SIZE, device=DEVICE)
    finished = torch.zeros(B, dtype=torch.bool, device=DEVICE)

    with autocast(device_type=DEVICE), torch.no_grad():
        for _ in range(MAX_GEN_LEN):
            logits, hidden = model.decode_step(prev, hidden, enc_outs)
            scaled = logits / temperature
            penalty_mask = torch.where(used>0, REPETITION_PENALTY, 1.0)
            scaled = scaled / penalty_mask
            scaled[finished] = float('-inf'); scaled[finished, EOS_ID] = 0

            probs = F.softmax(scaled, dim=-1)
            sorted_p, sorted_i = torch.sort(probs, descending=True)
            cum = sorted_p.cumsum(dim=-1)
            mask = cum > p; mask[...,1:] = mask[..., :-1]; mask[...,0] = False
            filt = sorted_p.masked_fill(mask, 0.0)
            if filt.sum(dim=-1, keepdim=True).eq(0).any(): filt = torch.ones_like(filt)
            filt = filt / filt.sum(dim=-1, keepdim=True)
            choice = torch.multinomial(filt, 1).squeeze(-1)
            next_tok = sorted_i.gather(-1, choice.unsqueeze(-1)).squeeze(-1)

            for i, tok in enumerate(next_tok.tolist()):
                if not finished[i]:
                    seqs[i].append(tok)
                    used[i, tok] = 1
                    if tok == EOS_ID:
                        finished[i] = True
            prev = next_tok
            if finished.all(): break

    return [tokenizer.decode(seq, skip_special_tokens=True) for seq in seqs]


def generate_beam_batch(src_ids, beam_size=BEAM_SIZE, temperature=TEMPERATURE):
    B = src_ids.size(0)
    enc_outs, hidden = model.encode(src_ids)
    enc = enc_outs.unsqueeze(1).expand(B, beam_size, -1, -1).reshape(B*beam_size, -1, HIDDEN_DIM)
    h   = hidden.unsqueeze(2).expand(-1, B, beam_size, -1).reshape(1, B*beam_size, -1)
    beams  = torch.full((B*beam_size,1), PAD_ID, device=DEVICE)
    scores = torch.zeros(B, beam_size, device=DEVICE)
    used   = torch.zeros(B*beam_size, VOCAB_SIZE, device=DEVICE)
    finished = torch.zeros(B*beam_size, dtype=torch.bool, device=DEVICE)

    with torch.no_grad():
        for _ in range(MAX_GEN_LEN):
            prev = beams[:, -1]
            logits, h = model.decode_step(prev, h, enc)
            scaled = logits / temperature
            penalty_mask = torch.where(used>0, REPETITION_PENALTY, 1.0)
            scaled = scaled / penalty_mask
            scaled[finished] = float('-inf'); scaled[finished, EOS_ID] = 0

            logp = F.log_softmax(scaled, dim=-1).view(B, beam_size, -1)
            total = logp + scores.unsqueeze(-1)
            flat_scores, flat_idx = total.view(B, -1).topk(beam_size, dim=-1)
            beam_idx = flat_idx // VOCAB_SIZE; tok_idx = flat_idx % VOCAB_SIZE

            old = beams.view(B, beam_size, -1)
            new_beams, new_used, new_finished, new_scores = [], [], [], []
            for b in range(B):
                for i in range(beam_size):
                    bi = beam_idx[b,i]; ti = tok_idx[b,i]
                    seq = torch.cat([old[b,bi], ti.view(1)])
                    new_beams.append(seq)
                    idx_flat = b*beam_size + i
                    # update used
                    u = used[idx_flat].clone()
                    u[ti] = 1
                    new_used.append(u)
                    # update finished and score
                    fin = finished[idx_flat] or (ti==EOS_ID)
                    new_finished.append(fin)
                    new_scores.append(flat_scores[b,i])
            beams = torch.stack(new_beams).view(B*beam_size, -1)
            used = torch.stack(new_used)
            finished = torch.tensor(new_finished, device=DEVICE)
            scores = torch.stack(new_scores).view(B, beam_size)
            # reorder hidden
            h = h.view(1, B, beam_size, HIDDEN_DIM)
            h = h.gather(2, beam_idx.view(1,B,beam_size,1).expand(-1,-1,-1,HIDDEN_DIM))
            h = h.reshape(1, B*beam_size, HIDDEN_DIM)
            if finished.all(): break

    results = []
    beams = beams.view(B, beam_size, -1)
    best  = scores.argmax(dim=-1)
    for b in range(B):
        seq = beams[b, best[b],1:].tolist()
        if EOS_ID in seq: seq = seq[:seq.index(EOS_ID)]
        results.append(tokenizer.decode(seq, skip_special_tokens=True))
    return results

# ──────────────────────────────────────────────────────────────────────────
# DATASET & RUN
# ──────────────────────────────────────────────────────────────────────────
class XSumTestDataset(Dataset):
    def __init__(self, path): self.df = pd.read_csv(path).dropna().reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        enc = tokenizer(self.df.loc[i,"document"], max_length=512,
                        padding="max_length", truncation=True, return_tensors="pt")
        return enc["input_ids"].squeeze(0)

csv_in = os.path.join(DATA_DIR, "xsum_test.csv")
loader = DataLoader(XSumTestDataset(csv_in), batch_size=BATCH_SIZE,
                    shuffle=False, num_workers=NUM_WORKERS)

pred_g, pred_k, pred_p, pred_b = [], [], [], []
for batch in tqdm(loader, desc="Greedy"):   pred_g += generate_greedy_batch(batch.to(DEVICE))
for batch in tqdm(loader, desc="Top-K"):    pred_k += generate_top_k_batch(batch.to(DEVICE))
for batch in tqdm(loader, desc="Top-P"):    pred_p += generate_top_p_batch(batch.to(DEVICE))
for batch in tqdm(loader, desc="Beam"):     pred_b += generate_beam_batch(batch.to(DEVICE))

out_df = pd.read_csv(csv_in).dropna().reset_index(drop=True)
out_df["pred_greedy"] = pred_g
out_df["pred_top_k"]   = pred_k
out_df["pred_top_p"]   = pred_p
out_df["pred_beam"]    = pred_b
out_df.to_csv(os.path.join(DATA_DIR, "xsum_test_with_preds.csv"), index=False)
print("Saved to xsum_test_with_preds.csv")

