In [None]:
import spacy
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

from utils import (
    tokenize_to_ids,
)

nlp = spacy.load(
    "en_core_web_lg",
    exclude=[
        "parser",
        "tagger",
        "ner",
        "textcat",
        "lemmatizer",
        "attribute_ruler",
        "tok2vec",
    ],
)
print("unique vector size", len(nlp.vocab.vectors))

# Hyper‑parameters
MAX_LEN = 64
NUM_CLASSES = 3
NR_UNK = 100

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

label_map = {"entailment": 0, "contradiction": 1, "neutral": 2}

# reverse it: id→name
id2label = {v: k for k, v in label_map.items()}

# ESIM Model NLI Inference 

In [None]:
def infer_and_plot(model, text1: str, text2: str):
    """Run inference & show both Premise→Hypothesis and Hypothesis→Premise attention."""
    # 1) Tokenize → ids → tensors
    ids1 = tokenize_to_ids([text1], nlp, MAX_LEN, NR_UNK)
    ids2 = tokenize_to_ids([text2], nlp, MAX_LEN, NR_UNK)
    x1 = torch.tensor(ids1, device=device)
    x2 = torch.tensor(ids2, device=device)
    l1 = (x1 != 0).sum(1)
    l2 = (x2 != 0).sum(1)

    # 2) Forward pass w/ attention
    model.eval()
    with torch.no_grad():
        logits, (att_p2h, att_h2p) = model(x1, l1, x2, l2, return_attention=True)
        probs = F.softmax(logits, dim=1).cpu().numpy()[0]

    pred = probs.argmax()
    print(f"Prediction: {id2label[pred]}")
    print(
        "Scores: "
        + ", ".join(f"{id2label[i]}={probs[i]:.3f}" for i in range(len(probs)))
    )

    # 3) Extract tokens
    toks1 = [t.text for t in list(nlp(text1))[: l1.item()]]
    toks2 = [t.text for t in list(nlp(text2))[: l2.item()]]

    # 4) Plot both attentions side‑by‑side
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Premise → Hypothesis
    A1 = att_p2h[0, : len(toks1), : len(toks2)].cpu().numpy()
    sns.heatmap(
        A1,
        ax=axes[0],
        xticklabels=toks2,
        yticklabels=toks1,
        cmap="viridis",
        cbar_kws={"label": "attention"},
        fmt=".2f",
    )
    axes[0].set_title("Premise → Hypothesis")
    axes[0].set_xlabel("Hypothesis tokens")
    axes[0].set_ylabel("Premise tokens")
    axes[0].tick_params(axis="x", rotation=45)
    axes[0].tick_params(axis="y", rotation=0)

    # Hypothesis → Premise
    A2 = att_h2p[0, : len(toks2), : len(toks1)].cpu().numpy()
    sns.heatmap(
        A2,
        ax=axes[1],
        xticklabels=toks1,
        yticklabels=toks2,
        cmap="magma",
        cbar_kws={"label": "attention"},
        fmt=".2f",
    )
    axes[1].set_title("Hypothesis → Premise")
    axes[1].set_xlabel("Premise tokens")
    axes[1].set_ylabel("Hypothesis tokens")
    axes[1].tick_params(axis="x", rotation=45)
    axes[1].tick_params(axis="y", rotation=0)

    plt.tight_layout()
    plt.show()

In [None]:
model = torch.load("data/esim_nli_model.pt", map_location=device, weights_only=False)
model.eval()

In [None]:
premise = "in the park alice plays a flute solo"
hypothesis = "someone playing music outside"

infer_and_plot(model, premise, hypothesis)

# BERT Model NLI Inference

In [None]:
tokenizer = AutoTokenizer.from_pretrained("data/checkpoints/bert-snli")

model = AutoModelForSequenceClassification.from_pretrained(
    "data/checkpoints/bert-snli",
    output_attentions=True,
    attn_implementation="eager",
    output_hidden_states=True,
).to(device)

model.eval()

In [None]:
# Prepare your sentence pair
premise = "in the park alice plays a flute solo"
hypothesis = "someone playing music outside"

In [None]:
# Tokenize for classification
inputs = tokenizer(
    premise,
    hypothesis,
    padding="max_length",
    truncation=True,
    max_length=128,
    return_tensors="pt",
).to(device)

In [None]:
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

# Inference
with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
attentions = outputs.attentions  # tuple of (layers) x (batch, heads, L, L)

pred = int(probs.argmax())
print(f"Prediction: {id2label[pred]}")

print(
    "Scores: " + ", ".join(f"{id2label[i]}={probs[i]:.3f}" for i in range(len(probs)))
)

In [None]:
last_hidden = outputs.hidden_states[-1][0]  # shape (seq_len, hidden_dim)

# Find token indices for premise & hypothesis (skip [CLS], [SEP], [PAD])
ids = inputs["input_ids"][0].cpu().tolist()
sep_id = tokenizer.sep_token_id
sep_positions = [i for i, tok in enumerate(ids) if tok == sep_id]
premise_idx = list(range(1, sep_positions[0]))
hypo_idx = list(range(sep_positions[0] + 1, sep_positions[1]))

# Slice embeddings
P = last_hidden[premise_idx]  # (Lp, D)
H = last_hidden[hypo_idx]  # (Lh, D)

# Compute raw similarity matrix & normalize like ESIM
sim_matrix = torch.matmul(P, H.T)  # (Lp, Lh)
attn_p2h = F.softmax(sim_matrix, dim=1)  # premise→hypo rows sum to 1
attn_h2p = F.softmax(sim_matrix.T, dim=1)  # hypo→premise rows sum to 1

# Convert IDs → tokens
premise_tokens = tokenizer.convert_ids_to_tokens([ids[i] for i in premise_idx])
hypothesis_tokens = tokenizer.convert_ids_to_tokens([ids[i] for i in hypo_idx])

In [None]:
# Plot side-by-side like ESIM
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Premise → Hypothesis
sns.heatmap(
    attn_p2h.cpu().numpy(),
    ax=axes[0],
    xticklabels=hypothesis_tokens,
    yticklabels=premise_tokens,
    cmap="viridis",
    cbar_kws={"label": "attention"},
    fmt=".2f",
)
axes[0].set_title("Premise → Hypothesis")
axes[0].set_xlabel("Hypothesis tokens")
axes[0].set_ylabel("Premise tokens")
axes[0].tick_params(axis="x", rotation=45)
plt.setp(axes[0].get_xticklabels(), ha="right")
axes[0].tick_params(axis="y", rotation=0)

# Hypothesis → Premise
sns.heatmap(
    attn_h2p.cpu().numpy(),
    ax=axes[1],
    xticklabels=premise_tokens,
    yticklabels=hypothesis_tokens,
    cmap="magma",
    cbar_kws={"label": "attention"},
    fmt=".2f",
)
axes[1].set_title("Hypothesis → Premise")
axes[1].set_xlabel("Premise tokens")
axes[1].set_ylabel("Hypothesis tokens")
axes[1].tick_params(axis="x", rotation=45)
plt.setp(axes[1].get_xticklabels(), ha="right")
axes[1].tick_params(axis="y", rotation=0)

plt.tight_layout()
plt.show()