## Procesamiento de Lenguaje Natural


![Colegio Bourbaki](./Images/Bourbaki.png)

## Named Entity Disambiguation

La desambiguación de entidades nombradas (Named Entity Disambiguation, **NED**) es una tarea fundamental en el campo del Procesamiento del Lenguaje Natural. Su objetivo es identificar correctamente a qué entidad del mundo real se refiere una mención ambigua en un texto.

Por ejemplo, cuando un texto menciona **“Amazon”**, esta palabra puede referirse a:

- La empresa Amazon Inc.

- El río Amazonas

- La selva amazónica

La **NED** busca resolver esa ambigüedad seleccionando la entidad correcta según el contexto.

La NED suele ir después del proceso de Reconocimiento de Entidades Nombradas (NER): NER detecta las menciones de entidades en el texto (por ejemplo: “Amazon” es una Organización). NED determina a qué “Amazon” se refiere, vinculando esa mención con una entrada específica en una base de conocimiento, como Wikidata, DBpedia o Wikipedia.


El proceso de NED se puede dividir en tres pasos principales:

- Identificación de la mención:
Se detecta la palabra o frase que podría ser una entidad (por ejemplo, “Paris”).

- Generación de candidatos:
Se buscan todas las posibles entidades que podrían corresponder a esa mención.
Ejemplo: “Paris” → {Paris, France; Paris Hilton; Paris, Texas}.

- Desambiguación (selección del candidato correcto):
Se usa el contexto del texto, las relaciones semánticas y la frecuencia para decidir cuál de las opciones es la correcta.

Por ejemplo, en la frase “Paris es una de las ciudades más visitadas del mundo”, el contexto “ciudad” ayuda a escoger Paris, France.

Existen varios métodos para realizar NED:

- Basados en reglas y diccionarios:
Usan coincidencias exactas con nombres en bases de datos o listas de entidades.

- Basados en similitud semántica:
Calculan qué tan similar es el contexto del texto con las descripciones de cada entidad candidata.

- Modelos de aprendizaje automático o profundo (Deep Learning):
Utilizan modelos entrenados sobre grandes corpus (como BERT, ELMo o spaCy transformers) para entender el contexto y predecir la entidad correcta de manera más precisa.

- Sistemas híbridos:
Combinan reglas, información semántica y modelos neuronales para mejorar la precisión.



![NED](./Images/NED_example.png)

**Ejemplo**

“Apple presentó el nuevo iPhone en su evento anual.”

NER: Detecta “Apple” como una entidad del tipo Organización.

NED: Decide que “Apple” se refiere a Apple Inc., y no a “apple” (la fruta).

El sistema puede vincular esta mención con una base de conocimiento:

Apple Inc. → Wikidata ID: Q312

**Aplicaciones**

- Motores de búsqueda semántica: mejoran la comprensión de consultas ambiguas.

- Análisis de noticias y redes sociales: para identificar correctamente a personas o empresas mencionadas.

- Sistemas de preguntas y respuestas (QA): permiten enlazar menciones a entidades reales.

- Desambiguación en bases de conocimiento: ayudan a mantener consistencia entre datos textuales y estructurados.

### Base de Datos

La base de datos que utilizaremos proviene de los sistemas **AIDA-YAGO**

AIDA-YAGO es un conjunto de datos y sistema de referencia ampliamente utilizado en investigación de Named Entity Disambiguation (NED).
Combina dos componentes principales:

AIDA → un sistema de desambiguación de entidades.

YAGO → una base de conocimiento semántico (similar a Wikidata o DBpedia).

En conjunto, AIDA-YAGO se usa para evaluar, entrenar y comparar algoritmos de NED, proporcionando textos con menciones ya enlazadas a entidades concretas del mundo real.

**YAGO: la base de conocimiento**

YAGO (Yet Another Great Ontology) es una base de conocimiento semántico creada en la Universidad de Múnich (Max Planck Institute for Informatics).
Su propósito es unir la estructura de WordNet (una base léxica del inglés) con la información de Wikipedia.

Características clave de YAGO:

-Contiene millones de entidades: personas, lugares, organizaciones, eventos, etc.

- Cada entidad está enlazada con una página de Wikipedia.

- Incluye tipos semánticos y relaciones (por ejemplo, “Barack Obama —isA→ Person”, “Obama —bornIn→ Hawaii”).

- Tiene alta precisión (≈95%), porque los datos son verificados automáticamente y por reglas consistentes.

En pocas palabras, YAGO es una gran red de conocimiento estructurado que representa hechos sobre el mundo real, usada como “referencia” para enlazar menciones de texto.

**AIDA: el sistema de desambiguación**

AIDA (Accurate Online Disambiguation of Entities) es un sistema automático de desambiguación desarrollado por el mismo grupo de investigación del Max Planck Institute. Recibe un texto con menciones detectadas (por ejemplo, de un sistema NER). Genera candidatos de entidades usando YAGO. Utiliza características contextuales y semánticas (coherencia entre entidades, similitud de contexto, relaciones en YAGO) para seleccionar la mejor entidad. Puede procesar texto en línea y enlazarlo automáticamente con YAGO o Wikipedia.

**AIDA-YAGO Dataset**

El dataset AIDA-YAGO es un conjunto de artículos de noticias en inglés (extraídos de Reuters) donde:

- Se han anotado manualmente las menciones de entidades (personas, lugares, organizaciones).

- Cada mención está enlazada a una entidad en YAGO (y por tanto a Wikipedia).

Este corpus es muy valioso porque:

- Permite evaluar la precisión de sistemas de NED.

- Contiene textos reales, no ejemplos sintéticos.

El dataset AIDA-YAGO es uno de los estándares más usados para evaluar modelos de entity linking y NED. Se emplea para:

- Entrenar modelos de aprendizaje profundo (como BERT, BLINK, REL, GENRE).

- Comparar resultados entre diferentes enfoques.

- Servir como benchmark público reproducible.

**Ejemplo**

Texto del dataset:

“Apple CEO Steve Jobs introduced the iPhone at Macworld in San Francisco.”

En el corpus AIDA-YAGO:

“Apple” → Apple Inc. (YAGO entity)

“Steve Jobs” → Steve_Jobs

“Macworld” → Macworld_Conference_and_Expo

“San Francisco” → San_Francisco

Cada una de estas menciones está etiquetada con su correspondiente entidad YAGO, lo que permite a los modelos aprender a desambiguar.


Link: https://resources.mpi-inf.mpg.de/yago-naga/aida/downloads.html?

### Librerias

In [1]:
import itertools
import os
import csv
import random
import sys

from dataclasses import dataclass
from collections import defaultdict
from collections.abc import Iterable
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

### Configuraciones

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

Device: cuda


In [3]:
print("Device:", DEVICE)

Device: cuda


In [4]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
torch.backends.cuda.matmul.fp32_precision = 'ieee'  # torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.conv.fp32_precision = 'tf32'# torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False

In [5]:
print("__Python VERSION:", sys.version)
print("__pyTorch VERSION:", torch.__version__)
print(
    "__CUDA VERSION",
)
! nvidia-smi
print("__CUDNN VERSION:", torch.backends.cudnn.version())
print("__Number CUDA Devices:", torch.cuda.device_count())
print("__Devices")
print("Active CUDA Device: GPU", torch.cuda.current_device())
print("Available devices ", torch.cuda.device_count())
print("Current cuda device ", torch.cuda.current_device())

__Python VERSION: 3.12.11 (main, Sep  5 2025, 19:35:43) [GCC 13.3.0]
__pyTorch VERSION: 2.9.0+cu128
__CUDA VERSION
Thu Nov  6 19:24:12 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650 Ti     Off |   00000000:01:00.0  On |                  N/A |
| N/A   64C    P5              9W /   50W |     350MiB /   4096MiB |     32%      Default |
|                                         |                        |                  N/A |
+------------------------

### Funciones de ayuda

In [6]:
def set_seed(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [7]:
def pad_sequences(seqs: list[list[int]], pad_idx: int):
    ''' Pads a list of sequences to the same length with pad_idx.
    Returns a tensor of shape (batch_size, max_seq_length) and a tensor of lengths.
    '''
    lengths = torch.tensor([len(s) for s in seqs], dtype=torch.long)
    max_len = int(lengths.max()) if lengths.numel() > 0 else 0
    out = torch.full((len(seqs), max_len), pad_idx, dtype=torch.long)
    for i, s in enumerate(seqs):
        if len(s) > 0:
            out[i, : len(s)] = torch.tensor(s, dtype=torch.long)
    return out, lengths

Vamos a construir una clase Vocab simple para manejar el vocabulario y la codificación de tokens.

In [27]:
class Vocab:
    ''' A simple vocabulary class that builds a mapping from tokens to indices and vice versa.
        It also keeps track of token frequencies and allows filtering by minimum frequency.
    '''
    def __init__(self, min_freq: int = 1, specials: list[str] = None):
        if specials is None:
            specials = ["<pad>", "<unk>"]
        self.min_freq = min_freq
        self.freqs: dict[str, int] = {}
        self.itos: list[str] = []
        self.stoi: dict[str, int] = {}
        self.specials = specials

    def build(self, texts: Iterable[str]):
        ''' Build the vocabulary from an iterable of texts.'''
        for t in texts:
            for tok in t.strip().split():
                self.freqs[tok] = self.freqs.get(tok, 0) + 1
        self.itos = list(self.specials)
        for i, sp in enumerate(self.specials):
            self.stoi[sp] = i
        for tok, f in sorted(self.freqs.items(), key=lambda x: (-x[1], x[0])):
            if f >= self.min_freq and tok not in self.stoi:
                self.stoi[tok] = len(self.itos)
                self.itos.append(tok)

    @property
    def pad_idx(self):
        return self.stoi["<pad>"]

    @property
    def unk_idx(self):
        return self.stoi["<unk>"]

    def encode(self, text: str):
        return [self.stoi.get(tok, self.unk_idx) for tok in text.strip().split()]

    def __len__(self):
        return len(self.itos)

Veamos como funciona el vocabulario con un ejemplo sencillo:

In [9]:
toy_data = [
    ("this is a test", 1),
    ("another test example", 0)
]
toy_vocab = Vocab(min_freq=1)
toy_vocab.build([text for text, _ in toy_data])
print("Vocab size:", len(toy_vocab))
for text, label in toy_data:
    print("Text:", text)
    print("Encoded:", toy_vocab.encode(text))
    print("Label:", label)  

Vocab size: 8
Text: this is a test
Encoded: [7, 6, 3, 2]
Label: 1
Text: another test example
Encoded: [4, 2, 5]
Label: 0


In [10]:
toy_data = [("this is a test", 1), ("another test example", 0)]
toy_vocab = Vocab(min_freq=1)
toy_vocab.build([text for text, _ in toy_data])
print("Vocab size:", len(toy_vocab))
for text, label in toy_data:
    print("Text:", text)
    print("Encoded:", toy_vocab.encode(text))
    print("Label:", label)

Vocab size: 8
Text: this is a test
Encoded: [7, 6, 3, 2]
Label: 1
Text: another test example
Encoded: [4, 2, 5]
Label: 0


In [12]:
toy_vocab.stoi

{'<pad>': 0,
 '<unk>': 1,
 'test': 2,
 'a': 3,
 'another': 4,
 'example': 5,
 'is': 6,
 'this': 7}

Ahora, crearemos 3 clases para manejar la base de Datos:

- **NEDExample**: representa un ejemplo individual con mención, contexto, candidato y etiqueta.
- **NEDDataset**: maneja la carga y almacenamiento de múltiples ejemplos desde un archivo CSV.
- **NEDCollator**: prepara lotes de datos para el entrenamiento, incluyendo el padding de secuencias. En definitiva, convierte una lista de NEDExamples en tensores de Pytorch. Primero, recibe una clase Vocabulario que se encarga de codificar, entonces por cada ejemplo en el batch va a codificar el texto, truncarlo y definir la etiqueta segun la base de datos. También realizará padding de secuencias.

In [28]:
@dataclass
class NEDExample:
    ''' A single example for Named Entity Disambiguation.
    mention: str
    context: str
    candidate: str
    label: int
    ''' 
    mention: str
    context: str
    candidate: str
    label: int


class NEDDataset(Dataset):
    def __init__(self, path: str, delimiter: str | None = None):
        self.examples: list[NEDExample] = []
        with open(path, newline="", encoding="utf-8") as f:
            sample = f.read(4096)
            f.seek(0)

            # Prefer explicit delimiter; otherwise auto-pick (tabs vs commas)
            if delimiter is None:
                delimiter = "\t" if sample.count("\t") >= sample.count(",") else ","

            reader = csv.DictReader(f, delimiter=delimiter)
            if not reader.fieldnames:
                raise ValueError("File has no header row.")

            # Normalize header names (and strip BOM)
            headers = [h.lstrip("\ufeff").strip() for h in reader.fieldnames]
            lower = [h.lower() for h in headers]

            required = {"mention", "context", "candidate", "label"}
            missing = required - set(lower)
            if missing:
                raise ValueError(f"Missing required columns: {missing}. Got: {lower}")

            header_map = {orig: low for orig, low in zip(headers, lower)}

            for raw_row in reader:
                # Only use declared headers; ignore DictReader’s None/overflow bucket
                row = {header_map[h]: (raw_row.get(h) or "") for h in headers}
                self.examples.append(
                    NEDExample(
                        mention=row["mention"].strip(),
                        context=row["context"].strip(),
                        candidate=row["candidate"].strip(),
                        label=int(str(row["label"]).strip() or 0),
                    )
                )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


class NEDCollator:
    """Collate function for NED examples."""

    def __init__(self, vocab: Vocab, ctx_max_len: int, cand_max_len: int):
        self.vocab = vocab
        self.ctx_max_len = ctx_max_len
        self.cand_max_len = cand_max_len

    def truncate(self, ids: list[int], max_len: int):
        return ids[:max_len]

    def __call__(self, batch: list[NEDExample]):
        assert batch and len(batch) > 0, "Empty batch!"
        ctx_ids, cand_ids, labels = [], [], []
        # prefer UNK if available, else PAD
        unk_idx = getattr(self.vocab, "unk_idx", self.vocab.pad_idx)

        for ex in batch:
            ctx_text = (getattr(ex, "context", "") or "").strip()
            cand_text = (getattr(ex, "candidate", "") or "").strip()

            ctx = self.vocab.encode(ctx_text)
            cand = self.vocab.encode(cand_text)

            # --- SAFETY: avoid zero-length sequences ---
            if not ctx:
                ctx = [unk_idx]
            if not cand:
                cand = [unk_idx]
            # ------------------------------------------

            ctx_ids.append(self.truncate(ctx, self.ctx_max_len))
            cand_ids.append(self.truncate(cand, self.cand_max_len))
            labels.append(int(getattr(ex, "label", 0) or 0))

        ctx_pad, ctx_len = pad_sequences(ctx_ids, self.vocab.pad_idx)
        cand_pad, cand_len = pad_sequences(cand_ids, self.vocab.pad_idx)
        labels = torch.tensor(labels, dtype=torch.float32)
        return ctx_pad, ctx_len, cand_pad, cand_len, labels

### Atención y Arquitectura Transfomer

![S2S](./Images/Seq2Seq.png)

![ex](./Images/Attn_ex.png)

![dot](./Images/dot_prod.png)

![MHA](./Images/MHA.png)

![MHA_ex](./Images/MHA_economy.png)

![Trans](./Images/Original_Transformer.png)

Fuente: Chapter 7 - Python Deep Learning - Ivan Vasilev

Por consiguiente, vamos a construir 2 clases (**LSTMEncoder** & **SelfAttentionEncoder**) que convierten un batch de tokens en vectores de tamano por secuencia, es decir, un vector por contexto o candidato:
1. token_ids -> vectores (embeddings)
2. codifica la secuencia de vectores con LSTM o atención
3. Mean Pooling y proyección, para obtener una representación fija por secuencia.


In [13]:
class SequenceEncoder(nn.Module):
    def forward(self, ids, lengths):
        raise NotImplementedError


class LSTMEncoder(SequenceEncoder):
    """BiLSTM + mean pooling over non-pad tokens."""

    def __init__(
        self, vocab_size, emb_dim, hidden_dim, num_layers=1, dropout=0.1, pad_idx=0
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            emb_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.proj = nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, ids, lengths):
        emb = self.emb(ids)
        packed = nn.utils.rnn.pack_padded_sequence(
            emb, lengths.to("cpu"), batch_first=True, enforce_sorted=False
        )
        packed_out, _ = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_out, batch_first=True
        )  # [B, T', 2H]
        # mask over actual tokens (use pad_idx, not hardcoded 0)
        mask = (ids[:, : out.size(1)] != self.pad_idx).unsqueeze(-1).float()
        summed = (out * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1.0)
        pooled = summed / denom
        return self.dropout(torch.tanh(self.proj(pooled)))


class SelfAttentionEncoder(SequenceEncoder):
    """TransformerEncoder + mean pooling over non-pad tokens."""

    def __init__(
        self,
        vocab_size,
        emb_dim,
        hidden_dim,
        n_heads=4,
        n_layers=2,
        dropout=0.1,
        pad_idx=0,
        max_len=1024,
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.max_len = max_len
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos = nn.Embedding(max_len, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers,
            enable_nested_tensor=False,
        )
        self.proj = nn.Linear(emb_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, ids, lengths):
        B, T = ids.size()
        # clip or wrap positions if needed
        if T > self.max_len:
            pos_ids = torch.arange(T, device=ids.device) % self.max_len
        else:
            pos_ids = torch.arange(T, device=ids.device)
        pos_ids = pos_ids.unsqueeze(0).expand(B, T)  # [B, T]

        x = self.emb(ids) + self.pos(pos_ids)  # [B, T, D]
        kpm = ids == self.pad_idx  # True where padding
        x = self.encoder(x, src_key_padding_mask=kpm)
        mask = (~kpm).unsqueeze(-1).float()  # non-pad mask
        pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        return self.dropout(torch.tanh(self.proj(pooled)))

Vamos a proceder con la lógica de entrenamiento. Tenemos dos posibilidades:

- **Entrenamiento par-a-par (pairwise):**

 Lo que le enseñas al modelo:

 «Para este par (contexto, candidato), ¿es la entidad correcta? Sí (1) / No (0)». Para cada mención, comparamos el candidato correcto contra uno incorrecto.
 
 Forma de los datos:

 Filas como: (mención = «Apple», contexto = «Me comí una manzana madura...», candidato = «Apple, la fruta», etiqueta = 1) (mención = «Apple», contexto = «Me comí una manzana madura...», candidato = «Apple Inc., la empresa», etiqueta = 0)

Salida y pérdida del modelo:

Un logit por fila → pasar por BCEWithLogitsLoss contra la etiqueta 0/1. Cada candidato se aprende de forma independiente. 

Inferencia (cómo elegir una entidad):

para una mención con K candidatos, puntúa cada par de forma independiente y elige el logit más alto (o softmax sobre esos logits para obtener una sensación probabilística). 

Ventajas: creación de datos sencilla, flexible (se puede reutilizar como puntuador de plausibilidad general). 

Desventajas: no hay competencia explícita entre los candidatos; puede producir puntuaciones mal calibradas entre los candidatos.

- **Entrenamiento por lista (listwise):**

Lo que le enseñas al modelo:

«Dados todos los candidatos para esta mención a la vez, haz que el candidato ideal supere al resto». Para cada mención, consideramos todos los  candidatos juntos y optimizamos para que el correcto supere a los demás.

Forma de los datos:

Agrupa a los candidatos por mención (un identificador), por ejemplo, para la misma (mención, contexto): candidatos = [fruta (etiqueta=1), empresa (etiqueta=0), …]

Resultado y pérdida del modelo:

El modelo produce una puntuación por candidato en el grupo. Aplica softmax dentro del grupo y minimiza NLL del candidato ideal: -log p(candidato ideal). Los candidatos compiten directamente.

Inferencia:

Compara siempre el grupo en su conjunto y elige la puntuación argmax (o probabilidad softmax).

Ventajas: optimiza la decisión real (clasificación dentro de un conjunto) → normalmente mejor precisión top-1.

Desventajas: necesita datos agrupados (≥1 positivo + ≥1 negativo por mención). Si un grupo solo tiene un candidato (Ki=1), la pérdida se reduce a 0 → no hay aprendizaje.

**Ejemplo**

Mención: «Manzana» en «Hoy me comí una manzana madura».

Candidatos: [fruta SI, empresa NO]

Por pares: Entrena fruta=1, empresa=0 de forma independiente. En el momento de la prueba, calcula ambas puntuaciones y elige la mayor.

Por lista: Entrena ambas juntas para esa mención, de modo que softmax(puntuación_fruta, puntuación_empresa) otorgue mayor probabilidad a la fruta.

Esa es la diferencia fundamental: plausibilidad independiente (por pares) frente a competencia directa (por lista).

Para esto, primero vamos a instanciar una clase que permita realizar el scoring:

In [14]:
class PairwiseScorer(nn.Module):
    """Pairwise scorer model for NED."""

    def __init__(self, encoder: SequenceEncoder, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, ctx_ids, ctx_len, cand_ids, cand_len):
        ctx_vec = self.encoder(ctx_ids, ctx_len)
        cand_vec = self.encoder(cand_ids, cand_len)
        feat = torch.cat(
            [ctx_vec, cand_vec, torch.abs(ctx_vec - cand_vec), ctx_vec * cand_vec],
            dim=-1,
        )
        logits = self.mlp(feat).squeeze(-1)
        return logits

In [15]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total = correct = 0
    loss_sum = 0.0
    bce = nn.BCEWithLogitsLoss()
    for batch in loader:
        ctx_ids, ctx_len, cand_ids, cand_len, labels = [
            x.to(device) if torch.is_tensor(x) else x for x in batch
        ]
        logits = model(ctx_ids, ctx_len, cand_ids, cand_len)
        loss = bce(logits, labels)
        preds = (torch.sigmoid(logits) >= 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.numel()
        loss_sum += loss.item() * labels.size(0)
    return correct / max(total, 1), loss_sum / max(total, 1)

* PairwiseScorer: 

El modelo aprende una función

f(contexto,candidato)→logit

que indica qué tan compatible es una pareja.

Durante el entrenamiento, se usa un dataset con etiquetas 1 (correcto) o 0 (incorrecto).
Se optimiza con BCEWithLogitsLoss, como en evaluate().

Luego de entrenarlo, el modelo sabe producir un puntaje alto para contextos y entidades correctas.

* evaluate: evaluación en lote

Esta función se usa para medir el rendimiento general del modelo (accuracy y pérdida promedio) sobre un conjunto de validación o prueba.

Procesa miles de ejemplos en batches.

Calcula métricas globales.

No sirve para uso interactivo, sino para comparar modelos o hiperparámetros.

Entonces, podemos armar los entrenamientos:

In [25]:
def train_pairwise(args):
    ''' Trains a pairwise NED model based on the provided arguments. '''
    set_seed(args["seed"])
    train_ds = NEDDataset(args["train"])
    dev_ds = NEDDataset(args["dev"]) if args.get("dev") else None

    # Optionally mark mentions in context
    def mark_mention(ex):
        m = (ex.mention or "").strip()
        if m:
            ex.context = f"[MENTION] {m} [/MENTION] " + ex.context
        return ex

    train_ds.examples = [mark_mention(ex) for ex in train_ds.examples]
    if dev_ds:
        dev_ds.examples = [mark_mention(ex) for ex in dev_ds.examples]
    
    # Vocab and DataLoaders
    vocab = Vocab(min_freq=args["min_freq"])
    vocab.build(
        [ex.context for ex in train_ds.examples]
        + [ex.candidate for ex in train_ds.examples]
    )
    collate = NEDCollator(vocab, args["ctx_max_len"], args["cand_max_len"])
    train_loader = DataLoader(
        train_ds, batch_size=args["batch_size"], shuffle=True, collate_fn=collate
    )
    dev_loader = (
        DataLoader(
            dev_ds, batch_size=args["batch_size"], shuffle=False, collate_fn=collate
        )
        if dev_ds
        else None
    )

    # Model
    if args["encoder"] == "lstm":
        enc = LSTMEncoder(
            len(vocab),
            args["emb_dim"],
            args["hidden_dim"],
            args["lstm_layers"],
            args["dropout"],
            pad_idx=vocab.pad_idx,
        )
    else:
        enc = SelfAttentionEncoder(
            len(vocab),
            args["emb_dim"],
            args["hidden_dim"],
            args["attn_heads"],
            args["attn_layers"],
            args["dropout"],
            pad_idx=vocab.pad_idx,
            max_len=max(args["ctx_max_len"], args["cand_max_len"]),
        )
    model = PairwiseScorer(
        enc, hidden_dim=args["hidden_dim"], dropout=args["dropout"]
    ).to(DEVICE)
    opt = torch.optim.AdamW(
        model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
    )
    bce = nn.BCEWithLogitsLoss() #BCELoss
    best_dev = 0.0
    
    # Training loop
    for epoch in range(1, args["epochs"] + 1):
        model.train()
        total = correct = 0
        loss_sum = 0.0
        for _, batch in enumerate(train_loader, 1):
            ctx_ids, ctx_len, cand_ids, cand_len, labels = [
                x.to(DEVICE) if torch.is_tensor(x) else x for x in batch
            ]
            # Forward pass
            opt.zero_grad()
            logits = model(ctx_ids, ctx_len, cand_ids, cand_len)
            loss = bce(logits, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            # Compute metrics
            preds = (torch.sigmoid(logits) >= 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.numel()
            loss_sum += loss.item() * labels.size(0)
        if dev_loader is not None:
            dev_acc, dev_loss = evaluate(model, dev_loader, DEVICE)
            print(f"Epoch {epoch} DONE | pairwise dev_loss={dev_loss:.4f} pairwise dev_acc={dev_acc:.4f}")
            best_dev = max(best_dev, dev_acc)

    print("Training complete. Best dev acc:", best_dev)
    
    return model, vocab

In [24]:
def train_listwise(args, augment_singletons=True):

    def _read_groups_from_csv(path):
        groups = defaultdict(lambda: {"context": None, "mention": None, "cands": []})
        with open(path, newline="", encoding="utf-8") as f:
            rdr = csv.DictReader(f)
            for r in rdr:
                gid = str(r["id"])
                if groups[gid]["context"] is None:
                    groups[gid]["context"] = r["context"]
                if groups[gid]["mention"] is None:
                    groups[gid]["mention"] = r.get("mention", "")
                groups[gid]["cands"].append((r["candidate"], int(r["label"])))
        return [{"id": k, **v} for k, v in groups.items()]

    set_seed(args["seed"])
    train_groups = _read_groups_from_csv(args["train"])
    dev_groups = _read_groups_from_csv(args["dev"]) if args.get("dev") else []

    # Mark mention in context
    def mark(g):
        m = (g.get("mention", "") or "").strip()
        if m:
            g["context"] = f"[MENTION] {m} [/MENTION] " + g["context"]
        return g

    train_groups = [mark(g) for g in train_groups]
    dev_groups = [mark(g) for g in dev_groups]

    # Ensure Ki >= 2 (augment singletons with a negative)
    def ensure_two(groups):
        kept = []
        # pool of candidate strings to sample negatives from
        pool = list({c for g in groups for c, _ in g["cands"]})
        for g in groups:
            Ki = len(g["cands"])
            pos_idx = next((i for i, (_, y) in enumerate(g["cands"]) if y == 1), None)
            if Ki >= 2:
                kept.append(g)
                continue
            if not augment_singletons:
                continue
            gold = g["cands"][pos_idx][0] if pos_idx is not None else None
            # pick any different candidate as negative
            neg = next((c for c in pool if c != gold), None)
            if neg is None:
                neg = (gold or "DUMMY_ENTITY") + "_NEG"
            if pos_idx is None:
                base = g["cands"][0][0] if g["cands"] else "SOME_ENTITY"
                g["cands"] = [(base, 1), (neg, 0)]
            else:
                g["cands"].append((neg, 0))
            kept.append(g)
        return kept

    train_groups = ensure_two(train_groups)
    dev_groups = ensure_two(dev_groups)

    # Build vocab
    texts = []
    for g in train_groups:
        texts.append(g["context"])
        texts.extend(c for c, _ in g["cands"])
    vocab = Vocab(min_freq=args["min_freq"])
    vocab.build(texts)
    # Model
    if args["encoder"] == "lstm":
        enc = LSTMEncoder(
            len(vocab),
            args["emb_dim"],
            args["hidden_dim"],
            args["lstm_layers"],
            args["dropout"],
            pad_idx=vocab.pad_idx,
        )
    else:
        enc = SelfAttentionEncoder(
            len(vocab),
            args["emb_dim"],
            args["hidden_dim"],
            args["attn_heads"],
            args["attn_layers"],
            args["dropout"],
            pad_idx=vocab.pad_idx,
            max_len=max(args["ctx_max_len"], args["cand_max_len"]),
        )
    model = PairwiseScorer(
        enc, hidden_dim=args["hidden_dim"], dropout=args["dropout"]
    ).to(DEVICE)
    opt = torch.optim.AdamW(
        model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
    )

    def encode_texts(vocab, texts, max_len):
        unk = getattr(vocab, "unk_idx", vocab.pad_idx)
        ids = []
        for t in texts:
            seq = vocab.encode((t or "").strip())
            if not seq:            # <-- guard: ensure length >= 1
                seq = [unk]
            ids.append(seq[:max_len])
        pad, lengths = pad_sequences(ids, vocab.pad_idx)
        return pad.to(DEVICE), lengths.to(DEVICE)

    # Training loop
    for epoch in range(1, args["epochs"] + 1):
        random.shuffle(train_groups)
        model.train()
        total_loss = steps = 0
        for i in range(0, len(train_groups), args["batch_size"]):
            batch = [
                g
                for g in train_groups[i : i + args["batch_size"]]
                if len(g["cands"]) >= 2
            ]
            if not batch:
                continue
            ctx_texts = [g["context"] for g in batch]
            cand_texts = list(
                itertools.chain.from_iterable(
                    [[c for c, _ in g["cands"]] for g in batch]
                )
            )
            ctx_ids, ctx_len = encode_texts(vocab, ctx_texts, args["ctx_max_len"])
            cand_ids, cand_len = encode_texts(vocab, cand_texts, args["cand_max_len"])

            ctx_vec = model.encoder(ctx_ids, ctx_len)  # [B, H]
            cand_vec = model.encoder(cand_ids, cand_len)  # [sumK, H]

            # Per-group scores and NLL on local gold index
            scores = []
            start = 0
            for bi, g in enumerate(batch):
                Ki = len(g["cands"])
                c_i = cand_vec[start : start + Ki]
                ctx_i = ctx_vec[bi].unsqueeze(0).expand(Ki, -1)
                feat = torch.cat(
                    [ctx_i, c_i, torch.abs(ctx_i - c_i), ctx_i * c_i], dim=-1
                )
                scores.append(model.mlp(feat).squeeze(-1))  # [Ki]
                start += Ki

            losses = []
            for s, g in zip(scores, batch):
                logp = torch.log_softmax(s, dim=0)
                gold_local = next(
                    (j for j, (_, y) in enumerate(g["cands"]) if y == 1), 0
                )
                losses.append(-logp[gold_local])
            loss = torch.stack(losses).mean()

            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += float(loss.item())
            steps += 1

        print(f"Epoch {epoch}: listwise loss={total_loss/max(steps,1):.4f}")

    return model, vocab

**score_pair_in_memory**	Puntúa una pareja (contexto, candidato)	(prob, logit)

**choose_among_candidates**	Compara varios candidatos y elige el mejor	(best_index, softmax_probs, logits)

**compare_two	Compara dos (uno correcto y otro incorrecto)**	Diccionario con resultados detallados

In [23]:
@torch.no_grad()
def score_pair_in_memory(
    model: PairwiseScorer,
    vocab: Vocab,
    context: str,
    candidate: str,
    ctx_max_len: int = 64,
    cand_max_len: int = 64,
) -> tuple[float, float]:
    model.eval()
    ctx_ids = torch.tensor(
        [vocab.encode(context)[:ctx_max_len]], dtype=torch.long, device=DEVICE
    )
    ctx_len = torch.tensor([ctx_ids.size(1)], dtype=torch.long, device=DEVICE)
    cand_ids = torch.tensor(
        [vocab.encode(candidate)[:cand_max_len]], dtype=torch.long, device=DEVICE
    )
    cand_len = torch.tensor([cand_ids.size(1)], dtype=torch.long, device=DEVICE)
    logit = model(ctx_ids, ctx_len, cand_ids, cand_len)  # shape [1]
    prob = torch.sigmoid(logit).item()
    return prob, logit.item()  # (sigmoid prob, raw logit)


@torch.no_grad()
def choose_among_candidates(
    model: PairwiseScorer,
    vocab: Vocab,
    context: str,
    candidates: list[str],
    ctx_max_len: int = 64,
    cand_max_len: int = 64,
):
    """Return (best_index, softmax_probs[list], logits[list]) for a set of candidates."""
    model.eval()
    K = len(candidates)
    if K == 0:
        return None, [], []
    # encode one context, tile to K
    ctx_enc = vocab.encode(context)[:ctx_max_len]
    ctx_ids = torch.tensor([ctx_enc], dtype=torch.long, device=DEVICE).expand(K, -1)
    ctx_len = torch.tensor([len(ctx_enc)], dtype=torch.long, device=DEVICE).expand(K)

    # encode candidates with padding
    cand_seqs = [vocab.encode(c)[:cand_max_len] for c in candidates]
    maxL = max((len(s) for s in cand_seqs), default=0)
    pad = vocab.pad_idx
    cand_ids = torch.full((K, maxL), pad, dtype=torch.long, device=DEVICE)
    cand_len = torch.tensor(
        [len(s) for s in cand_seqs], dtype=torch.long, device=DEVICE
    )
    for i, s in enumerate(cand_seqs):
        if s:
            cand_ids[i, : len(s)] = torch.tensor(s, dtype=torch.long, device=DEVICE)

    logits = model(ctx_ids, ctx_len, cand_ids, cand_len)  # [K]
    probs = torch.softmax(logits, dim=0)  # [K]
    best = int(torch.argmax(logits).item())
    return best, probs.cpu().tolist(), logits.cpu().tolist()


def compare_two(
    model: PairwiseScorer,
    vocab: Vocab,
    context: str,
    correct_candidate: str,
    incorrect_candidate: str,
):
    best, probs, logits = choose_among_candidates(
        model, vocab, context, [correct_candidate, incorrect_candidate]
    )
    return {
        "chosen": "correct" if best == 0 else "incorrect",
        "softmax_probs": {"correct": probs[0], "incorrect": probs[1]},
        "logits": {"correct": logits[0], "incorrect": logits[1]},
    }

Entonces, armaremos un toy example para definir los hiperparámetros y entrenar el modelo:

In [17]:
rows = [
    # Apple, company context
    {
        "id": 1,
        "mention": "Apple",
        "context": "I love my new Apple laptop",
        "candidate": "Apple Inc. is a technology company",
        "label": 1,
    },
    {
        "id": 1,
        "mention": "Apple",
        "context": "I love my new Apple laptop",
        "candidate": "Apple is a fruit often red or green",
        "label": 0,
    },
    # Apple, fruit context
    {
        "id": 2,
        "mention": "Apple",
        "context": "I ate a ripe apple today",
        "candidate": "Apple is a fruit often red or green",
        "label": 1,
    },
    {
        "id": 2,
        "mention": "Apple",
        "context": "I ate a ripe apple today",
        "candidate": "Apple Inc. is a technology company",
        "label": 0,
    },
    # Amazon, cloud context
    {
        "id": 3,
        "mention": "Amazon",
        "context": "Amazon released new cloud features",
        "candidate": "Amazon.com is an e-commerce and cloud company",
        "label": 1,
    },
    {
        "id": 3,
        "mention": "Amazon",
        "context": "Amazon released new cloud features",
        "candidate": "Amazon rainforest is in South America",
        "label": 0,
    },
    # Amazon, rainforest context
    {
        "id": 4,
        "mention": "Amazon",
        "context": "Tree species in the Amazon are diverse",
        "candidate": "Amazon rainforest is in South America",
        "label": 1,
    },
    {
        "id": 4,
        "mention": "Amazon",
        #"candidate":
        "context": "Amazon.com is an e-commerce and cloud company",
        "label": 0,
    },
    # Amazon, AWS context
    {
        "id": 5,
        "mention": "Amazon",
        "context": "AWS is a part of Amazon services",
        "candidate": "Amazon.com is an e-commerce and cloud company",
        "label": 1,
    },
    {
        "id": 5,
        "mention": "Amazon",
        "context": "Amazon rainforest is in South America",
        "label": 0,
    },
]

In [18]:
rows

[{'id': 1,
  'mention': 'Apple',
  'context': 'I love my new Apple laptop',
  'candidate': 'Apple Inc. is a technology company',
  'label': 1},
 {'id': 1,
  'mention': 'Apple',
  'context': 'I love my new Apple laptop',
  'candidate': 'Apple is a fruit often red or green',
  'label': 0},
 {'id': 2,
  'mention': 'Apple',
  'context': 'I ate a ripe apple today',
  'candidate': 'Apple is a fruit often red or green',
  'label': 1},
 {'id': 2,
  'mention': 'Apple',
  'context': 'I ate a ripe apple today',
  'candidate': 'Apple Inc. is a technology company',
  'label': 0},
 {'id': 3,
  'mention': 'Amazon',
  'context': 'Amazon released new cloud features',
  'candidate': 'Amazon.com is an e-commerce and cloud company',
  'label': 1},
 {'id': 3,
  'mention': 'Amazon',
  'context': 'Amazon released new cloud features',
  'candidate': 'Amazon rainforest is in South America',
  'label': 0},
 {'id': 4,
  'mention': 'Amazon',
  'context': 'Tree species in the Amazon are diverse',
  'candidate': 'A

In [19]:
os.makedirs("data", exist_ok=True)
for split in ["train", "dev"]:
    with open(f"data/{split}.csv", "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(
            f, fieldnames=["id", "mention", "context", "candidate", "label"]
        )
        w.writeheader()
        w.writerows(rows)
print("Toy data written to data/train.csv and data/dev.csv")

Toy data written to data/train.csv and data/dev.csv


In [None]:
ARGS = {
    'mode': 'pw',
    "train": "data/train.csv",
    "dev": "data/dev.csv",
    "save_dir": "models/",
    "encoder": "lstm",  # 'lstm' or 'attn'
    "emb_dim": 256,
    "hidden_dim": 512,
    "dropout": 0.2,
    "lstm_layers": 1,
    "attn_heads": 4,
    "attn_layers": 2,
    "batch_size": 1,
    "epochs": 30,
    "lr": 1e-4,
    "weight_decay": 0.001,
    "log_every": 5,
    "min_freq": 1,
    "ctx_max_len": 64,
    "cand_max_len": 64,
    "seed": 42,
}


{'mode': 'pw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [21]:
ARGS

{'mode': 'pw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [29]:
model_pw_lstm, vocab_pw_lstm = train_pairwise(ARGS)

Epoch 1 DONE | pairwise dev_loss=0.6800 pairwise dev_acc=0.8000
Epoch 2 DONE | pairwise dev_loss=0.6642 pairwise dev_acc=0.7000
Epoch 3 DONE | pairwise dev_loss=0.6443 pairwise dev_acc=0.7000
Epoch 4 DONE | pairwise dev_loss=0.6153 pairwise dev_acc=0.8000
Epoch 5 DONE | pairwise dev_loss=0.5768 pairwise dev_acc=0.9000
Epoch 6 DONE | pairwise dev_loss=0.5321 pairwise dev_acc=0.9000
Epoch 7 DONE | pairwise dev_loss=0.4844 pairwise dev_acc=0.9000
Epoch 8 DONE | pairwise dev_loss=0.4328 pairwise dev_acc=0.8000
Epoch 9 DONE | pairwise dev_loss=0.3906 pairwise dev_acc=0.8000
Epoch 10 DONE | pairwise dev_loss=0.3509 pairwise dev_acc=0.9000
Epoch 11 DONE | pairwise dev_loss=0.3177 pairwise dev_acc=0.9000
Epoch 12 DONE | pairwise dev_loss=0.2839 pairwise dev_acc=0.9000
Epoch 13 DONE | pairwise dev_loss=0.2441 pairwise dev_acc=0.9000
Epoch 14 DONE | pairwise dev_loss=0.2044 pairwise dev_acc=1.0000
Epoch 15 DONE | pairwise dev_loss=0.1619 pairwise dev_acc=1.0000
Epoch 16 DONE | pairwise dev_loss=

In [30]:
prob, logit = score_pair_in_memory(
    model_pw_lstm,
    vocab_pw_lstm,
    context="I ate a ripe apple today",
    candidate="Apple is a fruit often red or green",
)
print("P(candidate | context) =", prob, "logit:", logit)

P(candidate | context) = 0.9976624250411987 logit: 6.056290626525879


In [31]:
best_idx, probs, logits = choose_among_candidates(
    model_pw_lstm,
    vocab_pw_lstm,
    context="I ate a ripe apple today",
    candidates=[
        "Apple is a fruit often red or green",  # correct
        "Apple Inc. is a technology company",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen:", ["fruit", "company"][best_idx])

softmax probs: [0.9999979734420776, 1.9989972770417808e-06]
chosen: fruit


In [None]:
ARGS.update({"mode": "lw"})


{'mode': 'lw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [33]:
ARGS

{'mode': 'lw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [34]:
model_lw_lstm, vocab_lw_lstm = train_listwise(ARGS)

Epoch 1: listwise loss=0.7017
Epoch 2: listwise loss=0.6738
Epoch 3: listwise loss=0.6464
Epoch 4: listwise loss=0.6258
Epoch 5: listwise loss=0.5884
Epoch 6: listwise loss=0.5527
Epoch 7: listwise loss=0.5035
Epoch 8: listwise loss=0.4596
Epoch 9: listwise loss=0.4101
Epoch 10: listwise loss=0.3695
Epoch 11: listwise loss=0.3215
Epoch 12: listwise loss=0.2865
Epoch 13: listwise loss=0.2573
Epoch 14: listwise loss=0.2245
Epoch 15: listwise loss=0.2020
Epoch 16: listwise loss=0.1669
Epoch 17: listwise loss=0.1360
Epoch 18: listwise loss=0.1085
Epoch 19: listwise loss=0.0906
Epoch 20: listwise loss=0.0586
Epoch 21: listwise loss=0.0318
Epoch 22: listwise loss=0.0207
Epoch 23: listwise loss=0.0148
Epoch 24: listwise loss=0.0087
Epoch 25: listwise loss=0.0042
Epoch 26: listwise loss=0.0044
Epoch 27: listwise loss=0.0022
Epoch 28: listwise loss=0.0023
Epoch 29: listwise loss=0.0015
Epoch 30: listwise loss=0.0014


In [35]:
compare_two(
    model_lw_lstm,
    vocab_lw_lstm,
    context="I ate a ripe apple today",
    correct_candidate="Apple is a fruit often red or green",
    incorrect_candidate="Apple Inc. is a technology company",
)

{'chosen': 'correct',
 'softmax_probs': {'correct': 0.9989790916442871,
  'incorrect': 0.0010209004394710064},
 'logits': {'correct': 2.348507881164551, 'incorrect': -4.537540912628174}}

In [36]:
best_idx, probs, logits = choose_among_candidates(
    model_lw_lstm,
    vocab_lw_lstm,
    context="Amazon released new cloud features",
    candidates=[
        "Amazon.com is an e-commerce and cloud company",  # correct
        "Amazon rainforest is in South America",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen idx:", best_idx)

softmax probs: [0.9964815378189087, 0.003518466604873538]
chosen idx: 0


In [None]:
ARGS.update({"mode": "pw", "encoder": "attn"})


{'mode': 'pw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'attn',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [38]:
ARGS

{'mode': 'pw',
 'train': 'data/train.csv',
 'dev': 'data/dev.csv',
 'save_dir': 'models/',
 'encoder': 'attn',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 1,
 'epochs': 30,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [39]:
model_pw_attn, vocab_pw_attn = train_pairwise(ARGS)

Epoch 1 DONE | pairwise dev_loss=0.6689 pairwise dev_acc=0.5000
Epoch 2 DONE | pairwise dev_loss=0.6018 pairwise dev_acc=0.7000
Epoch 3 DONE | pairwise dev_loss=0.5489 pairwise dev_acc=0.7000
Epoch 4 DONE | pairwise dev_loss=0.5205 pairwise dev_acc=0.7000
Epoch 5 DONE | pairwise dev_loss=0.4947 pairwise dev_acc=0.7000
Epoch 6 DONE | pairwise dev_loss=0.5160 pairwise dev_acc=0.7000
Epoch 7 DONE | pairwise dev_loss=0.5653 pairwise dev_acc=0.7000
Epoch 8 DONE | pairwise dev_loss=0.4404 pairwise dev_acc=0.7000
Epoch 9 DONE | pairwise dev_loss=0.4889 pairwise dev_acc=0.7000
Epoch 10 DONE | pairwise dev_loss=0.4386 pairwise dev_acc=0.7000
Epoch 11 DONE | pairwise dev_loss=0.3989 pairwise dev_acc=0.8000
Epoch 12 DONE | pairwise dev_loss=0.3496 pairwise dev_acc=0.7000
Epoch 13 DONE | pairwise dev_loss=0.3770 pairwise dev_acc=0.8000
Epoch 14 DONE | pairwise dev_loss=0.2440 pairwise dev_acc=0.8000
Epoch 15 DONE | pairwise dev_loss=0.1104 pairwise dev_acc=1.0000
Epoch 16 DONE | pairwise dev_loss=

In [40]:
prob, logit = score_pair_in_memory(
    model_pw_attn,
    vocab_pw_attn,
    context="I ate a ripe apple today",
    candidate="Apple is a fruit often red or green",
)
print("P(candidate | context) =", prob, "logit:", logit)

P(candidate | context) = 0.9996460676193237 logit: 7.9459357261657715


In [41]:
print("P(candidate | context) =", prob, "logit:", logit)

P(candidate | context) = 0.9996460676193237 logit: 7.9459357261657715


In [42]:
best_idx, probs, logits = choose_among_candidates(
    model_pw_attn,
    vocab_pw_attn,
    context="I ate a ripe apple today",
    candidates=[
        "Apple is a fruit often red or green",  # correct
        "Apple Inc. is a technology company",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen:", ["fruit", "company"][best_idx])

softmax probs: [0.9999996423721313, 3.729472553004598e-07]
chosen: fruit


In [43]:
ARGS.update({"mode": "lw"})

In [44]:
model_lw_attn, vocab_lw_attn = train_listwise(ARGS)

Epoch 1: listwise loss=0.6821
Epoch 2: listwise loss=0.5622
Epoch 3: listwise loss=0.5009
Epoch 4: listwise loss=0.4575
Epoch 5: listwise loss=0.3950
Epoch 6: listwise loss=0.3535
Epoch 7: listwise loss=0.3000
Epoch 8: listwise loss=0.3067
Epoch 9: listwise loss=0.2919
Epoch 10: listwise loss=0.2532
Epoch 11: listwise loss=0.2052
Epoch 12: listwise loss=0.1877
Epoch 13: listwise loss=0.1598
Epoch 14: listwise loss=0.1291
Epoch 15: listwise loss=0.0954
Epoch 16: listwise loss=0.0628
Epoch 17: listwise loss=0.0366
Epoch 18: listwise loss=0.0160
Epoch 19: listwise loss=0.0056
Epoch 20: listwise loss=0.0029
Epoch 21: listwise loss=0.0023
Epoch 22: listwise loss=0.0020
Epoch 23: listwise loss=0.0011
Epoch 24: listwise loss=0.0009
Epoch 25: listwise loss=0.0012
Epoch 26: listwise loss=0.0006
Epoch 27: listwise loss=0.0009
Epoch 28: listwise loss=0.0006
Epoch 29: listwise loss=0.0006
Epoch 30: listwise loss=0.0004


In [45]:
compare_two(
    model_lw_attn,
    vocab_lw_attn,
    context="I ate a ripe apple today",
    correct_candidate="Apple is a fruit often red or green",
    incorrect_candidate="Apple Inc. is a technology company",
)

{'chosen': 'correct',
 'softmax_probs': {'correct': 0.9998024106025696,
  'incorrect': 0.00019758193229790777},
 'logits': {'correct': 2.694718837738037, 'incorrect': -5.834440231323242}}

In [46]:
best_idx, probs, logits = choose_among_candidates(
    model_lw_attn,
    vocab_lw_attn,
    context="Amazon released new cloud features",
    candidates=[
        "Amazon.com is an e-commerce and cloud company",  # correct
        "Amazon rainforest is in South America",  # incorrect\
    ],
)
print("softmax probs:", probs)
print("chosen idx:", best_idx)

softmax probs: [0.9950313568115234, 0.0049686492420732975]
chosen idx: 0


Link: https://github.com/cyanic-selkie/aida-conll-yago-wikidata/tree/main/data

Ahora utilicemos el Dataset AIDA-YAGO. Las dos celdas siguientes pueden correrlas para generar el dataset para poder ser aprovechado.

DESCOMENTAR LAS DOS CELDAS SUBSIGUIENTES PARA GENERAR EL DATASET

In [87]:
# import os, re, csv, random, urllib.parse
# from collections import defaultdict


# def _title_from_url(url: str):
#     if not url:
#         return None
#     try:
#         t = url.split("/wiki/", 1)[1]
#     except Exception:
#         return None
#     return urllib.parse.unquote(t.replace(" ", "_"))


# def load_means(means_tsv, cap=100):
#     alias = defaultdict(list)
#     with open(means_tsv, encoding="utf-8") as f:
#         rdr = csv.reader(f, delimiter="\t")
#         for row in rdr:
#             if not row:
#                 continue
#             m = row[0].strip().strip('"').lower()
#             e = row[1].strip()
#             if e and e != "--NME--" and len(alias[m]) < cap:
#                 alias[m].append(e)
#     return alias


# def parse_aida_token_link_tsv(aida_tok_tsv_path, window=50):
#     """
#     Reads the AIDA token+link TSV (your FIRST file) and returns:
#       splits = {'train': [...], 'testa': [...], 'testb': [...]}
#     Each item: {'mention': str, 'context': str, 'gold': str}
#     """
#     docs = []
#     cur = {"doc": None, "split": "train", "tokens": [], "rows": []}

#     def flush():
#         if cur["doc"] is not None:
#             docs.append(
#                 {
#                     "doc": cur["doc"],
#                     "split": cur["split"],
#                     "tokens": cur["tokens"][:],
#                     "rows": cur["rows"][:],
#                 }
#             )

#     with open(aida_tok_tsv_path, encoding="utf-8") as f:
#         for raw in f:
#             line = raw.rstrip("\n")
#             if not line:
#                 continue
#             if line.startswith("-DOCSTART-"):
#                 flush()
#                 # doc name inside parentheses
#                 m = re.search(r"-DOCSTART-\s*\(([^)]+)\)", line)
#                 docname = m.group(1) if m else line
#                 split = (
#                     "testa"
#                     if "testa" in docname.lower()
#                     else ("testb" if "testb" in docname.lower() else "train")
#                 )
#                 cur = {"doc": docname, "split": split, "tokens": [], "rows": []}
#                 continue

#             parts = line.split("\t")
#             token = parts[0] if parts else ""
#             cur["tokens"].append(token)

#             # Heuristic parse: many entity rows look like:
#             # TOKEN  B|I  MENTION_SURFACE  WIKI_TITLE  WIKI_URL  WIKI_ID  MID
#             bio = parts[1] if len(parts) >= 2 and parts[1] in {"B", "I", "O"} else None
#             wiki_url = (
#                 parts[4]
#                 if len(parts) >= 5 and parts[4].startswith("http")
#                 else (
#                     parts[3] if len(parts) >= 4 and parts[3].startswith("http") else ""
#                 )
#             )
#             wiki_title = (
#                 parts[3] if len(parts) >= 4 and not parts[3].startswith("http") else ""
#             )
#             gold = _title_from_url(wiki_url) or (
#                 wiki_title if wiki_title and wiki_title != "--NME--" else None
#             )

#             mention_surface = parts[2] if len(parts) >= 3 else ""
#             cur["rows"].append(
#                 {
#                     "token": token,
#                     "bio": bio,  # may be None for non-entity rows
#                     "surf": mention_surface,
#                     "gold": gold,  # None for --NME-- or non-entity
#                 }
#             )

#     flush()

#     # Build contiguous gold spans using BIO (+ same gold) and make context windows
#     splits = {"train": [], "testa": [], "testb": []}
#     for d in docs:
#         toks = d["tokens"]
#         rows = d["rows"]
#         i, N = 0, len(rows)
#         while i < N:
#             r = rows[i]
#             if r["bio"] != "B" or not r["gold"]:
#                 i += 1
#                 continue
#             gold = r["gold"]
#             j = i + 1
#             while j < N and rows[j]["bio"] == "I" and rows[j]["gold"] == gold:
#                 j += 1
#             # mention text: use tokens from i..j
#             mention = " ".join(toks[i:j])
#             # context window
#             L = max(0, i - window)
#             R = min(N, j + window)
#             context = " ".join(toks[L:R]) or mention
#             splits[d["split"]].append(
#                 {"mention": mention, "context": context, "gold": gold}
#             )
#             i = j
#     return splits


# def build_candidates(mention_text, gold, alias, fallback_golds, K_neg=5):
#     # candidates: [gold] + K_neg negatives from alias table (by surface), fallback to other golds
#     pool = [e for e in alias.get(mention_text.lower(), []) if e != gold]
#     random.shuffle(pool)
#     negs = pool[:K_neg]
#     if len(negs) < K_neg:
#         extra = [g for g in fallback_golds if g != gold]
#         random.shuffle(extra)
#         for e in extra:
#             if len(negs) >= K_neg:
#                 break
#             if e not in negs:
#                 negs.append(e)
#     if not negs:  # ultra edge-case
#         negs = [gold + "_NEG"]
#     return [gold] + negs


# def write_pairwise_and_listwise_from_token_link_tsv(
#     aida_tok_tsv_path,
#     means_tsv_path,
#     outdir="data/aida_csv",
#     K_neg=5,
#     seed=42,
#     window=50,
# ):
#     random.seed(seed)
#     splits = parse_aida_token_link_tsv(aida_tok_tsv_path, window=window)
#     alias = load_means(means_tsv_path)
#     os.makedirs(outdir, exist_ok=True)

#     paths = {}
#     for split, rows in splits.items():
#         pw_path = os.path.join(outdir, f"{split}_pairs.csv")
#         lw_path = os.path.join(outdir, f"{split}_listwise.csv")
#         with open(pw_path, "w", newline="", encoding="utf-8") as f1, open(
#             lw_path, "w", newline="", encoding="utf-8"
#         ) as f2:
#             cols = ["id", "mention", "context", "candidate", "label"]
#             pw = csv.DictWriter(f1, fieldnames=cols)
#             pw.writeheader()
#             lw = csv.DictWriter(f2, fieldnames=cols)
#             lw.writeheader()
#             gid = rid = 0
#             all_golds = [r["gold"] for r in rows]
#             for ex in rows:
#                 m, ctx, gold = ex["mention"], ex["context"], ex["gold"]
#                 cands = build_candidates(m, gold, alias, all_golds, K_neg=K_neg)
#                 # listwise group
#                 for c in cands:
#                     lw.writerow(
#                         {
#                             "id": gid,
#                             "mention": m,
#                             "context": ctx,
#                             "candidate": c,
#                             "label": int(c == gold),
#                         }
#                     )
#                 gid += 1
#                 # pairwise rows
#                 for c in cands:
#                     pw.writerow(
#                         {
#                             "id": rid,
#                             "mention": m,
#                             "context": ctx,
#                             "candidate": c,
#                             "label": int(c == gold),
#                         }
#                     )
#                     rid += 1
#         paths[split] = {"pairwise": pw_path, "listwise": lw_path}
#         print("Wrote", pw_path, "and", lw_path)
#     return paths

In [88]:
# paths = write_pairwise_and_listwise_from_token_link_tsv(
#     aida_tok_tsv_path="/media/pdconte/hdd/Pablo/Personal/Colegio_Bourbaki/Natural_Language_Processsing/Semana4/aida_yago2/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv",  # ← your FIRST file
#     means_tsv_path="/media/pdconte/hdd/Pablo/Personal/Colegio_Bourbaki/Natural_Language_Processsing/Semana4/aida_means.tsv",  # uncompressed from aida_means.tsv.bz2
#     outdir="data/aida_csv",
#     K_neg=5,
#     seed=42,
#     window=50,
# )

Vamos a entrenar los modelos con la nueva data!

In [89]:
AARGS = dict(ARGS)
AARGS.update(
    {
        "mode": "pw",
        "train": "data/aida_csv/train_pairs.csv",
        "dev": "data/aida_csv/testa_pairs.csv",
        "encoder": "lstm",
        "batch_size": 256,
        'epochs': 10,
    }
)
AARGS

{'mode': 'pw',
 'train': 'data/aida_csv/train_pairs.csv',
 'dev': 'data/aida_csv/testa_pairs.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 256,
 'epochs': 10,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

Vamos a entrenar pocas etapas:

In [37]:
model_pw_lstm, vocab_pw_lstm = train_pairwise(AARGS)

Epoch 1 DONE | pairwise dev_loss=0.4073 pairwise dev_acc=0.8387
Epoch 2 DONE | pairwise dev_loss=0.3521 pairwise dev_acc=0.8699
Epoch 3 DONE | pairwise dev_loss=0.3069 pairwise dev_acc=0.8876
Epoch 4 DONE | pairwise dev_loss=0.2908 pairwise dev_acc=0.8742
Epoch 5 DONE | pairwise dev_loss=0.3235 pairwise dev_acc=0.8568
Epoch 6 DONE | pairwise dev_loss=0.6087 pairwise dev_acc=0.7828
Epoch 7 DONE | pairwise dev_loss=0.4946 pairwise dev_acc=0.8162
Epoch 8 DONE | pairwise dev_loss=0.5463 pairwise dev_acc=0.8228
Epoch 9 DONE | pairwise dev_loss=0.6777 pairwise dev_acc=0.8101
Epoch 10 DONE | pairwise dev_loss=0.8773 pairwise dev_acc=0.7975
Training complete. Best dev acc: 0.8876017532874139


In [38]:
def save_model(model, vocab, args):
    """Saves the trained model and vocabulary to disk."""
    os.makedirs(args.get("save_dir", "models"), exist_ok=True)
    last_path = os.path.join(
        args["save_dir"], f"ned_{args['mode']}_{args['encoder']}_latest.pt"
    )

    ckpt1 = {
        "model_state": model.state_dict(),
        "vocab_itos": getattr(vocab, "itos", None),  # your Vocab stores tokens here
        "args": dict(args),
    }
    torch.save(ckpt1, last_path)
    print("Saved", last_path)

In [39]:
save_model(model_pw_lstm, vocab_pw_lstm, AARGS)

Saved models/ned_pw_lstm_latest.pt


Podemos usar esta rutina para cargar modelo y vocabulario

In [None]:
# --- Load back (example) ---
last_path = os.path.join(
    AARGS["save_dir"], f"ned_{AARGS['mode']}_{AARGS['encoder']}_latest.pt"
)
ckpt1 = torch.load(last_path, map_location=DEVICE)
saved_args = ckpt1["args"]

# Rebuild vocab from itos
vocab2 = Vocab()
vocab2.itos = deepcopy(ckpt1["vocab_itos"])
vocab2.stoi = {t: i for i, t in enumerate(vocab2.itos)}

# Rebuild encoder & model from saved args
if saved_args["encoder"] == "lstm":
    enc2 = LSTMEncoder(
        len(vocab2),
        saved_args["emb_dim"],
        saved_args["hidden_dim"],
        saved_args["lstm_layers"],
        saved_args["dropout"],
        pad_idx=vocab2.pad_idx,
    )
else:
    enc2 = SelfAttentionEncoder(
        len(vocab2),
        saved_args["emb_dim"],
        saved_args["hidden_dim"],
        saved_args["attn_heads"],
        saved_args["attn_layers"],
        saved_args["dropout"],
        pad_idx=vocab2.pad_idx,
        max_len=max(saved_args["ctx_max_len"], saved_args["cand_max_len"]),
    )

model2 = PairwiseScorer(
    enc2, hidden_dim=saved_args["hidden_dim"], dropout=saved_args["dropout"]
).to(DEVICE)
model2.load_state_dict(ckpt1["model_state"])
model2.eval()

In [41]:
prob, logit = score_pair_in_memory(
    model_pw_lstm,
    vocab_pw_lstm,
    context="I ate a ripe apple today",
    candidate="Apple is a fruit often red or green",
)
print("P(candidate | context) =", prob, "logit:", logit)

P(candidate | context) = 0.32778704166412354 logit: -0.7182109951972961


In [42]:
best_idx, probs, logits = choose_among_candidates(
    model_pw_lstm,
    vocab_pw_lstm,
    context="I ate a ripe apple today",
    candidates=[
        "Apple is a fruit often red or green",  # correct
        "Apple Inc. is a technology company",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen:", ["fruit", "company"][best_idx])

softmax probs: [0.9376494884490967, 0.06235046312212944]
chosen: fruit


In [43]:
AARGS.update(
    {
        'mode':'lw',
        "train": "data/aida_csv/train_listwise.csv",
        "dev": "data/aida_csv/testa_listwise.csv",
    }
)
AARGS

{'mode': 'lw',
 'train': 'data/aida_csv/train_listwise.csv',
 'dev': 'data/aida_csv/testa_listwise.csv',
 'save_dir': 'models/',
 'encoder': 'lstm',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 256,
 'epochs': 10,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [44]:
model_lw_lstm, vocab_lw_lstm = train_listwise(AARGS)

Epoch 1: listwise loss=1.7207
Epoch 2: listwise loss=1.4952
Epoch 3: listwise loss=1.2094
Epoch 4: listwise loss=0.9450
Epoch 5: listwise loss=0.7301
Epoch 6: listwise loss=0.5623
Epoch 7: listwise loss=0.4387
Epoch 8: listwise loss=0.3422
Epoch 9: listwise loss=0.2679
Epoch 10: listwise loss=0.2039


In [45]:
save_model(model_lw_lstm, vocab_lw_lstm, AARGS)

Saved models/ned_lw_lstm_latest.pt


In [46]:
compare_two(
    model_lw_lstm,
    vocab_lw_lstm,
    context="I ate a ripe apple today",
    correct_candidate="Apple is a fruit often red or green",
    incorrect_candidate="Apple Inc. is a technology company",
)

{'chosen': 'correct',
 'softmax_probs': {'correct': 0.9026119112968445,
  'incorrect': 0.09738808870315552},
 'logits': {'correct': -4.809988021850586, 'incorrect': -7.036576747894287}}

In [47]:
best_idx, probs, logits = choose_among_candidates(
    model_lw_lstm,
    vocab_lw_lstm,
    context="Amazon released new cloud features",
    candidates=[
        "Amazon.com is an e-commerce and cloud company",  # correct
        "Amazon rainforest is in South America",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen idx:", best_idx)

softmax probs: [0.26784127950668335, 0.7321587800979614]
chosen idx: 1


In [48]:
AARGS.update(
    {
        'mode':'pw',
        "train": "data/aida_csv/train_pairs.csv",
        "dev": "data/aida_csv/testa_pairs.csv",
        "encoder": "attn",
    }
)
AARGS

{'mode': 'pw',
 'train': 'data/aida_csv/train_pairs.csv',
 'dev': 'data/aida_csv/testa_pairs.csv',
 'save_dir': 'models/',
 'encoder': 'attn',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 256,
 'epochs': 10,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [49]:
model_pw_attn, vocab_pw_attn = train_pairwise(AARGS)

Epoch 1 DONE | pairwise dev_loss=0.3735 pairwise dev_acc=0.8342
Epoch 2 DONE | pairwise dev_loss=0.4763 pairwise dev_acc=0.8415
Epoch 3 DONE | pairwise dev_loss=0.3677 pairwise dev_acc=0.8597
Epoch 4 DONE | pairwise dev_loss=0.4256 pairwise dev_acc=0.8614
Epoch 5 DONE | pairwise dev_loss=0.4423 pairwise dev_acc=0.8650
Epoch 6 DONE | pairwise dev_loss=0.5396 pairwise dev_acc=0.8533
Epoch 7 DONE | pairwise dev_loss=0.6196 pairwise dev_acc=0.8489
Epoch 8 DONE | pairwise dev_loss=0.6101 pairwise dev_acc=0.8601
Epoch 9 DONE | pairwise dev_loss=0.7078 pairwise dev_acc=0.8472
Epoch 10 DONE | pairwise dev_loss=0.8683 pairwise dev_acc=0.8422
Training complete. Best dev acc: 0.8649899116398804


In [50]:
save_model(model_pw_attn, vocab_pw_attn, AARGS)

Saved models/ned_pw_attn_latest.pt


In [51]:
prob, logit = score_pair_in_memory(
    model_pw_attn,
    vocab_pw_attn,
    context="I ate a ripe apple today",
    candidate="Apple is a fruit often red or green",
)
print("P(candidate | context) =", prob, "logit:", logit)

P(candidate | context) = 0.9879438877105713 logit: 4.406055450439453


In [52]:
best_idx, probs, logits = choose_among_candidates(
    model_pw_attn,
    vocab_pw_attn,
    context="Amazon released new cloud features",
    candidates=[
        "Amazon.com is an e-commerce and cloud company",  # correct
        "Amazon rainforest is in South America",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen idx:", best_idx)

softmax probs: [0.038246218115091324, 0.961753785610199]
chosen idx: 1


In [53]:
AARGS.update(
    {
        'mode': 'lw',
        "train": "data/aida_csv/train_listwise.csv",
        "dev": "data/aida_csv/testa_listwise.csv",
    }
)
AARGS

{'mode': 'lw',
 'train': 'data/aida_csv/train_listwise.csv',
 'dev': 'data/aida_csv/testa_listwise.csv',
 'save_dir': 'models/',
 'encoder': 'attn',
 'emb_dim': 256,
 'hidden_dim': 512,
 'dropout': 0.2,
 'lstm_layers': 1,
 'attn_heads': 4,
 'attn_layers': 2,
 'batch_size': 256,
 'epochs': 10,
 'lr': 0.0001,
 'weight_decay': 0.001,
 'log_every': 5,
 'min_freq': 1,
 'ctx_max_len': 64,
 'cand_max_len': 64,
 'seed': 42}

In [54]:
model_lw_attn, vocab_lw_attn = train_listwise(AARGS)

Epoch 1: listwise loss=1.7203
Epoch 2: listwise loss=1.5338
Epoch 3: listwise loss=1.1776
Epoch 4: listwise loss=0.9476
Epoch 5: listwise loss=0.7918
Epoch 6: listwise loss=0.6660
Epoch 7: listwise loss=0.5714
Epoch 8: listwise loss=0.4860
Epoch 9: listwise loss=0.4123
Epoch 10: listwise loss=0.3608


In [55]:
save_model(model_lw_attn, vocab_lw_attn, AARGS)

Saved models/ned_lw_attn_latest.pt


In [56]:
compare_two(
    model_lw_attn,
    vocab_lw_attn,
    context="I ate a ripe apple today",
    correct_candidate="Apple is a fruit often red or green",
    incorrect_candidate="Apple Inc. is a technology company",
)

{'chosen': 'incorrect',
 'softmax_probs': {'correct': 0.21120314300060272,
  'incorrect': 0.7887968420982361},
 'logits': {'correct': -1.2091745138168335, 'incorrect': 0.10851380228996277}}

In [57]:
best_idx, probs, logits = choose_among_candidates(
    model_lw_attn,
    vocab_lw_attn,
    context="Amazon released new cloud features",
    candidates=[
        "Amazon.com is an e-commerce and cloud company",  # correct
        "Amazon rainforest is in South America",  # incorrect
    ],
)
print("softmax probs:", probs)
print("chosen idx:", best_idx)

softmax probs: [0.40954166650772095, 0.590458333492279]
chosen idx: 1


### EJERCICIO:
* Entrenar con más epocas para mejorar la salida de los modelos.

![Lenguaje Matemático](./Images/Matematicas.png)

![Contacto](./Images/Contacto.png)