
# Error Atlas — t-SNE Walkthrough

This notebook rebuilds the error dataframe, computes query embeddings, runs PCA→t-SNE, and renders interactive Plotly views. Adjust the toggles to explore all or only incorrect rows.



## Implementation Plan & Tasks (Notebook‑Driven)

**Goal:** Extend the current t‑SNE analysis into a full error‑analysis workflow, adding taxonomy, ontology enrichment, residual embeddings, and diagnostics.

### Tasks (check as you execute)
1. [ ] Data audit: missingness, duplicates, imbalance, schema sanity
2. [ ] Failure taxonomy: ambiguity, near‑miss, thresholding, retrieval vs ranking
3. [ ] Ontology enrichment: map gold/pred IDs → labels/synonyms/definitions
4. [ ] Multi‑view embeddings: query, gold, pred, residuals (q−pred, q−gold)
5. [ ] 2D maps: t‑SNE for each view; facet by model/dataset/error_type
6. [ ] Diagnostics: neighborhood purity, silhouette, error concentration
7. [ ] Export: enriched parquet + plots for sharing


In [28]:

# Optional: pip installs (uncomment if running in a fresh environment)
# !pip install -q sentence-transformers scikit-learn plotly pandas pyarrow tqdm


In [29]:
import json
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
import plotly.express as px

# --- Project root discovery (robust to running from /notebooks) ---

def find_project_root(start=None, max_up=6):
    p = Path(start or Path.cwd()).resolve()
    for _ in range(max_up):
        if (p / 'data').exists() and (p / 'src').exists():
            return p
        p = p.parent
    return Path.cwd().resolve()

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'
ANALYSIS_DIR = DATA_DIR / 'analysis'
EMB_DIR = DATA_DIR / 'embeddings'
ERRORS_PATH = ANALYSIS_DIR / 'errors.parquet'

encoder_name = 'intfloat/e5-small-v2'
perplexity = 30
pca_components = 50
only_incorrect = True
random_state = 42

print('PROJECT_ROOT:', PROJECT_ROOT)


PROJECT_ROOT: /Users/jandrole/projects/onto_rag_paper_version


In [30]:

# Load dataframe
assert ERRORS_PATH.exists(), f"Missing {ERRORS_PATH} — run scripts/build_error_frame.py first."
df = pd.read_parquet(ERRORS_PATH)
if only_incorrect:
    df = df[df['is_correct'] == False].copy()
print(df.shape)
df.head()


(2566, 31)


Unnamed: 0,query,gold_ids,predicted_id,predicted_label,is_correct,confidence,candidate_count,candidate_labels,gold_in_candidates,gold_first_found_at_attempt,...,concurrent_requests,error_type,error,query_lower,query_len,query_tokens,query_has_digit,query_has_hyphen,query_is_upper,query_has_greek
0,glucose,[CHEBI:17234],,,False,,0,[],False,1.0,...,20,no_prediction,,glucose,7,1,False,False,False,False
2,dipotassium phosphate,[CHEBI:32031],CHEBI:131527,dipotassium hydrogen phosphate,False,0.95,30,"[dipotassium hydrogen phosphate, dipotassium b...",False,,...,20,retrieval_miss,,dipotassium phosphate,21,2,False,False,False,False
12,Alizarin red,[CHEBI:16866],CHEBI:87358,alizarin red S,False,0.85,28,"[alizarin red S, alizarin, neutral red, 3,4-di...",True,1.0,...,20,ranking_miss,,alizarin red,12,2,False,False,False,False
22,TAMRA,[CHEBI:51657],CHEBI:52282,tetramethylrhodamine,False,0.95,17,"[tetramethylrhodamine, 5-carboxytetramethylrho...",True,1.0,...,20,ranking_miss,,tamra,5,1,False,False,True,False
24,cineol,[CHEBI:23243],CHEBI:27961,"1,8-cineole",False,0.95,22,"[cineole, 1,8-cineole, 2-exo-hydroxy-1,8-cineo...",True,1.0,...,20,ranking_miss,,cineol,6,1,False,False,False,False



## 1) Data audit & schema sanity


In [33]:
import ast
import numpy as np
from collections import Counter

# Ensure gold_ids is a list (parquet should already preserve lists)
def ensure_list(x):
    if isinstance(x, list):
        return x
    if isinstance(x, tuple):
        return list(x)
    if isinstance(x, np.ndarray):
        return x.tolist()
    if x is None:
        return []
    if isinstance(x, float) and np.isnan(x):
        return []
    if isinstance(x, str):
        try:
            v = ast.literal_eval(x)
            return v if isinstance(v, list) else [v]
        except Exception:
            return [x]
    return [x]

_df = df.copy()
_df['gold_ids_list'] = _df['gold_ids'].apply(ensure_list)

# Missingness summary
missing = _df[['predicted_id','confidence','error','candidate_count']].isna().mean().sort_values(ascending=False)
print('Missingness (fraction):')
print(missing)

# Duplicate query+gold across runs
_df['gold_primary'] = _df['gold_ids_list'].apply(lambda xs: xs[0] if xs else None)
dup_counts = _df.groupby(['dataset','query','gold_primary']).size().sort_values(ascending=False)
print('Top duplicates (dataset, query, gold_primary):')
print(dup_counts.head(10))

# Class balance by dataset/model
print('Correct/Incorrect by dataset:')
print(_df.groupby('dataset')['is_correct'].value_counts().unstack(fill_value=0))

print('Correct/Incorrect by model:')
print(_df.groupby('model')['is_correct'].value_counts().unstack(fill_value=0))


Missingness (fraction):
error              0.974669
predicted_id       0.121980
confidence         0.058067
candidate_count    0.000000
dtype: float64
Top duplicates (dataset, query, gold_primary):
dataset       query      gold_primary  
chebi         glucose    CHEBI:17234       31
ctd_diseases  FAP        MESH:D011125      25
chebi         amyloid    CHEBI:60425       25
ctd_diseases  MPS IVA    OMIM:253000       20
ncbi_gene     STAT3      NCBIGene:6774     20
              EGFR       NCBIGene:1956     15
              CD82       NCBIGene:83628    15
              IL-6       NCBIGene:3569     15
ctd_diseases  CLD        MESH:C536210      15
chebi         molecules  CHEBI:36357       15
dtype: int64
Correct/Incorrect by dataset:
is_correct    False
dataset            
chebi           416
ctd_diseases    509
foodon          592
ncbi_gene      1049
Correct/Incorrect by model:
is_correct                     False
model                               
DeepSeek-R1-Distill-Qwen-32B     523



## 2) Failure taxonomy (beyond retrieval_miss/ranking_miss)


In [34]:

# Heuristic flags for ambiguity / genericity
GENERIC_TERMS = set([
    'molecule','molecules','compound','compounds','chemical','chemicals','substance','substances',
    'drug','drugs','salt','salts','acid','acids','ion','ions','agent','agents','factor','factors'
])

_df['query_lower'] = _df['query'].str.lower()
_df['is_generic'] = _df['query_lower'].isin(GENERIC_TERMS)
_df['is_short'] = _df['query'].str.len() <= 3
_df['is_abbrev'] = _df['query'].str.isupper() & (_df['query'].str.len() <= 6)

# Failure bucket

def failure_bucket(row):
    if row['is_correct']:
        return 'correct'
    if row['error_type'] == 'system_error':
        return 'system_error'
    if row['error_type'] == 'no_prediction':
        return 'no_prediction'
    if row['error_type'] == 'retrieval_miss':
        return 'retrieval_miss'
    if row['error_type'] == 'ranking_miss':
        return 'ranking_miss'
    return 'other'

_df['failure_bucket'] = _df.apply(failure_bucket, axis=1)

# Ambiguity flags for incorrect rows
_df['ambiguous_flag'] = (~_df['is_correct']) & (_df['is_generic'] | _df['is_abbrev'] | _df['is_short'])

print(_df['failure_bucket'].value_counts())
print('Ambiguity rate among incorrect:', _df.loc[~_df['is_correct'], 'ambiguous_flag'].mean())


failure_bucket
retrieval_miss    1784
ranking_miss       469
no_prediction      248
system_error        65
Name: count, dtype: int64
Ambiguity rate among incorrect: 0.3604832424006235



## 3) Ontology enrichment (gold/pred → labels/synonyms/definition)


In [None]:
from functools import lru_cache

ONTO_DUMPS = {
    'chebi': DATA_DIR / 'chebi' / 'ontology_dump.json',
    'ctd_diseases': DATA_DIR / 'ctd_diseases' / 'ontology_dump.json',
    'ncbi_gene': DATA_DIR / 'ncbi_gene' / 'ontology_dump.json',
    'foodon': DATA_DIR / 'foodon' / 'ontology_dump.json',
}

CTD_TSV = DATA_DIR / 'ontologies' / 'CTD_diseases.tsv'
NCBI_TSV = DATA_DIR / 'ontologies' / 'gene_info.tsv'

missing = [ds for ds, p in ONTO_DUMPS.items() if not p.exists()]
if missing:
    print('Missing ontology dumps for:', missing)

@lru_cache(maxsize=None)
def load_ontology_dump(dataset):
    path = ONTO_DUMPS.get(dataset)
    if not path or not path.exists():
        return {}
    with path.open('r', encoding='utf-8') as f:
        return json.load(f)

@lru_cache(maxsize=None)
def load_ctd_alt_map():
    alt_map = {}
    if not CTD_TSV.exists():
        return alt_map
    header = None
    with CTD_TSV.open('r', encoding='utf-8') as f:
        for line in f:
            if line.startswith('# DiseaseName'):
                header = line.lstrip('#').strip().split('	')
                break
        if not header:
            return alt_map
        for line in f:
            if line.startswith('#') or not line.strip():
                continue
            parts = line.rstrip('
').split('	')
            if len(parts) != len(header):
                continue
            row = dict(zip(header, parts))
            primary = row.get('DiseaseID')
            alt_ids = row.get('AltDiseaseIDs', '')
            if primary and alt_ids:
                for alt in alt_ids.split('|'):
                    alt = alt.strip()
                    if alt:
                        alt_map[alt] = primary
    return alt_map


def build_ncbi_fallback(missing_ids):
    # Build a minimal lookup for missing NCBI Gene IDs from gene_info.tsv
    if not missing_ids:
        return {}
    # normalize to bare gene ids
    targets = set()
    for curie in missing_ids:
        if isinstance(curie, str) and curie.startswith('NCBIGene:'):
            targets.add(curie.split(':', 1)[1])
        elif isinstance(curie, str):
            targets.add(curie)
    if not NCBI_TSV.exists():
        return {}
    lookup = {}
    with NCBI_TSV.open('r', encoding='utf-8') as f:
        for line in f:
            if line.startswith('#') or not line.strip():
                continue
            parts = line.rstrip('').split('	')
            if len(parts) < 9:
                continue
            gene_id = parts[1]
            if gene_id not in targets:
                continue
            symbol = parts[2]
            synonyms = parts[4]
            desc = parts[8] if len(parts) > 8 else ''
            curie = f'NCBIGene:{gene_id}'
            lookup[curie] = {
                'label': symbol,
                'synonyms': [s for s in synonyms.split('|') if s and s != '-'],
                'definition': desc,
            }
    return lookup


def normalize_id(dataset, curie):
    if curie is None:
        return None
    # Map OMIM/DO → MESH for CTD when possible
    if dataset == 'ctd_diseases':
        alt_map = load_ctd_alt_map()
        if curie in alt_map:
            return alt_map[curie]
    if dataset == 'ncbi_gene' and isinstance(curie, str) and curie.isdigit():
        return f'NCBIGene:{curie}'
    return curie


def build_term_text(term):
    if not term:
        return None
    label = term.get('label')
    if not label:
        return None
    parts = [label]
    syns = term.get('synonyms') or []
    if syns:
        parts.append(' ; '.join(syns[:5]))
    definition = term.get('definition') or ''
    if definition:
        parts.append(definition[:200])
    return ' | '.join(parts)

# Add gold/pred labels and text

def lookup_term_text(dataset, curie, fallback_label=None, ncbi_fallback=None):
    dump = load_ontology_dump(dataset)
    term = dump.get(curie) if curie else None
    if term:
        return build_term_text(term), (term.get('label') if term else None)
    if dataset == 'ncbi_gene' and ncbi_fallback and curie in ncbi_fallback:
        term = ncbi_fallback[curie]
        return build_term_text(term), term.get('label')
    if fallback_label:
        return str(fallback_label), str(fallback_label)
    return None, None

# Normalize gold/pred IDs
_df['gold_primary'] = _df['gold_ids_list'].apply(lambda xs: xs[0] if xs else None)
_df['gold_primary'] = _df.apply(lambda r: normalize_id(r['dataset'], r['gold_primary']), axis=1)
_df['predicted_id_norm'] = _df.apply(lambda r: normalize_id(r['dataset'], r['predicted_id']), axis=1)

# For gold_ids, prefer any ID that exists in dump (useful for multi-ID cases)
for ds in _df['dataset'].unique():
    dump = load_ontology_dump(ds)
    if not dump:
        continue
    mask = _df['dataset'] == ds
    def pick_best(ids):
        for gid in ids:
            gid_norm = normalize_id(ds, gid)
            if gid_norm in dump:
                return gid_norm
        return normalize_id(ds, ids[0]) if ids else None
    _df.loc[mask, 'gold_primary'] = _df.loc[mask, 'gold_ids_list'].apply(pick_best)

# Build NCBI fallback only for missing IDs
ncbi_dump = load_ontology_dump('ncbi_gene')
missing_ncbi = set(_df.loc[_df['dataset']=='ncbi_gene','gold_primary'].dropna()) - set(ncbi_dump.keys())
if missing_ncbi:
    print('Building NCBI fallback for', len(missing_ncbi), 'IDs')
    ncbi_fallback = build_ncbi_fallback(missing_ncbi)
else:
    ncbi_fallback = {}

texts = []
labels = []
ptexts = []
plabels = []

for _, row in _df.iterrows():
    gtext, glabel = lookup_term_text(row['dataset'], row['gold_primary'], ncbi_fallback=ncbi_fallback)
    ptext, plabel = lookup_term_text(row['dataset'], row['predicted_id_norm'], row.get('predicted_label'), ncbi_fallback=ncbi_fallback)
    texts.append(gtext)
    labels.append(glabel)
    ptexts.append(ptext)
    plabels.append(plabel)

_df['gold_text'] = texts
_df['gold_label'] = labels
_df['pred_text'] = ptexts
_df['pred_label_enriched'] = plabels

print('gold_text coverage:', _df['gold_text'].notna().mean())
print('pred_text coverage:', _df['pred_text'].notna().mean())

_df[['gold_primary','gold_label','predicted_id_norm','pred_label_enriched']].head()


SyntaxError: EOL while scanning string literal (2003655079.py, line 41)


## 4) Multi‑view embeddings (query / gold / pred / residuals)


In [None]:

import numpy as np

ENCODER = encoder_name  # reuse from earlier cell

def embed_texts(unique_texts, cache_path):
    if cache_path.exists():
        emb = np.load(cache_path)
        if emb.shape[0] == len(unique_texts):
            print(f"Loaded {cache_path}")
            return emb
    model = SentenceTransformer(ENCODER)
    emb = model.encode(unique_texts, batch_size=64, convert_to_numpy=True,
                       show_progress_bar=True, normalize_embeddings=True)
    np.save(cache_path, emb)
    print(f"Saved {cache_path}")
    return emb

# Build embeddings for query, gold_text, pred_text
EMB_DIR.mkdir(parents=True, exist_ok=True)

# Query embeddings (all rows)
q_texts = _df['query'].astype(str).tolist()
q_unique = sorted(set(q_texts))
q_map = {t:i for i,t in enumerate(q_unique)}
q_cache = EMB_DIR / f"{ENCODER.replace('/', '_')}_query.npy"
emb_q_unique = embed_texts(q_unique, q_cache)
emb_q = emb_q_unique[[q_map[t] for t in q_texts]]

# Gold embeddings (rows with gold_text)
_g = _df['gold_text'].fillna('')
_g_unique = sorted(set([t for t in _g if t]))
g_map = {t:i for i,t in enumerate(_g_unique)}
g_cache = EMB_DIR / f"{ENCODER.replace('/', '_')}_gold.npy"
emb_g_unique = embed_texts(_g_unique, g_cache) if _g_unique else None

# Pred embeddings (rows with pred_text)
_p = _df['pred_text'].fillna('')
_p_unique = sorted(set([t for t in _p if t]))
p_map = {t:i for i,t in enumerate(_p_unique)}
p_cache = EMB_DIR / f"{ENCODER.replace('/', '_')}_pred.npy"
emb_p_unique = embed_texts(_p_unique, p_cache) if _p_unique else None

# Build aligned arrays with NaNs for missing
emb_g = np.full_like(emb_q, np.nan)
emb_p = np.full_like(emb_q, np.nan)

if emb_g_unique is not None:
    for i, t in enumerate(_g):
        if t:
            emb_g[i] = emb_g_unique[g_map[t]]

if emb_p_unique is not None:
    for i, t in enumerate(_p):
        if t:
            emb_p[i] = emb_p_unique[p_map[t]]

print('emb_q', emb_q.shape, 'emb_g', emb_g.shape, 'emb_p', emb_p.shape)



## 5) Residual embeddings + t‑SNE views


In [None]:
# Residuals: query - pred / query - gold / pred - gold
mask_qp = ~np.isnan(emb_p).any(axis=1)
mask_qg = ~np.isnan(emb_g).any(axis=1)
mask_pg = mask_qp & mask_qg

res_qp = emb_q[mask_qp] - emb_p[mask_qp]
res_qg = emb_q[mask_qg] - emb_g[mask_qg]
res_pg = emb_p[mask_pg] - emb_g[mask_pg]

print('Residual shapes:', res_qp.shape, res_qg.shape, res_pg.shape)
print('mask_qp:', mask_qp.sum(), 'mask_qg:', mask_qg.sum(), 'mask_pg:', mask_pg.sum())

def run_tsne(data, perplexity=30, pca_components=50, seed=42):
    from sklearn.decomposition import TruncatedSVD
    from sklearn.manifold import TSNE
    if data.shape[0] < 5:
        raise ValueError('Not enough points for t-SNE')
    if pca_components and pca_components < data.shape[1]:
        svd = TruncatedSVD(n_components=pca_components, random_state=seed)
        data = svd.fit_transform(data)
    perp = min(perplexity, max(5, data.shape[0] - 1))
    tsne = TSNE(n_components=2, perplexity=perp, metric='cosine', init='pca',
                random_state=seed, learning_rate='auto', max_iter=1500, verbose=1)
    return tsne.fit_transform(data)

# Choose a view to project
view = 'q_minus_pred'  # options: q_minus_pred, q_minus_gold, pred_minus_gold

# Fallback if not enough points
if view == 'q_minus_pred' and res_qp.shape[0] < 5:
    print('Not enough points for q_minus_pred; falling back to q_minus_gold')
    view = 'q_minus_gold'
if view == 'q_minus_gold' and res_qg.shape[0] < 5:
    print('Not enough points for q_minus_gold; falling back to pred_minus_gold')
    view = 'pred_minus_gold'
if view == 'pred_minus_gold' and res_pg.shape[0] < 5:
    print('Not enough points for residuals; falling back to query embeddings')
    view = 'query'

if view == 'q_minus_pred':
    coords = run_tsne(res_qp, perplexity=30, pca_components=50)
    df_view = _df[mask_qp].copy()
elif view == 'q_minus_gold':
    coords = run_tsne(res_qg, perplexity=30, pca_components=50)
    df_view = _df[mask_qg].copy()
elif view == 'pred_minus_gold':
    coords = run_tsne(res_pg, perplexity=30, pca_components=50)
    df_view = _df[mask_pg].copy()
else:
    coords = run_tsne(emb_q, perplexity=30, pca_components=50)
    df_view = _df.copy()

# Add coords
_df_view = df_view.copy()
_df_view['tsne_x'] = coords[:,0]
_df_view['tsne_y'] = coords[:,1]
_df_view.head()



## 6) Visualize residual space (t‑SNE)


In [None]:

fig_res = px.scatter(
    _df_view,
    x='tsne_x', y='tsne_y',
    color='error_type', symbol='dataset',
    hover_data={
        'query': True,
        'pred_label_enriched': True,
        'gold_label': True,
        'model': True,
        'confidence': True,
        'is_correct': True,
        'tsne_x': False,
        'tsne_y': False,
    },
    title=f"t-SNE residual view: {view}",
    opacity=0.85,
)
fig_res.update_traces(marker=dict(size=6, line=dict(width=0)))
fig_res.show()



## 7) Diagnostics (cluster purity, silhouette, neighborhood hit‑rate)


In [None]:
from sklearn.metrics import silhouette_score
from sklearn.neighbors import NearestNeighbors
import numpy as np

# Use embeddings for the CURRENT dataframe _df (may be incorrect-only)
# Guard against NaN/inf
finite_mask = np.isfinite(emb_q).all(axis=1)
emb_q_finite = emb_q[finite_mask]
_df_finite = _df.iloc[np.where(finite_mask)[0]].copy()

# Neighborhood purity for error_type (current view)
k = 15
k = min(k, len(emb_q_finite) - 1) if len(emb_q_finite) > 1 else 1
nbrs = NearestNeighbors(n_neighbors=k, metric='cosine').fit(emb_q_finite)
_, idx = nbrs.kneighbors(emb_q_finite)

labels = _df_finite['error_type'].astype(str).values
purity = []
for i, neighbors in enumerate(idx):
    # skip self at index 0
    neigh_labels = labels[neighbors[1:]] if len(neighbors) > 1 else labels[neighbors]
    purity.append((neigh_labels == labels[i]).mean())

_df_finite['nn_purity_error_type'] = purity
print('Avg neighborhood purity (error_type):', _df_finite['nn_purity_error_type'].mean())

# Silhouette score by error_type (exclude tiny classes)
label_counts = _df_finite['error_type'].value_counts()
valid = _df_finite['error_type'].isin(label_counts[label_counts > 10].index)
try:
    if valid.sum() > 10:
        sil = silhouette_score(emb_q_finite[valid], _df_finite.loc[valid, 'error_type'])
        print('Silhouette (error_type):', sil)
    else:
        print('Silhouette skipped: not enough samples after filtering')
except Exception as e:
    print('Silhouette failed:', e)

# Error rates should be computed on FULL dataset, not filtered
_df_full = pd.read_parquet(ERRORS_PATH)

print('
Accuracy by dataset (full):')
print((_df_full.groupby('dataset')['is_correct'].mean().rename('accuracy')))

print('
Accuracy by model (full):')
print((_df_full.groupby('model')['is_correct'].mean().rename('accuracy')))



## 8) Export enriched dataframe


In [None]:
# Export enriched dataframe
ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)
out_path = ANALYSIS_DIR / 'errors_enriched.parquet'
_df.to_parquet(out_path, index=False)
print('Wrote', out_path)


In [None]:

# Helper: encode unique texts with caching on disk
import hashlib

def cache_path_for(encoder: str):
    slug = encoder.replace('/', '_').replace(':', '_')
    return EMB_DIR / f"{slug}_query_embeddings.npy"

EMB_DIR.mkdir(parents=True, exist_ok=True)
cache_path = cache_path_for(encoder_name)

texts = df['query'].astype(str).tolist()
unique_texts = sorted(set(texts))
text_to_idx = {t: i for i, t in enumerate(unique_texts)}

if cache_path.exists():
    emb_unique = np.load(cache_path)
    if emb_unique.shape[0] != len(unique_texts):
        print('Cache size mismatch; recomputing embeddings...')
        emb_unique = None
else:
    emb_unique = None

if emb_unique is None:
    model = SentenceTransformer(encoder_name)
    emb_unique = model.encode(unique_texts, batch_size=64, convert_to_numpy=True,
                              show_progress_bar=True, normalize_embeddings=True)
    np.save(cache_path, emb_unique)
    print(f"Saved embeddings to {cache_path}")
else:
    print(f"Loaded embeddings from {cache_path}")

emb_full = emb_unique[[text_to_idx[t] for t in texts]]
emb_full.shape


In [None]:

# PCA (SVD) before t-SNE for speed/stability
svd = TruncatedSVD(n_components=min(pca_components, emb_full.shape[1]-1), random_state=random_state)
emb_svd = svd.fit_transform(emb_full)
emb_svd.shape


In [None]:

# t-SNE projection
perp = min(perplexity, max(5, len(df) - 1))
tsne = TSNE(n_components=2, perplexity=perp, metric='cosine', init='pca',
            random_state=random_state, learning_rate='auto', max_iter=1500, verbose=1)
coords = tsne.fit_transform(emb_svd)
df_plot = df.copy()
df_plot['tsne_x'] = coords[:,0]
df_plot['tsne_y'] = coords[:,1]
coords[:5]


In [None]:

# Interactive scatter
fig = px.scatter(
    df_plot,
    x='tsne_x', y='tsne_y',
    color='error_type', symbol='dataset',
    hover_data={
        'query': True,
        'predicted_label': True,
        'gold_ids': True,
        'model': True,
        'run_id': True,
        'confidence': True,
        'is_correct': True,
        'tsne_x': False,
        'tsne_y': False,
    },
    title=f"t-SNE of queries (encoder={encoder_name}, perplexity={perp})",
    opacity=0.85,
)
fig.update_traces(marker=dict(size=6, line=dict(width=0)))
fig.show()



## Explore neighborhoods
Pick a query and see its nearest neighbors in embedding space.


In [None]:

# Compute cosine similarity vs all points for an example query
example_query = df_plot.iloc[0]['query']
q_idx = df_plot.index[df_plot['query'] == example_query][0]
q_vec = emb_full[q_idx]

# use dot product because vectors are normalized
sims = emb_full @ q_vec
nn_idx = sims.argsort()[::-1][:15]

nn = df_plot.iloc[nn_idx][['query','error_type','dataset','model','confidence','predicted_label','gold_ids']]
print(f"Example query: {example_query}")


## By model / dataset faceting


In [None]:
fig_facet = px.scatter(
    df_plot,
    x='tsne_x', y='tsne_y',
    color='error_type',
    facet_col='dataset', facet_row='model',
    height=900,
    width=1400,
    opacity=0.8,
    title='t-SNE faceted by dataset/model',
)
fig_facet.update_traces(marker=dict(size=4, line=dict(width=0)))
fig_facet.update_annotations(textangle=0)
for annotation in fig_facet.layout.annotations:
    if 'model=' in annotation.text:
        annotation.textangle = 80
fig_facet.show()
