# ROR â†’ Postgres + pgvector (multilingualâ€‘e5â€‘large)
**Order:** 1) Imports & setup â†’ 2) Schema â†’ 3) Load & normalize â†’ 4) Core upsert â†’ 5) Child tables â†’ 6) index_text views â†’ 7) Embed (with progress bars) â†’ 8) ANN indexes â†’ 9) Search helper

## 1) Imports & setup

In [None]:
%pip install -q "psycopg[binary]>=3.1" "pgvector>=0.3.2" sentence-transformers torch tqdm orjson zipfile36 ipywidgets

import os, json, zipfile
from pathlib import Path
from typing import List, Dict, Any, Iterable

import numpy as np
import orjson
from tqdm.auto import tqdm
from urllib.parse import urlparse

import psycopg
from pgvector.psycopg import register_vector

import torch
from sentence_transformers import SentenceTransformer

# --- Database connection parameters ---
DB_HOST = 'localhost'
DB_PORT = 5432
DB_NAME = 'ror_db'
DB_USER = 'username'
DB_PASSWORD = 'password'

DSN = f'postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}'

# --- File paths and model settings ---
ROR_ZIP_PATH = Path(os.environ.get('ROR_ZIP', 'ror-latest.zip')).expanduser()
MODEL_NAME = 'intfloat/multilingual-e5-large'  # 1024â€‘dim
DEVICE = 'mps' if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() else 'cpu'
EMBED_DIM = 1024

print('Using DSN:', DSN)
print('ROR dump exists:', ROR_ZIP_PATH.exists(), str(ROR_ZIP_PATH))
print('Device:', DEVICE)

# Lazy model loader + embedding helper
_model = None
def embed_texts(texts: Iterable[str], is_passage: bool = True, batch_size: int = 64) -> np.ndarray:
    global _model
    if _model is None:
        _model = SentenceTransformer(MODEL_NAME, device=DEVICE)
    prefix = 'passage: ' if is_passage else 'query: '
    tagged = [prefix + (t or '') for t in texts]
    vecs = _model.encode(tagged, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True)
    return np.asarray(vecs, dtype=np.float32)


## 2) Create schema (pgvector + tables)

In [None]:
with psycopg.connect(DSN, autocommit=True) as conn:
    register_vector(conn)
    with conn.cursor() as cur:
        cur.execute('''
CREATE EXTENSION IF NOT EXISTS vector;

CREATE TABLE IF NOT EXISTS ror_org (
  ror_id            text PRIMARY KEY,
  status            text NOT NULL CHECK (status IN ('active','inactive','withdrawn')),
  types             text[] NOT NULL,
  established       integer,
  created_date      date NOT NULL,
  created_schema_version text NOT NULL,
  last_modified_date date NOT NULL,
  last_modified_schema_version text NOT NULL,
  country_code      char(2),
  search_text       text,
  embedding         vector(1024)
);
CREATE INDEX IF NOT EXISTS ror_org_status_idx ON ror_org(status);
CREATE INDEX IF NOT EXISTS ror_org_types_gin  ON ror_org USING GIN (types);
CREATE INDEX IF NOT EXISTS ror_org_cc_idx     ON ror_org(country_code);

CREATE TABLE IF NOT EXISTS ror_org_name (
  id      bigserial PRIMARY KEY,
  ror_id  text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  value   text NOT NULL,
  types   text[] NOT NULL,
  lang    char(2)
);
CREATE UNIQUE INDEX IF NOT EXISTS ror_org_name_uq ON ror_org_name(ror_id, value, lang, types);

CREATE TABLE IF NOT EXISTS ror_org_link (
  id     bigserial PRIMARY KEY,
  ror_id text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  type   text NOT NULL,
  value  text NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS ror_org_link_uq ON ror_org_link(ror_id, type, value);

CREATE TABLE IF NOT EXISTS ror_org_external_id (
  id        bigserial PRIMARY KEY,
  ror_id    text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  type      text NOT NULL,
  all_ids   text[] NOT NULL,
  preferred text
);
CREATE UNIQUE INDEX IF NOT EXISTS ror_org_extid_uq ON ror_org_external_id(ror_id, type);

CREATE TABLE IF NOT EXISTS ror_org_location (
  id        bigserial PRIMARY KEY,
  ror_id    text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  geonames_id integer NOT NULL,
  name      text NOT NULL,
  lat       double precision,
  lng       double precision,
  continent_code text,
  continent_name text,
  country_code char(2),
  country_name text,
  country_subdivision_code text,
  country_subdivision_name text
);
CREATE INDEX IF NOT EXISTS ror_org_loc_country_idx ON ror_org_location(country_code);
CREATE INDEX IF NOT EXISTS ror_org_loc_geonames_idx ON ror_org_location(geonames_id);

CREATE TABLE IF NOT EXISTS ror_org_relationship (
  id        bigserial PRIMARY KEY,
  ror_id    text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  rel_type  text NOT NULL,
  target_id text NOT NULL,
  label     text NOT NULL
);
CREATE INDEX IF NOT EXISTS ror_org_rel_src_idx ON ror_org_relationship(ror_id, rel_type);
CREATE INDEX IF NOT EXISTS ror_org_rel_tgt_idx ON ror_org_relationship(target_id);

CREATE TABLE IF NOT EXISTS ror_org_domain (
  id      bigserial PRIMARY KEY,
  ror_id  text NOT NULL REFERENCES ror_org(ror_id) ON DELETE CASCADE,
  domain  text NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS ror_org_domain_uq ON ror_org_domain(ror_id, domain);
''')
print('Schema ready.')


## 3) Load & normalize ROR (v2 JSONâ€‘inâ€‘ZIP)

In [None]:
def _pick_json_from_zip(zip_path: Path) -> str:
    with zipfile.ZipFile(zip_path, 'r') as zf:
        cands = [n for n in zf.namelist() if n.endswith('.json') and ('v2' in n.lower() or '/v2/' in n)]
        if not cands:
            cands = [n for n in zf.namelist() if n.endswith('.json')]
        if not cands:
            raise RuntimeError('No JSON file found in the ROR dump zip.')
        return max(cands, key=lambda n: zf.getinfo(n).file_size)

def load_ror(zip_path: Path):
    with zipfile.ZipFile(zip_path, 'r') as zf:
        json_name = _pick_json_from_zip(zip_path)
        with zf.open(json_name) as f:
            return orjson.loads(f.read())

def collect_name_variants(org: Dict[str, Any]) -> List[str]:
    names = set()
    if org.get('name'): names.add(org['name'])
    for key in ('aliases','labels','acronyms'):
        for v in org.get(key) or []:
            if isinstance(v, dict):
                val = v.get('label') or v.get('value') or v.get('name')
                if val: names.add(val)
            elif isinstance(v, str):
                names.add(v)
    return sorted({n.strip() for n in names if isinstance(n, str) and n.strip()})

def normalize(org: Dict[str, Any]) -> Dict[str, Any]:
    rid = org.get('id') or org.get('ror_id') or ''
    names = collect_name_variants(org)
    cc = ((org.get('country') or {}).get('country_code')) or ''
    primary = org.get('name') or (names[0] if names else '')
    search_text = primary
    if cc: search_text += f' (country:{cc})'
    if len(names) > 1:
        search_text += ' | aka: ' + '; '.join(names[1:][:10])
    return {
        'ror_id': rid,
        'status': org.get('status') or 'active',
        'types': org.get('types') or [],
        'established': org.get('established'),
        'created_date': ((org.get('admin') or {}).get('created') or {}).get('date') or '1970-01-01',
        'created_schema_version': ((org.get('admin') or {}).get('created') or {}).get('schema_version') or '2.1',
        'last_modified_date': ((org.get('admin') or {}).get('last_modified') or {}).get('date') or '1970-01-01',
        'last_modified_schema_version': ((org.get('admin') or {}).get('last_modified') or {}).get('schema_version') or '2.1',
        'country_code': cc or None,
        'search_text': search_text,
        'name': primary,
        'names': names,
        'links': org.get('links') or [],
        'external_ids': org.get('external_ids') or [],
        'locations': org.get('locations') or [],
        'relationships': org.get('relationships') or [],
        'domains': org.get('domains') or [],
    }

raw = load_ror(ROR_ZIP_PATH)
rows = [normalize(o) for o in raw]
print('Organizations:', len(rows))


## 4) Upsert core orgs

In [None]:
ins_org = (
    "INSERT INTO ror_org ("
    "  ror_id, status, types, established,"
    "  created_date, created_schema_version,"
    "  last_modified_date, last_modified_schema_version,"
    "  country_code, search_text) "
    "VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) "
    "ON CONFLICT (ror_id) DO UPDATE SET "
    "  status=EXCLUDED.status, types=EXCLUDED.types, established=EXCLUDED.established,"
    "  created_date=EXCLUDED.created_date, created_schema_version=EXCLUDED.created_schema_version,"
    "  last_modified_date=EXCLUDED.last_modified_date, last_modified_schema_version=EXCLUDED.last_modified_schema_version,"
    "  country_code=EXCLUDED.country_code, search_text=EXCLUDED.search_text;"
)

with psycopg.connect(DSN, autocommit=False) as conn:
    with conn.cursor() as cur:
        BATCH = 1000
        for i in tqdm(range(0, len(rows), BATCH), desc='Upserting orgs'):
            chunk = rows[i:i+BATCH]
            cur.executemany(
                ins_org,
                [(
                    r['ror_id'], r['status'], r['types'], r['established'],
                    r['created_date'], r['created_schema_version'],
                    r['last_modified_date'], r['last_modified_schema_version'],
                    r['country_code'], r['search_text']
                ) for r in chunk]
            )
    conn.commit()
print('Core orgs upserted.')


## 5) Upsert child tables (names, links, external_ids, locations, relationships, domains)

In [None]:
def _nz(x):
    return x.strip() if isinstance(x, str) else x

def rows_names(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    primary = _nz(r.get('name') or '')
    if primary:
        out.append((rid, primary, ['ror_display'], None))
    for v in r.get('names') or []:
        v = _nz(v)
        if not v:
            continue
        typ = 'acronym' if (v.isupper() and len(v) <= 12) else 'alias'
        out.append((rid, v, [typ], None))
    return out

def rows_links(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    for lk in (r.get('links') or []):
        t = _nz(lk.get('type') or 'website')
        v = _nz(lk.get('value') or '')
        if v:
            out.append((rid, t, v))
    return out

def rows_extids(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    for e in (r.get('external_ids') or []):
        t = _nz(e.get('type') or '')
        all_ids = e.get('all') or e.get('all_ids') or []
        pref = _nz(e.get('preferred') or None)
        if t and all_ids:
            out.append((rid, t, [str(x) for x in all_ids], pref))
    return out

def rows_locations(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    for loc in (r.get('locations') or []):
        g = loc.get('geonames_id')
        det = loc.get('geonames_details') or {}
        name = _nz(det.get('name') or '')
        if not g or not name:
            continue
        out.append((
            rid, int(g), 
            _nz(det.get('name') or None),
            _nz(det.get('lat') or None),
            _nz(det.get('lng') or None),
            _nz(det.get('continent_code') or None),
            _nz(det.get('continent_name') or None),
            _nz(det.get('country_code') or None),
            _nz(det.get('country_name') or None),
            _nz(det.get('country_subdivision_code') or None),
            _nz(det.get('country_subdivision_name') or None),
        ))
    return out

def rows_relationships(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    for rel in (r.get('relationships') or []):
        rt  = _nz(rel.get('type') or '')
        tid = _nz(rel.get('id') or '')
        lbl = _nz(rel.get('label') or '')
        if rt and tid and lbl:
            out.append((rid, rt, tid, lbl))
    return out

def rows_domains(r):
    out = []
    rid = r.get('ror_id')
    if not rid:
        return out
    for d in (r.get('domains') or []):
        d = _nz(d)
        if d:
            out.append((rid, d))
    if not out:
        for lk in (r.get('links') or []):
            v = _nz(lk.get('value') or '')
            try:
                host = urlparse(v).hostname
                if host:
                    out.append((rid, host.lower()))
            except Exception:
                pass
    seen, dedup = set(), []
    for _, d in out:
        if d not in seen:
            seen.add(d)
            dedup.append((rid, d))
    return dedup

ins_name = """
INSERT INTO ror_org_name (ror_id, value, types, lang)
VALUES (%s,%s,%s,%s)
ON CONFLICT (ror_id, value, lang, types) DO NOTHING;
"""
ins_link = """
INSERT INTO ror_org_link (ror_id, type, value)
VALUES (%s,%s,%s)
ON CONFLICT (ror_id, type, value) DO NOTHING;
"""
ins_extid = """
INSERT INTO ror_org_external_id (ror_id, type, all_ids, preferred)
VALUES (%s,%s,%s,%s)
ON CONFLICT (ror_id, type) DO UPDATE SET all_ids = EXCLUDED.all_ids, preferred = EXCLUDED.preferred;
"""
ins_loc = """
INSERT INTO ror_org_location (
  ror_id, geonames_id, name, lat, lng,
  continent_code, continent_name,
  country_code, country_name,
  country_subdivision_code, country_subdivision_name
) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
ON CONFLICT DO NOTHING;
"""
ins_rel = """
INSERT INTO ror_org_relationship (ror_id, rel_type, target_id, label)
VALUES (%s,%s,%s,%s)
ON CONFLICT DO NOTHING;
"""
ins_dom = """
INSERT INTO ror_org_domain (ror_id, domain)
VALUES (%s,%s)
ON CONFLICT (ror_id, domain) DO NOTHING;
"""

BATCH = 1000
with psycopg.connect(DSN, autocommit=False) as conn:
    with conn.cursor() as cur:
        # names
        buf = []
        for r in rows:
            buf.extend(rows_names(r))
            if len(buf) >= BATCH:
                cur.executemany(ins_name, buf); buf.clear()
        if buf:
            cur.executemany(ins_name, buf); buf.clear()

        # links
        for r in rows:
            vals = rows_links(r)
            if vals:
                cur.executemany(ins_link, vals)

        # external IDs
        buf = []
        for r in rows:
            buf.extend(rows_extids(r))
            if len(buf) >= BATCH:
                cur.executemany(ins_extid, buf); buf.clear()
        if buf:
            cur.executemany(ins_extid, buf); buf.clear()

        # locations
        buf = []
        for r in rows:
            buf.extend(rows_locations(r))
            if len(buf) >= BATCH:
                cur.executemany(ins_loc, buf); buf.clear()
        if buf:
            cur.executemany(ins_loc, buf); buf.clear()

        # relationships
        buf = []
        for r in rows:
            buf.extend(rows_relationships(r))
            if len(buf) >= BATCH:
                cur.executemany(ins_rel, buf); buf.clear()
        if buf:
            cur.executemany(ins_rel, buf); buf.clear()

        # domains
        buf = []
        for r in rows:
            buf.extend(rows_domains(r))
            if len(buf) >= BATCH:
                cur.executemany(ins_dom, buf); buf.clear()
        if buf:
            cur.executemany(ins_dom, buf); buf.clear()

    conn.commit()
print('Child tables loaded.')


## 6) Compose index_text from all tables (SQL views)

In [None]:
with psycopg.connect(DSN, autocommit=True) as conn:
    with conn.cursor() as cur:
        cur.execute('''
CREATE OR REPLACE VIEW ror_org_index_text AS
SELECT
  o.ror_id,
  coalesce(
    (SELECT n.value FROM ror_org_name n WHERE n.ror_id = o.ror_id AND 'ror_display' = ANY(n.types) ORDER BY n.id LIMIT 1),
    (SELECT n.value FROM ror_org_name n WHERE n.ror_id = o.ror_id ORDER BY n.id LIMIT 1),
    ''
  ) AS display_name,
  coalesce((
    SELECT string_agg(n.value, '; ' ORDER BY n.id)
    FROM ror_org_name n
    WHERE n.ror_id = o.ror_id AND ('alias' = ANY(n.types) OR 'acronym' = ANY(n.types) OR 'label' = ANY(n.types))
  ), '') AS aka_block,
  coalesce((
    SELECT string_agg(d.domain, '; ' ORDER BY d.domain)
    FROM ror_org_domain d
    WHERE d.ror_id = o.ror_id
  ), '') AS domain_block,
  coalesce((
    SELECT string_agg(l.value, '; ' ORDER BY l.id)
    FROM ror_org_link l
    WHERE l.ror_id = o.ror_id
  ), '') AS links_block,
  coalesce((
    SELECT string_agg(
             concat(e.type, ':', coalesce(e.preferred,''),
                    CASE WHEN e.preferred IS NOT NULL AND cardinality(e.all_ids) > 0 THEN '|' ELSE '' END,
                    array_to_string(e.all_ids, ',')),
             ' ; ' ORDER BY e.id)
    FROM ror_org_external_id e
    WHERE e.ror_id = o.ror_id
  ), '') AS extid_block,
  coalesce((
    SELECT string_agg(
             concat_ws(' ',
               coalesce(l.name,''),
               coalesce(l.country_subdivision_name,''),
               coalesce(l.country_name,''),
               CASE WHEN l.country_code IS NOT NULL THEN '(country:'||l.country_code||')' ELSE '' END
             ), ' ; ' ORDER BY l.id)
    FROM ror_org_location l
    WHERE l.ror_id = o.ror_id
  ), '') AS location_block,
  coalesce((
    SELECT string_agg(concat(r.rel_type, ': ', r.label), ' ; ' ORDER BY r.id)
    FROM ror_org_relationship r
    WHERE r.ror_id = o.ror_id
  ), '') AS rel_block,
  o.country_code
FROM ror_org o;
CREATE OR REPLACE VIEW ror_org_index_text_full AS
SELECT
  ror_id,
  trim(both ' ' from
    concat(
      display_name,
      CASE WHEN display_name <> '' AND country_code IS NOT NULL THEN ' (country:'||country_code||')' ELSE '' END,
      CASE WHEN aka_block     <> '' THEN ' | aka: '     || aka_block     ELSE '' END,
      CASE WHEN domain_block  <> '' THEN ' | domains: ' || domain_block  ELSE '' END,
      CASE WHEN links_block   <> '' THEN ' | links: '   || links_block   ELSE '' END,
      CASE WHEN extid_block   <> '' THEN ' | ids: '     || extid_block   ELSE '' END,
      CASE WHEN location_block<> '' THEN ' | locs: '    || location_block ELSE '' END,
      CASE WHEN rel_block     <> '' THEN ' | rel: '     || rel_block     ELSE '' END
    )
  ) AS index_text
FROM ror_org_index_text;
''')
print('Views created.')


In [None]:
import psycopg

with psycopg.connect("postgresql://nick:postgres@localhost:5432/ror_db", autocommit=True) as conn:
    with conn.cursor() as cur:
        cur.execute("""
        CREATE INDEX IF NOT EXISTS ror_org_name_rid_idx        ON ror_org_name(ror_id);
        CREATE INDEX IF NOT EXISTS ror_org_link_rid_idx        ON ror_org_link(ror_id);
        CREATE INDEX IF NOT EXISTS ror_org_external_id_rid_idx ON ror_org_external_id(ror_id);
        CREATE INDEX IF NOT EXISTS ror_org_location_rid_idx    ON ror_org_location(ror_id);
        CREATE INDEX IF NOT EXISTS ror_org_relationship_rid_idx ON ror_org_relationship(ror_id);
        CREATE INDEX IF NOT EXISTS ror_org_domain_rid_idx      ON ror_org_domain(ror_id);

        ANALYZE ror_org;
        ANALYZE ror_org_name;
        ANALYZE ror_org_link;
        ANALYZE ror_org_external_id;
        ANALYZE ror_org_location;
        ANALYZE ror_org_relationship;
        ANALYZE ror_org_domain;
        """)
print("Indexes ensured and stats updated.")

In [None]:
# ðŸ”Ž ROR DB dashboard (SQLAlchemy + Pandas) â€” counts + samples

import pandas as pd
from sqlalchemy import create_engine, text

#DB_HOST = "localhost"; DB_PORT = 5432
#DB_NAME = "ror_db";   DB_USER = "nick"; DB_PASSWORD = "postgres"

ENGINE_URL = f"postgresql+psycopg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
engine = create_engine(ENGINE_URL, future=True)

TABLES = [
    "ror_org",
    "ror_org_name",
    "ror_org_link",
    "ror_org_external_id",
    "ror_org_location",
    "ror_org_relationship",
    "ror_org_domain",
]

# counts
with engine.connect() as conn:
    rows = []
    for t in TABLES:
        cnt = conn.execute(text(f"SELECT COUNT(*) FROM {t}")).scalar_one()
        rows.append({"table": t, "rows": cnt})
    overview = pd.DataFrame(rows).sort_values("rows", ascending=False)

display(overview)

# samples
with engine.connect() as conn:
    for t in TABLES:
        print(f"\nðŸ”¹ Sample from {t}:")
        try:
            df = pd.read_sql_query(text(f"SELECT * FROM {t} LIMIT 5"), conn)
            display(df)
        except Exception as e:
            print(f"(could not fetch sample: {e})")

# quick vector sanity checks
with engine.connect() as conn:
    print("\nâœ… Vectors present:")
    display(pd.read_sql_query(
        text("SELECT COUNT(*) AS with_vectors FROM ror_org WHERE embedding IS NOT NULL"),
        conn,
    ))
    print("\nTop countries:")
    display(pd.read_sql_query(
        text("SELECT country_code, COUNT(*) AS n FROM ror_org GROUP BY country_code ORDER BY n DESC LIMIT 10"),
        conn,
    ))

## 7) Embed index_text and store vectors (with progress bars)

In [None]:
from tqdm.auto import tqdm
import os

# Apple Silicon friendly defaults
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")

# Tune once
BATCH = 16  # try 8â€“16 on Apple M-series
UPDATE_SQL = "UPDATE ror_org SET search_text = %s, embedding = %s WHERE ror_id = %s"

def embed_batch(texts):
    """Embed a small batch with the inner progress bar disabled."""
    # Uses _model from setup; ensures same prefix logic as your embed_texts()
    tagged = ["passage: " + (t or "") for t in texts]
    vecs = _model.encode(
        tagged,
        batch_size=BATCH,
        normalize_embeddings=True,
        show_progress_bar=False,  # <- no 'Batches:' lines
    )
    return vecs

with psycopg.connect(DSN) as conn:  # one transaction
    with conn.cursor() as cur:
        cur.execute("SET LOCAL work_mem = '256MB';")
        print("Building pending setâ€¦", flush=True)
        cur.execute("""
            DROP TABLE IF EXISTS tmp_ror_pending;
            CREATE TEMP TABLE tmp_ror_pending ON COMMIT DROP AS
            SELECT o.ror_id, v.index_text
            FROM ror_org o
            JOIN ror_org_index_text_full v USING (ror_id)
            WHERE o.embedding IS NULL OR o.search_text IS DISTINCT FROM v.index_text;
            ANALYZE tmp_ror_pending;
        """)
        cur.execute("SELECT COUNT(*) FROM tmp_ror_pending;")
        total = cur.fetchone()[0]
        print(f"To (re)embed: {total:,}", flush=True)

    if total == 0:
        print("Nothing to update.")
    else:
        # quick warm-up so first batch doesn't feel frozen
        _ = _model.encode(["passage: warmup"], batch_size=1, normalize_embeddings=True, show_progress_bar=False)

        # stream rows with a server-side cursor
        with conn.cursor(name="embed_stream") as scur:
            scur.itersize = max(1000, BATCH)
            scur.execute("SELECT ror_id, index_text FROM tmp_ror_pending ORDER BY ror_id;")

            pbar_embed = tqdm(total=total, desc="Embedding", unit="org")
            pbar_write = tqdm(total=total, desc="Writing to Postgres", unit="org")

            while True:
                rows = scur.fetchmany(BATCH)
                if not rows:
                    break

                rids  = [r[0] for r in rows]
                texts = [r[1] or "" for r in rows]

                vecs = embed_batch(texts)        # <- no inner bars
                pbar_embed.update(len(rows))

                with conn.cursor() as wcur:
                    for rid, txt, v in zip(rids, texts, vecs):
                        wcur.execute(UPDATE_SQL, (txt, v.tolist(), rid))
                        pbar_write.update(1)

            pbar_embed.close()
            pbar_write.close()

    conn.commit()

print("âœ… Embedding pass completed.")

## 8) Ensure ANN indexes (IVFFLAT + optional HNSW)

In [None]:
with psycopg.connect(DSN, autocommit=True) as conn:
    with conn.cursor() as cur:
        cur.execute(
            """
            CREATE INDEX IF NOT EXISTS ror_org_embedding_ivfflat_idx
            ON ror_org USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
            """
        )
        cur.execute(
            """
            DO $$
            BEGIN
              IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname='ror_org_embedding_hnsw_idx') THEN
                CREATE INDEX ror_org_embedding_hnsw_idx ON ror_org USING hnsw (embedding vector_cosine_ops);
              END IF;
            END $$;
            """
        )
print('Vector indexes ensured.')


## 9) Search helper (cosine)

In [None]:
def search_affiliations(qtext: str, k: int = 10):
    qvec = embed_texts([qtext], is_passage=False)[0].tolist()
    sql = (
        "SELECT ror_id, search_text, country_code, 1 - (embedding <=> %s) AS score "
        "FROM ror_org WHERE embedding IS NOT NULL ORDER BY embedding <=> %s LIMIT %s;"
    )
    with psycopg.connect(DSN) as conn:
        with conn.cursor() as cur:
            cur.execute(sql, (qvec, qvec, k))
            return cur.fetchall()

# Example usage (uncomment once embeddings exist):
# for q in ['University of Groningen', 'Rijksuniversiteit Groningen', 'UniversitÃ© de GenÃ¨ve', 'ETH ZÃ¼rich']:
#     print('\nQuery:', q)
#     for rid, txt, cc, score in search_affiliations(q, 5):
#         print(f"  {score:.4f}  {rid}  [{cc}]  {txt[:100]}â€¦")
