# Model B – ESIM-Style BiGRU with Inference Composition

This notebook implements the ESIM-style BiGRU model (Model B) for the CITS4012 Natural Language Processing project. The implementation adheres to the project specification by using the provided science-domain NLI dataset, training all parameters from scratch with PyTorch, and avoiding any prohibited resources (e.g., Hugging Face models). Detailed explanations accompany each step to keep the workflow transparent and reproducible.

## Specification Alignment
- Inputs come exclusively from `train.json`, `validation.json`, and `test.json` supplied with the project.
- Only PyTorch (permitted framework) is used for modelling; no pretrained language model checkpoints are loaded.
- A shared BiGRU encoder with soft-alignment attention, local inference composition, and pooling satisfies the architectural requirements for Model B.
- The notebook records configuration, training, and evaluation logs so the reported numbers are reproducible and compliant with the requirement to include the running log.

## Implementation Roadmap
1. Configure deterministic training utilities and parse the dataset according to the specification.
2. Inspect the dataset to understand label balance and sequence-length statistics for science-domain premises and hypotheses.
3. Build a vocabulary from the training split, implement token-to-index conversion, and construct PyTorch datasets/dataloaders with dynamic padding.
4. Define the ESIM-style BiGRU architecture, covering encoding, cross-attention, local inference enhancement, composition, and classification.
5. Train the model with early stopping on the validation set, evaluate on validation and test data, and report detailed metrics.
6. Run a brief qualitative attention inspection to ground the alignment behaviour of the model.

In [2]:
import json
import math
import random
import re
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm.auto import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.8.0
CUDA available: False


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
LABEL_TO_ID = {"neutral": 0, "entails": 1}
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}

@dataclass
class Config:
    data_dir: Path = Path('.')
    min_freq: int = 2
    max_vocab_size: Optional[int] = 30000
    embed_dim: int = 200
    hidden_size: int = 128
    projection_dim: int = 256
    mlp_dim: int = 256
    dropout: float = 0.3
    batch_size: int = 32
    num_epochs: int = 10
    patience: int = 3
    learning_rate: float = 1e-3
    weight_decay: float = 1e-5
    max_grad_norm: float = 5.0
    seed: int = 2025


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def clean_text(text: str) -> str:
    return text.replace('\n', ' ').strip()


def tokenize(text: str) -> List[str]:
    text = clean_text(text.lower())
    tokens = re.findall(r"[a-z0-9]+(?:'[a-z0-9]+)?|[^\w\s]", text)
    return tokens if tokens else ["<empty>"]


def load_split(path: Path) -> List[Dict[str, Optional[str]]]:
    with open(path) as f:
        raw = json.load(f)
    ids = sorted(raw['premise'].keys(), key=lambda x: int(x))
    samples: List[Dict[str, Optional[str]]] = []
    for sid in ids:
        sample = {
            'id': int(sid),
            'premise': raw['premise'][sid],
            'hypothesis': raw['hypothesis'][sid],
            'label': raw['label'].get(sid) if 'label' in raw else None,
        }
        samples.append(sample)
    return samples


def attach_tokens(samples: List[Dict[str, Optional[str]]]) -> None:
    for sample in samples:
        sample['premise_tokens'] = tokenize(sample['premise'])
        sample['hypothesis_tokens'] = tokenize(sample['hypothesis'])


def describe_split(name: str, samples: List[Dict[str, Optional[str]]]) -> Dict[str, float]:
    premise_lengths = np.array([len(s['premise_tokens']) for s in samples])
    hypothesis_lengths = np.array([len(s['hypothesis_tokens']) for s in samples])
    labels = [s['label'] for s in samples if s['label'] is not None]
    stats = {
        'num_examples': len(samples),
        'premise_mean': float(premise_lengths.mean()),
        'premise_median': float(np.median(premise_lengths)),
        'premise_p95': float(np.percentile(premise_lengths, 95)),
        'hyp_mean': float(hypothesis_lengths.mean()),
        'hyp_median': float(np.median(hypothesis_lengths)),
        'hyp_p95': float(np.percentile(hypothesis_lengths, 95)),
    }
    print(f"Split: {name}")
    print(f"  Examples: {stats['num_examples']}")
    print(f"  Premise tokens -> mean {stats['premise_mean']:.1f} | median {stats['premise_median']:.1f} | 95th pct {stats['premise_p95']:.1f}")
    print(f"  Hypothesis tokens -> mean {stats['hyp_mean']:.1f} | median {stats['hyp_median']:.1f} | 95th pct {stats['hyp_p95']:.1f}")
    if labels:
        counter = Counter(labels)
        total = sum(counter.values())
        for label, count in counter.items():
            print(f"  Label '{label}': {count} ({count / total:.2%})")
    else:
        print("  Labels not provided for this split.")
    print()
    return stats

In [4]:
class Vocabulary:
    def __init__(self, stoi: Dict[str, int]):
        self.stoi = stoi
        self.itos = {idx: token for token, idx in stoi.items()}
        self.pad_id = self.stoi[PAD_TOKEN]
        self.unk_id = self.stoi[UNK_TOKEN]

    @classmethod
    def build(cls, samples: Iterable[Dict[str, Optional[str]]], min_freq: int = 1, max_size: Optional[int] = None) -> 'Vocabulary':
        counter: Counter = Counter()
        for sample in samples:
            counter.update(sample['premise_tokens'])
            counter.update(sample['hypothesis_tokens'])
        most_common = [tok for tok, freq in counter.most_common() if freq >= min_freq]
        if max_size is not None:
            capacity = max_size - 2
            most_common = most_common[:max(0, capacity)]
        stoi = {PAD_TOKEN: 0, UNK_TOKEN: 1}
        for token in most_common:
            if token not in stoi:
                stoi[token] = len(stoi)
        return cls(stoi)

    def __len__(self) -> int:
        return len(self.stoi)

    def encode(self, tokens: List[str]) -> List[int]:
        return [self.stoi.get(tok, self.unk_id) for tok in tokens]

    def decode(self, ids: List[int]) -> List[str]:
        return [self.itos.get(idx, UNK_TOKEN) for idx in ids]


class NLIDataset(Dataset):
    def __init__(self, samples: List[Dict[str, Optional[str]]], vocab: Vocabulary, label_to_id: Dict[str, int]):
        self.vocab = vocab
        self.label_to_id = label_to_id
        self.samples = []
        for sample in samples:
            item = {
                'id': sample['id'],
                'premise_ids': vocab.encode(sample['premise_tokens']),
                'hypothesis_ids': vocab.encode(sample['hypothesis_tokens']),
                'premise_text': sample['premise'],
                'hypothesis_text': sample['hypothesis'],
                'label': label_to_id[sample['label']] if sample['label'] is not None else None,
            }
            self.samples.append(item)

    def __len__(self) -> int:
        return len(self.samples)

    def decode_tokens(self, ids: List[int]) -> List[str]:
        return [self.vocab.itos.get(idx, UNK_TOKEN) for idx in ids]

    def __getitem__(self, idx: int) -> Dict[str, Optional[torch.Tensor]]:
        return self.samples[idx]


def build_collate_fn(pad_id: int):
    def collate_fn(batch: List[Dict[str, Optional[torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        premise_seqs = [torch.tensor(item['premise_ids'], dtype=torch.long) for item in batch]
        hypothesis_seqs = [torch.tensor(item['hypothesis_ids'], dtype=torch.long) for item in batch]
        premise_lengths = torch.tensor([len(seq) for seq in premise_seqs], dtype=torch.long)
        hypothesis_lengths = torch.tensor([len(seq) for seq in hypothesis_seqs], dtype=torch.long)
        padded_premise = pad_sequence(premise_seqs, batch_first=True, padding_value=pad_id)
        padded_hypothesis = pad_sequence(hypothesis_seqs, batch_first=True, padding_value=pad_id)
        labels_list = [item['label'] for item in batch]
        labels = None
        if all(label is not None for label in labels_list):
            labels = torch.tensor(labels_list, dtype=torch.long)
        ids = torch.tensor([item['id'] for item in batch], dtype=torch.long)
        return {
            'premise': padded_premise,
            'premise_lengths': premise_lengths,
            'hypothesis': padded_hypothesis,
            'hypothesis_lengths': hypothesis_lengths,
            'labels': labels,
            'ids': ids,
        }
    return collate_fn

In [5]:
class ESIMBiGRU(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_size: int,
        projection_dim: int,
        mlp_dim: int,
        num_classes: int,
        padding_idx: int,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.padding_idx = padding_idx
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.encoder = nn.GRU(embed_dim, hidden_size, batch_first=True, bidirectional=True)
        self.projection = nn.Sequential(
            nn.Linear(hidden_size * 8, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.composition = nn.GRU(projection_dim, hidden_size, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 8, mlp_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, num_classes),
        )
        self.dropout = nn.Dropout(dropout)
        self._init_parameters()

    def _init_parameters(self) -> None:
        nn.init.xavier_uniform_(self.embedding.weight)
        self.embedding.weight.data[self.padding_idx] = 0
        for gru in [self.encoder, self.composition]:
            for name, param in gru.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)
        for module in self.projection:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)
        for module in self.classifier:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)

    @staticmethod
    def masked_softmax(tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
        mask = mask.to(dtype=torch.bool)
        tensor = tensor.masked_fill(~mask, float('-inf'))
        return torch.softmax(tensor, dim=dim)

    def forward(
        self,
        premise: torch.Tensor,
        hypothesis: torch.Tensor,
        premise_lengths: torch.Tensor,
        hypothesis_lengths: torch.Tensor,
        return_alignments: bool = False,
    ):
        premise_mask = premise != self.padding_idx
        hypothesis_mask = hypothesis != self.padding_idx

        premise_embed = self.dropout(self.embedding(premise))
        hypothesis_embed = self.dropout(self.embedding(hypothesis))

        premise_encoded, _ = self.encoder(premise_embed)
        hypothesis_encoded, _ = self.encoder(hypothesis_embed)
        premise_encoded = self.dropout(premise_encoded)
        hypothesis_encoded = self.dropout(hypothesis_encoded)

        similarity = torch.bmm(premise_encoded, hypothesis_encoded.transpose(1, 2))

        hyp_mask_expanded = hypothesis_mask.unsqueeze(1).expand_as(similarity)
        prem_mask_expanded = premise_mask.unsqueeze(1).expand_as(similarity.transpose(1, 2))
        weight_premise = self.masked_softmax(similarity, hyp_mask_expanded, dim=-1)
        weight_hypothesis = self.masked_softmax(similarity.transpose(1, 2), prem_mask_expanded, dim=-1)

        attended_premise = torch.bmm(weight_premise, hypothesis_encoded)
        attended_hypothesis = torch.bmm(weight_hypothesis, premise_encoded)

        premise_combined = torch.cat(
            [
                premise_encoded,
                attended_premise,
                premise_encoded - attended_premise,
                premise_encoded * attended_premise,
            ],
            dim=-1,
        )
        hypothesis_combined = torch.cat(
            [
                hypothesis_encoded,
                attended_hypothesis,
                hypothesis_encoded - attended_hypothesis,
                hypothesis_encoded * attended_hypothesis,
            ],
            dim=-1,
        )

        premise_projected = self.projection(premise_combined)
        hypothesis_projected = self.projection(hypothesis_combined)

        premise_composed, _ = self.composition(premise_projected)
        hypothesis_composed, _ = self.composition(hypothesis_projected)

        premise_composed = self.dropout(premise_composed)
        hypothesis_composed = self.dropout(hypothesis_composed)

        premise_avg = self.masked_mean(premise_composed, premise_mask)
        hypothesis_avg = self.masked_mean(hypothesis_composed, hypothesis_mask)
        premise_max = self.masked_max(premise_composed, premise_mask)
        hypothesis_max = self.masked_max(hypothesis_composed, hypothesis_mask)

        combined = torch.cat([premise_avg, premise_max, hypothesis_avg, hypothesis_max], dim=-1)
        logits = self.classifier(combined)

        if return_alignments:
            return logits, weight_premise, weight_hypothesis
        return logits

    @staticmethod
    def masked_mean(sequence: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        mask = mask.unsqueeze(-1).type_as(sequence)
        masked_seq = sequence * mask
        summed = masked_seq.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1.0)
        return summed / counts

    @staticmethod
    def masked_max(sequence: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        mask = mask.unsqueeze(-1)
        masked_seq = sequence.masked_fill(~mask, float('-inf'))
        return masked_seq.max(dim=1).values

In [6]:
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    max_grad_norm: float,
) -> Dict[str, float]:
    model.train()
    total_loss = 0.0
    total_examples = 0
    all_preds: List[int] = []
    all_labels: List[int] = []

    for batch in tqdm(dataloader, desc="Train", leave=False):
        optimizer.zero_grad(set_to_none=True)
        premise = batch['premise'].to(device)
        hypothesis = batch['hypothesis'].to(device)
        premise_lengths = batch['premise_lengths'].to(device)
        hypothesis_lengths = batch['hypothesis_lengths'].to(device)
        labels = batch['labels'].to(device)

        logits = model(premise, hypothesis, premise_lengths, hypothesis_lengths)
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        batch_size = premise.size(0)
        total_loss += loss.item() * batch_size
        total_examples += batch_size
        preds = logits.argmax(dim=-1)
        all_preds.extend(preds.detach().cpu().tolist())
        all_labels.extend(labels.detach().cpu().tolist())

    avg_loss = total_loss / total_examples
    accuracy = accuracy_score(all_labels, all_preds)
    return {"loss": avg_loss, "accuracy": accuracy}


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    total_loss = 0.0
    total_examples = 0
    all_preds: List[int] = []
    all_labels: List[int] = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Eval", leave=False):
            premise = batch['premise'].to(device)
            hypothesis = batch['hypothesis'].to(device)
            premise_lengths = batch['premise_lengths'].to(device)
            hypothesis_lengths = batch['hypothesis_lengths'].to(device)
            labels = batch['labels']
            logits = model(premise, hypothesis, premise_lengths, hypothesis_lengths)
            if labels is not None:
                labels = labels.to(device)
                loss = criterion(logits, labels)
                batch_size = premise.size(0)
                total_loss += loss.item() * batch_size
                total_examples += batch_size
                all_labels.extend(labels.detach().cpu().tolist())
                preds = logits.argmax(dim=-1)
                all_preds.extend(preds.detach().cpu().tolist())
            else:
                preds = logits.argmax(dim=-1)
                all_preds.extend(preds.detach().cpu().tolist())

    metrics = {"predictions": all_preds}
    if total_examples > 0:
        metrics["loss"] = total_loss / total_examples
        metrics["accuracy"] = accuracy_score(all_labels, all_preds)
        metrics["labels"] = all_labels
    return metrics


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: Config,
    device: torch.device,
) -> Tuple[nn.Module, List[Dict[str, float]]]:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    best_state = None
    best_val_accuracy = -math.inf
    epochs_without_improvement = 0
    history: List[Dict[str, float]] = []

    for epoch in range(1, config.num_epochs + 1):
        print(f"Epoch {epoch}/{config.num_epochs}")
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, config.max_grad_norm)
        val_metrics = evaluate(model, val_loader, criterion, device)

        log_entry = {
            'epoch': epoch,
            'train_loss': train_metrics['loss'],
            'train_accuracy': train_metrics['accuracy'],
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy'],
        }
        history.append(log_entry)
        print(f"  Train -> loss: {train_metrics['loss']:.4f}, accuracy: {train_metrics['accuracy']:.4f}")
        print(f"  Val   -> loss: {val_metrics['loss']:.4f}, accuracy: {val_metrics['accuracy']:.4f}")

        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            epochs_without_improvement = 0
            print("  New best validation accuracy; checkpoint updated.")
        else:
            epochs_without_improvement += 1
            print(f"  No improvement for {epochs_without_improvement} epoch(s).")
            if epochs_without_improvement >= config.patience:
                print("  Early stopping triggered.")
                break
        print()

    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history

In [7]:
config = Config()
set_seed(config.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_path = config.data_dir / 'train.json'
val_path = config.data_dir / 'validation.json'
test_path = config.data_dir / 'test.json'

train_samples = load_split(train_path)
val_samples = load_split(val_path)
test_samples = load_split(test_path)

for split in (train_samples, val_samples, test_samples):
    attach_tokens(split)

train_stats = describe_split('train', train_samples)
val_stats = describe_split('validation', val_samples)
test_stats = describe_split('test', test_samples)

print("Sample training instance:")
print(json.dumps({
    'premise': train_samples[0]['premise'],
    'hypothesis': train_samples[0]['hypothesis'],
    'label': train_samples[0]['label'],
}, indent=2))

Split: train
  Examples: 23088
  Premise tokens -> mean 21.1 | median 18.0 | 95th pct 38.0
  Hypothesis tokens -> mean 13.2 | median 12.0 | 95th pct 22.0
  Label 'neutral': 14618 (63.31%)
  Label 'entails': 8470 (36.69%)

Split: validation
  Examples: 1304
  Premise tokens -> mean 19.9 | median 17.0 | 95th pct 39.0
  Hypothesis tokens -> mean 13.8 | median 13.0 | 95th pct 23.0
  Label 'neutral': 647 (49.62%)
  Label 'entails': 657 (50.38%)

Split: test
  Examples: 2126
  Premise tokens -> mean 19.3 | median 18.0 | 95th pct 37.0
  Hypothesis tokens -> mean 14.0 | median 13.0 | 95th pct 25.0
  Label 'neutral': 1284 (60.40%)
  Label 'entails': 842 (39.60%)

Sample training instance:
{
  "premise": "Pluto rotates once on its axis every 6.39 Earth days;",
  "hypothesis": "Earth rotates on its axis once times in one day.",
  "label": "neutral"
}


The dataset inspection confirms that premises are longer than hypotheses (median lengths ≈18 vs ≈12 tokens). The label distribution is moderately imbalanced, with neutral outweighing entails in the train/test splits, so the optimisation objective must handle this skew. The 95th-percentile sequence lengths (≈38 and ≈23 tokens) guide the choice of batch padding and help size the BiGRU hidden state without excessive memory usage.

In [8]:
vocab = Vocabulary.build(train_samples, min_freq=config.min_freq, max_size=config.max_vocab_size)
print(f"Vocabulary size: {len(vocab)} (including pad/unk)")

train_dataset = NLIDataset(train_samples, vocab, LABEL_TO_ID)
val_dataset = NLIDataset(val_samples, vocab, LABEL_TO_ID)
test_dataset = NLIDataset(test_samples, vocab, LABEL_TO_ID)

collate_fn = build_collate_fn(vocab.pad_id)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

batch_example = next(iter(train_loader))
print('Batch tensor shapes:')
for key, value in batch_example.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {tuple(value.size())}")

Vocabulary size: 11447 (including pad/unk)
Batch tensor shapes:
  premise: (32, 40)
  premise_lengths: (32,)
  hypothesis: (32, 21)
  hypothesis_lengths: (32,)
  labels: (32,)
  ids: (32,)


The dataloaders dynamically pad each batch to the maximum length within that batch. This keeps computation efficient while preserving all tokens for the attention mechanism. The vocabulary is capped to frequent tokens (min frequency = 2) to control the embedding matrix size and mitigate overfitting on rare words.

In [9]:
model = ESIMBiGRU(
    vocab_size=len(vocab),
    embed_dim=config.embed_dim,
    hidden_size=config.hidden_size,
    projection_dim=config.projection_dim,
    mlp_dim=config.mlp_dim,
    num_classes=len(LABEL_TO_ID),
    padding_idx=vocab.pad_id,
    dropout=config.dropout,
)
model = model.to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model parameters: 3,364,602


In [10]:
model, history = train_model(model, train_loader, val_loader, config, device)

Epoch 1/10


                                                        

  Train -> loss: 0.4414, accuracy: 0.7930
  Val   -> loss: 0.6002, accuracy: 0.6618
  New best validation accuracy; checkpoint updated.

Epoch 2/10


                                                        

  Train -> loss: 0.2791, accuracy: 0.8838
  Val   -> loss: 0.6697, accuracy: 0.7017
  New best validation accuracy; checkpoint updated.

Epoch 3/10


                                                        

  Train -> loss: 0.1945, accuracy: 0.9228
  Val   -> loss: 0.7398, accuracy: 0.7163
  New best validation accuracy; checkpoint updated.

Epoch 4/10


                                                        

  Train -> loss: 0.1348, accuracy: 0.9487
  Val   -> loss: 0.8163, accuracy: 0.7063
  No improvement for 1 epoch(s).

Epoch 5/10


                                                        

  Train -> loss: 0.0932, accuracy: 0.9641
  Val   -> loss: 1.0852, accuracy: 0.7132
  No improvement for 2 epoch(s).

Epoch 6/10


                                                        

  Train -> loss: 0.0662, accuracy: 0.9749
  Val   -> loss: 1.3192, accuracy: 0.6910
  No improvement for 3 epoch(s).
  Early stopping triggered.


In [11]:
print("Training history (best epoch highlighted):")
best_epoch = max(history, key=lambda x: x['val_accuracy'])['epoch']
for record in history:
    marker = '***' if record['epoch'] == best_epoch else '   '
    print(f"{marker} Epoch {record['epoch']:02d} | Train loss {record['train_loss']:.4f} | Train acc {record['train_accuracy']:.4f} | Val loss {record['val_loss']:.4f} | Val acc {record['val_accuracy']:.4f}")

Training history (best epoch highlighted):
    Epoch 01 | Train loss 0.4414 | Train acc 0.7930 | Val loss 0.6002 | Val acc 0.6618
    Epoch 02 | Train loss 0.2791 | Train acc 0.8838 | Val loss 0.6697 | Val acc 0.7017
*** Epoch 03 | Train loss 0.1945 | Train acc 0.9228 | Val loss 0.7398 | Val acc 0.7163
    Epoch 04 | Train loss 0.1348 | Train acc 0.9487 | Val loss 0.8163 | Val acc 0.7063
    Epoch 05 | Train loss 0.0932 | Train acc 0.9641 | Val loss 1.0852 | Val acc 0.7132
    Epoch 06 | Train loss 0.0662 | Train acc 0.9749 | Val loss 1.3192 | Val acc 0.6910


In [12]:
criterion = nn.CrossEntropyLoss()
val_metrics = evaluate(model, val_loader, criterion, device)
print("Validation report:")
print(classification_report(val_metrics['labels'], val_metrics['predictions'], target_names=[ID_TO_LABEL[i] for i in range(len(LABEL_TO_ID))]))

val_confusion = confusion_matrix(val_metrics['labels'], val_metrics['predictions'])
print("Validation confusion matrix:")
print(val_confusion)

test_metrics = evaluate(model, test_loader, criterion, device)
if 'labels' in test_metrics:
    print("\nTest report:")
    print(classification_report(test_metrics['labels'], test_metrics['predictions'], target_names=[ID_TO_LABEL[i] for i in range(len(LABEL_TO_ID))]))
    test_confusion = confusion_matrix(test_metrics['labels'], test_metrics['predictions'])
    print("Test confusion matrix:")
    print(test_confusion)
else:
    print("Test split has no labels; generated predictions for offline evaluation.")

                                                     

Validation report:
              precision    recall  f1-score   support

     neutral       0.67      0.83      0.74       647
     entails       0.78      0.60      0.68       657

    accuracy                           0.72      1304
   macro avg       0.73      0.72      0.71      1304
weighted avg       0.73      0.72      0.71      1304

Validation confusion matrix:
[[538 109]
 [261 396]]


                                                     


Test report:
              precision    recall  f1-score   support

     neutral       0.75      0.82      0.78      1284
     entails       0.68      0.58      0.63       842

    accuracy                           0.73      2126
   macro avg       0.72      0.70      0.71      2126
weighted avg       0.72      0.73      0.72      2126

Test confusion matrix:
[[1053  231]
 [ 351  491]]




### Attention inspection on a validation example
To verify that the soft-alignment component behaves sensibly, we inspect the top alignment weights for a correctly classified validation item. The weights reveal which premise tokens receive the strongest focus when matching each hypothesis token.

In [13]:
def show_alignment(model: ESIMBiGRU, dataset: NLIDataset, index: int) -> None:
    model.eval()
    sample = dataset[index]
    with torch.no_grad():
        batch = {
            'premise': torch.tensor(sample['premise_ids'], dtype=torch.long, device=device).unsqueeze(0),
            'hypothesis': torch.tensor(sample['hypothesis_ids'], dtype=torch.long, device=device).unsqueeze(0),
            'premise_lengths': torch.tensor([len(sample['premise_ids'])], dtype=torch.long, device=device),
            'hypothesis_lengths': torch.tensor([len(sample['hypothesis_ids'])], dtype=torch.long, device=device),
        }
        logits, weight_premise, weight_hypothesis = model(
            batch['premise'], batch['hypothesis'], batch['premise_lengths'], batch['hypothesis_lengths'], return_alignments=True
        )
        predicted = logits.argmax(dim=-1).item()

    premise_tokens = dataset.decode_tokens(sample['premise_ids'])
    hypothesis_tokens = dataset.decode_tokens(sample['hypothesis_ids'])

    print(f"Example ID: {sample['id']}")
    print(f"Premise: {' '.join(premise_tokens)}")
    print(f"Hypothesis: {' '.join(hypothesis_tokens)}")
    print(f"Predicted label: {ID_TO_LABEL[predicted]} | Gold label: {ID_TO_LABEL[sample['label']] if sample['label'] is not None else 'n/a'}")
    print()

    alignment = weight_premise.squeeze(0).cpu().numpy()
    for i, hyp_token in enumerate(hypothesis_tokens):
        attention_row = alignment[:, i]
        top_indices = attention_row.argsort()[-3:][::-1]
        top_pairs = [(premise_tokens[j], attention_row[j]) for j in top_indices]
        formatted = ', '.join(f"{tok}:{score:.3f}" for tok, score in top_pairs)
        print(f"Hypothesis token '{hyp_token}' attends to -> {formatted}")


show_alignment(model, val_dataset, index=0)

Example ID: 0
Premise: an introduction to atoms and elements , compounds , atomic structure and bonding , the molecule and chemical reactions .
Hypothesis: replace another in a molecule happens to atoms during a substitution reaction .
Predicted label: neutral | Gold label: neutral

Hypothesis token 'replace' attends to -> the:0.069, ,:0.069, reactions:0.068
Hypothesis token 'another' attends to -> the:0.077, molecule:0.072, reactions:0.072
Hypothesis token 'in' attends to -> an:0.074, the:0.073, reactions:0.071
Hypothesis token 'a' attends to -> the:0.084, molecule:0.083, ,:0.078
Hypothesis token 'molecule' attends to -> molecule:0.133, the:0.080, an:0.073
Hypothesis token 'happens' attends to -> structure:0.111, and:0.106, and:0.104
Hypothesis token 'to' attends to -> introduction:0.115, and:0.102, an:0.100
Hypothesis token 'atoms' attends to -> elements:0.142, introduction:0.126, atoms:0.120
Hypothesis token 'during' attends to -> reactions:0.095, the:0.090, atomic:0.085
Hypothesis 

### Notes and next steps
- The ESIM-style architecture provides interpretable alignment scores and achieves competitive validation/test performance within the project constraints.
- Further work (for the full report) can explore ablations such as removing the difference/product features or exchanging max pooling for attentive pooling, as well as hyperparameter sweeps guided by the logged metrics.