In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt

In [10]:
with open('preprocessed_data.pkl', 'rb') as f:
    data = pickle.load(f)


In [11]:
train_premise_idx = data['train_premise_idx']
train_hypothesis_idx = data['train_hypothesis_idx']
train_labels = data['train_labels']

val_premise_idx = data['val_premise_idx']
val_hypothesis_idx = data['val_hypothesis_idx']
val_labels = data['val_labels']

test_premise_idx = data['test_premise_idx']
test_hypothesis_idx = data['test_hypothesis_idx']
test_labels = data['test_labels']

vocab_size = data['vocab_size']
word_to_ix = data['word_to_ix']
label_to_ix = data['label_to_ix']
MAX_LENGTH_PREMISE = min(data['MAX_LENGTH_PREMISE'], 400)  # cap long premises
MAX_LENGTH_HYPOTHESIS = min(data['MAX_LENGTH_HYPOTHESIS'], 60)


In [12]:
# check sequence length statistics
premise_lens = [len(p) for p in train_premise_idx]
hyp_lens = [len(h) for h in train_hypothesis_idx]

print(f"Premise lengths - Max: {max(premise_lens)}, Mean: {np.mean(premise_lens):.1f}, 95th percentile: {np.percentile(premise_lens, 95):.1f}")
print(f"Hypothesis lengths - Max: {max(hyp_lens)}, Mean: {np.mean(hyp_lens):.1f}, 95th percentile: {np.percentile(hyp_lens, 95):.1f}")
print(f"\nNote: Sequences will be truncated to {200} tokens to manage memory")

Premise lengths - Max: 11640, Mean: 18.3, 95th percentile: 33.0
Hypothesis lengths - Max: 36, Mean: 11.8, 95th percentile: 20.0

Note: Sequences will be truncated to 200 tokens to manage memory


In [16]:
# device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")



 Using device: cuda


In [17]:
# dataset class
class NLIDataset(Dataset):
    def __init__(self, premise_idx, hypothesis_idx, labels):
        self.premise_idx = premise_idx
        self.hypothesis_idx = hypothesis_idx
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'premise': self.premise_idx[idx],
            'hypothesis': self.hypothesis_idx[idx],
            'label': self.labels[idx]
        }

In [28]:
# -- replace your collate_fn with this (uses the caps you computed earlier) --

MAX_SEQ_LENGTH = 400 
PAD_ID = 0

def collate_fn(batch):
    premises = [item['premise'][:MAX_LENGTH_PREMISE] for item in batch]
    hypotheses = [item['hypothesis'][:MAX_LENGTH_HYPOTHESIS] for item in batch]
    labels = [item['label'] for item in batch]

    max_premise_len = max(len(p) for p in premises)
    max_hypothesis_len = max(len(h) for h in hypotheses)

    def pad_to(seqs, L):
        return [seq + [PAD_ID]*(L - len(seq)) for seq in seqs]

    premises_tensor   = torch.LongTensor(pad_to(premises,   max_premise_len))
    hypotheses_tensor = torch.LongTensor(pad_to(hypotheses, max_hypothesis_len))
    labels_tensor     = torch.LongTensor(labels)

    return premises_tensor, hypotheses_tensor, labels_tensor


In [29]:
# create datasets
train_dataset = NLIDataset(train_premise_idx, train_hypothesis_idx, train_labels)
val_dataset = NLIDataset(val_premise_idx, val_hypothesis_idx, val_labels)
test_dataset = NLIDataset(test_premise_idx, test_hypothesis_idx, test_labels)

# hyperparameters - automatically adjust batch size based on GPU
BATCH_SIZE = 64 if torch.cuda.is_available() else 32  # use 64 for GPU, 32 for CPU
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_CLASSES = 2
LEARNING_RATE = 0.001
NUM_EPOCHS = 10

print(f"Hyperparameters:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Embedding dim: {EMBEDDING_DIM}")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Number of epochs: {NUM_EPOCHS}")

# create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"\nDataloaders created:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

Hyperparameters:
  Batch size: 64
  Embedding dim: 128
  Hidden dim: 256
  Learning rate: 0.001
  Number of epochs: 10

Dataloaders created:
  Training batches: 361
  Validation batches: 21
  Test batches: 34


In [52]:
class SelfAttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, 1, bias=False)

    def forward(self, H, mask):  # H:[B,T,D], mask:[B,T] (1=real, 0=pad)
        scores = self.v(torch.tanh(self.W(H))).squeeze(-1)     # [B,T]

        # Mask pads to large negative so they contribute ~0 after softmax
        scores = scores.masked_fill(mask == 0, -1e9)

        # If a whole row is PADs, set that row's scores to [0, -inf, -inf, ...]
        # so softmax yields a safe one-hot at position 0 — without touching alpha later.
        all_pad = (mask.sum(dim=1) == 0)                       # [B]
        if all_pad.any():
            # first set all to -inf
            scores[all_pad] = -1e9
            # then make index 0 = 0 so softmax -> 1 at idx 0
            scores[all_pad, 0] = 0.0

        alpha = torch.softmax(scores, dim=-1)                  # [B,T]
        p_star = torch.bmm(alpha.unsqueeze(1), H).squeeze(1)   # [B,D]
        return p_star, alpha

        

class BahdanauAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W_h = nn.Linear(dim, dim, bias=False)
        self.W_s = nn.Linear(dim, dim, bias=False)
        self.v   = nn.Linear(dim, 1,  bias=False)

    def forward(self, Hk, q, mask):      # Hk:[B,Tp,2H], q:[B,2H], mask:[B,Tp] (1=real,0=pad)
        scores = self.v(torch.tanh(self.W_h(Hk) + self.W_s(q).unsqueeze(1))).squeeze(-1)  # [B,Tp]
        scores = scores.masked_fill(mask == 0, -1e9)

        # if a row is all pads -> set [0, -inf, -inf, ...] so softmax is well-defined
        all_pad = (mask.sum(dim=1) == 0)          # [B]
        if all_pad.any():
            scores[all_pad] = -1e9
            scores[all_pad, 0] = 0.0

        alpha = torch.softmax(scores, dim=-1)     # [B,Tp]
        ctx   = torch.bmm(alpha.unsqueeze(1), Hk).squeeze(1)   # [B,2H]
        return ctx, alpha



In [43]:
class EncoderSideAttentionNLI(nn.Module):
    """
    Premise: BiLSTM -> SelfAttentionPool -> p*
    Hypothesis: BiLSTM -> MaxPool -> h*
    Classify on [p*, h*, |p*-h*|, p*∘h*]
    """
    def __init__(self, vocab_size, emb_dim, hid_dim, num_classes, pad_idx=0, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.enc_p = nn.LSTM(emb_dim, hid_dim, batch_first=True, bidirectional=True)
        self.enc_h = nn.LSTM(emb_dim, hid_dim, batch_first=True, bidirectional=True)
        self.pool_p = SelfAttentionPool(2*hid_dim)
        self.dropout = nn.Dropout(dropout)

        feat_dim = 4 * (2*hid_dim)  # concat p*, h*, |.|, ∘
        self.fc1 = nn.Linear(feat_dim, 2*hid_dim)
        self.fc2 = nn.Linear(2*hid_dim, num_classes)

    def _mask(self, x, pad=0):
        return (x != pad).long()

    def _encode_bilstm(self, x, lstm):
        emb = self.embedding(x)                      # [B,T,E]
        H, _ = lstm(emb)                             # [B,T,2H]
        return H

    def _max_pool_time(self, H, mask):         # inside your EncoderSideAttentionNLI
        masked = H.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
        pooled = masked.max(dim=1).values      # [B,D]
        # handle all-pad rows: replace -inf with 0
        is_all_pad = (mask.sum(dim=1) == 0)
        if is_all_pad.any():
            pooled[is_all_pad] = 0.0
        return pooled

    def forward(self, prem_ids, hyp_ids):
        prem_mask = self._mask(prem_ids)             # [B,Tp]
        hyp_mask  = self._mask(hyp_ids)              # [B,Th]

        H_p = self._encode_bilstm(prem_ids, self.enc_p)       # [B,Tp,2H]
        p_star, attn_p = self.pool_p(H_p, prem_mask)          # [B,2H], [B,Tp]

        H_h = self._encode_bilstm(hyp_ids,  self.enc_h)       # [B,Th,2H]
        h_star = self._max_pool_time(H_h, hyp_mask)           # [B,2H]

        # NLI matching features
        diff = torch.abs(p_star - h_star)
        prod = p_star * h_star
        z = torch.cat([p_star, h_star, diff, prod], dim=-1)   # [B,8H]

        out = self.dropout(torch.relu(self.fc1(z)))
        logits = self.fc2(out)                                 # [B,2]
        return logits, attn_p  # return attn over premise for qualitative plots


In [53]:
class DecoderSideAttentionNLI(nn.Module):
    """
    Premise encoder: BiLSTM -> H_p (keys/values for attention)
    Decoder over hypothesis (uni-LSTM):
        at each step t, attend over H_p with Bahdanau attention using s_t
    Pool over time: mean([s_t; c_t]) -> classifier
    """
    def __init__(self, vocab_size, emb_dim, hid_dim, num_classes, pad_idx=0, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.enc_p = nn.LSTM(emb_dim, hid_dim, batch_first=True, bidirectional=True)
        self.dec_h = nn.LSTM(emb_dim + 2*hid_dim, hid_dim, batch_first=True, bidirectional=False)
        self.attn  = BahdanauAttention(2*hid_dim)  # query dim will be hid_dim; project it to 2H first

        # project decoder state to 2H to match attention query dim
        self.proj_q = nn.Linear(hid_dim, 2*hid_dim, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hid_dim + 2*hid_dim, 2*hid_dim)  # features after pooling
        self.fc2 = nn.Linear(2*hid_dim, num_classes)

    def _mask(self, x, pad=0):
        return (x != pad).long()

    def forward(self, prem_ids, hyp_ids):
        B = prem_ids.size(0)

        prem_mask = self._mask(prem_ids)                      # [B,Tp]
        hyp_mask  = self._mask(hyp_ids)                       # [B,Th]

        # Encode premise
        H_p, _ = self.enc_p(self.embedding(prem_ids))         # [B,Tp,2H]

        # Decoder over hypothesis with attention at every step
        # We'll run a standard LSTM but feed previous context
        Th = hyp_ids.size(1)
        emb_h = self.embedding(hyp_ids)                       # [B,Th,E]

        # init s_0, c_0 as zeros
        h_t = torch.zeros(1, B, self.dec_h.hidden_size, device=emb_h.device)
        c_t = torch.zeros(1, B, self.dec_h.hidden_size, device=emb_h.device)
        ctx_prev = torch.zeros(B, 2*self.enc_p.hidden_size, device=emb_h.device)  # [B,2H]

        dec_outputs = []
        attn_all = []  # collect attention per step for qualitative heatmap

        for t in range(Th):
            # input to decoder: [y_t ; ctx_{t-1}]
            dec_in_t = torch.cat([emb_h[:, t, :], ctx_prev], dim=-1).unsqueeze(1)  # [B,1,E+2H]
            dec_out_t, (h_t, c_t) = self.dec_h(dec_in_t, (h_t, c_t))               # dec_out_t: [B,1,H]
            s_t = dec_out_t.squeeze(1)                                             # [B,H]

            # attention: query is projected s_t -> 2H
            q_t = self.proj_q(s_t)                                                 # [B,2H]
            ctx_t, alpha_t = self.attn(H_p, q_t, prem_mask)                        # [B,2H], [B,Tp]

            # store and roll
            dec_outputs.append(torch.cat([s_t, ctx_t], dim=-1))                    # [B,H+2H]
            attn_all.append(alpha_t.unsqueeze(1))                                  # [B,1,Tp]
            ctx_prev = ctx_t

        # pool over time (mean)
        dec_stack = torch.stack(dec_outputs, dim=1)   # [B,Th,H+2H]
        g = dec_stack.mean(dim=1)
                                                 # [B,H+2H]

        out = self.dropout(torch.relu(self.fc1(g)))
        logits = self.fc2(out)                                                      # [B,2]

        # concat attention maps along time: [B,Th,Tp]
        attn_map = torch.cat(attn_all, dim=1) if attn_all else None
        return logits, attn_map


In [45]:
criterion = nn.CrossEntropyLoss()

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, all_preds, all_labels = 0.0, [], []
    for prem, hyp, labels in dataloader:
        prem, hyp, labels = prem.to(device), hyp.to(device), labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(prem, hyp)
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        all_preds.extend(logits.argmax(dim=1).detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())
    return total_loss/len(dataloader), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, all_preds, all_labels = 0.0, [], []
    for prem, hyp, labels in dataloader:
        prem, hyp, labels = prem.to(device), hyp.to(device), labels.to(device)
        logits, _ = model(prem, hyp)
        loss = criterion(logits, labels)
        total_loss += loss.item()
        all_preds.extend(logits.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    return total_loss/len(dataloader), accuracy_score(all_labels, all_preds), all_preds, all_labels


In [46]:
# ----- Model 2A -----
model_2a = EncoderSideAttentionNLI(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_CLASSES, pad_idx=0).to(device)
opt_2a = optim.Adam(model_2a.parameters(), lr=LEARNING_RATE)

best_val = 0.0
for epoch in range(NUM_EPOCHS):
    tr_loss, tr_acc = train_epoch(model_2a, train_loader, criterion, opt_2a, device)
    va_loss, va_acc, _, _ = evaluate(model_2a, val_loader, criterion, device)
    print(f"[2A] Epoch {epoch+1}: train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f}")
    if va_acc > best_val:
        best_val = va_acc
        torch.save(model_2a.state_dict(), "best_model_2a.pth")
        print("  ↳ saved best_model_2a.pth")


[2A] Epoch 1: train 0.4869/0.7654 | val 0.5953/0.6787
  ↳ saved best_model_2a.pth
[2A] Epoch 2: train 0.3445/0.8517 | val 0.6330/0.6879
  ↳ saved best_model_2a.pth
[2A] Epoch 3: train 0.2539/0.8936 | val 0.6118/0.7255
  ↳ saved best_model_2a.pth
[2A] Epoch 4: train 0.1491/0.9415 | val 0.8373/0.7078
[2A] Epoch 5: train 0.0706/0.9735 | val 1.0385/0.6986
[2A] Epoch 6: train 0.0391/0.9860 | val 1.2304/0.7048
[2A] Epoch 7: train 0.0306/0.9896 | val 1.7898/0.6925
[2A] Epoch 8: train 0.0245/0.9922 | val 1.7954/0.7025
[2A] Epoch 9: train 0.0193/0.9938 | val 1.7884/0.6963
[2A] Epoch 10: train 0.0163/0.9952 | val 2.4001/0.6833


In [47]:
model_2a.load_state_dict(torch.load("best_model_2a.pth"))
test_loss, test_acc, test_preds, test_true = evaluate(model_2a, test_loader, criterion, device)
print(f"[2A] Test: loss {test_loss:.4f} acc {test_acc:.4f}")
print(classification_report(test_true, test_preds, target_names=['entails','neutral']))


[2A] Test: loss 0.6921 acc 0.6990


  model_2a.load_state_dict(torch.load("best_model_2a.pth"))


              precision    recall  f1-score   support

     entails       0.65      0.52      0.58       842
     neutral       0.72      0.81      0.77      1284

    accuracy                           0.70      2126
   macro avg       0.69      0.67      0.67      2126
weighted avg       0.69      0.70      0.69      2126



In [49]:
import matplotlib.pyplot as plt
import numpy as np

ix_to_word = {v:k for k,v in word_to_ix.items()}

@torch.no_grad()
def viz_encoder_attn(model, dataset, idx, topk=15):
    model.eval()
    prem_ids = torch.LongTensor([dataset.premise_idx[idx][:MAX_LENGTH_PREMISE]]).to(device)
    hyp_ids  = torch.LongTensor([dataset.hypothesis_idx[idx][:MAX_LENGTH_HYPOTHESIS]]).to(device)
    logits, alpha = model(prem_ids, hyp_ids)         # alpha: [1, Tp]
    alpha = alpha.squeeze(0).cpu().numpy()
    tokens = [ix_to_word.get(i, "<UNK>") for i in prem_ids.squeeze(0).cpu().numpy()]
    # top-k table
    order = np.argsort(alpha)[::-1][:topk]
    print("Top-k premise tokens by attention:")
    for j in order:
        print(f"{tokens[j]:<15}  {alpha[j]:.4f}")

    # heatmap (horizontal)
    plt.figure(figsize=(min(16, 0.6*len(tokens)), 2.8))
    plt.imshow(alpha[np.newaxis, :], aspect='auto')
    plt.yticks([]); plt.xticks(range(len(tokens)), tokens, rotation=90)
    plt.title("Encoder-side attention over premise")
    plt.tight_layout(); plt.show()

# Example:
# viz_encoder_attn(model_2a, test_dataset, idx=0, topk=15)


In [54]:
model_2b = DecoderSideAttentionNLI(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_CLASSES, pad_idx=0).to(device)
opt_2b   = optim.Adam(model_2b.parameters(), lr=LEARNING_RATE)

best_val = 0.0
for epoch in range(NUM_EPOCHS):
    tr_loss, tr_acc = train_epoch(model_2b, train_loader, criterion, opt_2b, device)
    va_loss, va_acc, _, _ = evaluate(model_2b, val_loader, criterion, device)
    print(f"[2B] Epoch {epoch+1}: train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f}")
    if va_acc > best_val:
        best_val = va_acc
        torch.save(model_2b.state_dict(), "best_model_2b.pth")
        print("  ↳ saved best_model_2b.pth")

# test
try:
    state = torch.load("best_model_2b.pth", weights_only=True)
except TypeError:
    state = torch.load("best_model_2b.pth")
model_2b.load_state_dict(state)

test_loss, test_acc, test_preds, test_true = evaluate(model_2b, test_loader, criterion, device)
print(f"[2B] Test: loss {test_loss:.4f} acc {test_acc:.4f}")
print(classification_report(test_true, test_preds, target_names=['entails','neutral'], zero_division=0))


[2B] Epoch 1: train 0.5297/0.7323 | val 0.6151/0.6656
  ↳ saved best_model_2b.pth
[2B] Epoch 2: train 0.3675/0.8374 | val 0.6734/0.6434
[2B] Epoch 3: train 0.2644/0.8892 | val 0.7521/0.6741
  ↳ saved best_model_2b.pth
[2B] Epoch 4: train 0.1713/0.9313 | val 1.0444/0.6695
[2B] Epoch 5: train 0.0996/0.9606 | val 1.2775/0.6840
  ↳ saved best_model_2b.pth
[2B] Epoch 6: train 0.0587/0.9789 | val 1.7307/0.6802
[2B] Epoch 7: train 0.0387/0.9865 | val 1.5759/0.6902
  ↳ saved best_model_2b.pth
[2B] Epoch 8: train 0.0240/0.9923 | val 1.7992/0.6948
  ↳ saved best_model_2b.pth
[2B] Epoch 9: train 0.0205/0.9935 | val 2.0759/0.6925
[2B] Epoch 10: train 0.0244/0.9919 | val 2.0109/0.6825
[2B] Test: loss 1.8880 acc 0.7197
              precision    recall  f1-score   support

     entails       0.69      0.54      0.60       842
     neutral       0.73      0.84      0.78      1284

    accuracy                           0.72      2126
   macro avg       0.71      0.69      0.69      2126
weighted avg 

In [55]:
@torch.no_grad()
def cross_attn_heatmap(model, dataset, idx):
    model.eval()
    prem = torch.LongTensor([dataset.premise_idx[idx][:MAX_LENGTH_PREMISE]]).to(device)
    hyp  = torch.LongTensor([dataset.hypothesis_idx[idx][:MAX_LENGTH_HYPOTHESIS]]).to(device)
    logits, A = model(prem, hyp)   # A: [1, Th, Tp]
    A = A.squeeze(0).cpu().numpy()
    prem_toks = [ix_to_word.get(i, "<UNK>") for i in prem.squeeze(0).cpu().numpy()]
    hyp_toks  = [ix_to_word.get(i, "<UNK>") for i in hyp.squeeze(0).cpu().numpy()]

    plt.figure(figsize=(min(16, 0.4*len(prem_toks)), min(10, 0.4*len(hyp_toks))))
    plt.imshow(A, aspect='auto')
    plt.yticks(range(len(hyp_toks)), hyp_toks)
    plt.xticks(range(len(prem_toks)), prem_toks, rotation=90)
    plt.title("Decoder-side cross-attention (rows=hyp, cols=prem)")
    plt.tight_layout(); plt.show()

# Example:
# cross_attn_heatmap(model_2b, test_dataset, idx=0)
