In [None]:
from google.colab import drive
import json
import pandas as pd
import duckdb
import pathlib as pl
import numpy as np

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
json_path='/content/drive/MyDrive/review-California_10.json.gz'

In [None]:
meta_path='/content/drive/MyDrive/meta-California.json.gz'

In [None]:

out_dir   = '/content/drive/MyDrive/processed'
pl.Path(out_dir).mkdir(parents=True, exist_ok=True)

In [None]:


con = duckdb.connect()

# -------------------------------
# 1) Filter PLACES by categories
# -------------------------------
# We unnest 'category' → lowercase → apply include/exclude → keep distinct gmap_ids.
places_sql = f"""
WITH m AS (
  SELECT *
  FROM read_json_auto('{meta_path}', records=true, sample_size=-1)
),
cats AS (
  SELECT
    m.*,
    LOWER(TRIM(cat)) AS cat
  FROM m
  CROSS JOIN UNNEST(COALESCE(m.category, [])) AS u(cat)
),
filtered AS (
  SELECT DISTINCT
    gmap_id,
    name,
    latitude  AS lat,
    longitude AS lon,
    category,
    avg_rating,
    description,
    num_of_reviews,
    state,
  FROM cats
  WHERE gmap_id IS NOT NULL
    AND latitude IS NOT NULL
    AND longitude IS NOT NULL
    AND (
      state IS NULL
      OR LOWER(state) NOT LIKE '%permanently closed%'
    )
    -- include rule (restaurant / bar / cafe, including café)
    AND (
      cat LIKE '%restaurant%' OR
      cat LIKE '% bar%' OR cat LIKE 'bar %' OR cat = 'bar' OR
      cat LIKE '%cafe%' OR cat LIKE '%café%' OR cat = 'cafe' OR cat = 'café'
    )
    -- exclude specific false positives
    AND NOT (
      cat LIKE '%barber shop%' OR
      cat LIKE '%internet cafe%' OR
      cat LIKE '%hookah bar%' OR
      cat LIKE '%tapas bar%' OR
      cat LIKE '%bar stool supplier%' OR
      cat LIKE '%dart bar%' OR
      cat LIKE '%barber school%' OR
      cat LIKE '%piano bar%' OR
      cat LIKE '%dog cafe%' OR
      cat LIKE '%children% cafe%' OR  -- covers "children's"/"childrens"
      cat = 'barn' OR
      cat LIKE '%cabaret club%' OR
      cat LIKE '%carbaret club%' OR   -- common misspelling
      cat LIKE '%bartending school%' OR
      cat LIKE '%bariatric surgeon%' OR
      cat LIKE '%bariartic surgeon%' OR
      cat LIKE '%barrel supplier%'
    )
)
SELECT * FROM filtered
"""
places_df = con.sql(places_sql).df()

# ----------------------------------------
# 2) Filter REVIEWS by required non-nulls
#    then keep only those whose gmap_id
#    survives the category filter above.
# ----------------------------------------
reviews_sql = f"""
WITH r AS (
  SELECT *
  FROM read_json_auto('{json_path}', records=true, sample_size=-1)
),
clean AS (
  SELECT
    user_id, gmap_id, name,
    text, rating, time
  FROM r
  WHERE gmap_id IS NOT NULL
    AND user_id IS NOT NULL
    AND name    IS NOT NULL
    AND rating  IS NOT NULL
    AND text    IS NOT NULL
)
SELECT c.*
FROM clean c
JOIN (
  {places_sql}
) p
USING (gmap_id)
"""
reviews_df = con.sql(reviews_sql).df()

# Tidy types (optional but recommended)
reviews_df["time"] = pd.to_datetime(reviews_df["time"], unit="s", errors="coerce")

# --------------------
# 3) Write to Parquet
# --------------------
places_out  = f"{out_dir}/places_filtered.parquet"
reviews_out = f"{out_dir}/reviews_filtered.parquet"

places_df.to_parquet(places_out, index=False)
reviews_df.to_parquet(reviews_out, index=False)



FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [None]:
# === 4k users → what to extract (users & places) ===
from pathlib import Path

# ---------- paths ----------
REVIEWS_PATH = reviews_out
PLACES_PATH  = places_out
OUT_DIR      = Path("/content/drive/MyDrive/processed/slice_4k")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---------- params ----------
SAMPLE_USERS       = 4000 #try for 5000-6000
SEED               = 42
USE_TEMPORAL_SPLIT = True     # avoid leakage
TEST_RATIO         = 0.2
MIN_TRAIN_PER_USER = 1

# ---------- load ----------
rev_cols = ["user_id","gmap_id","name","text","rating","time"]
reviews  = pd.read_parquet(REVIEWS_PATH, columns=rev_cols)
reviews  = reviews[reviews.user_id.notna() & reviews.gmap_id.notna() & reviews.text.notna()].copy()
reviews["time"] = pd.to_datetime(reviews["time"], errors="coerce")

# places_cols = ["gmap_id","name","lat","lon","category","avg_rating","description","num_of_reviews","state"]
# places = pd.read_parquet(PLACES_PATH, columns=[c for c in places_cols if c in pd.read_parquet(PLACES_PATH, nrows=1).columns])

def read_parquet_subset(path, want_cols):
    import pyarrow.parquet as pq
    try:
        names = set(pq.read_schema(path).names)
        cols  = [c for c in want_cols if c in names]
        return pd.read_parquet(path, columns=cols)
    except Exception:
        # fallback: read all then select intersection
        df_all = pd.read_parquet(path)
        cols   = [c for c in want_cols if c in df_all.columns]
        return df_all[cols]

places_cols = ["gmap_id","name","lat","lon","category","avg_rating","description","num_of_reviews","state"]
places = read_parquet_subset(PLACES_PATH, places_cols)
# ---------- sample 4k users ----------
rng = np.random.default_rng(SEED)
all_users = pd.Index(reviews["user_id"].unique())
k = min(SAMPLE_USERS, len(all_users))
sampled_users = pd.Index(rng.choice(all_users, size=k, replace=False))

df = reviews[reviews["user_id"].isin(sampled_users)].copy()
df.sort_values(["user_id","time"], inplace=True, na_position="last")

# ---------- per-user split (temporal 80/20) ----------
parts = []
if USE_TEMPORAL_SPLIT:
    for uid, g in df.groupby("user_id", sort=False):
        g = g.reset_index(drop=True)
        n = len(g)
        n_test  = max(1, int(np.ceil(n * TEST_RATIO)))
        n_train = max(MIN_TRAIN_PER_USER, n - n_test)
        # ensure at least MIN_TRAIN train rows
        if n_train + n_test > n:
            n_test = max(0, n - n_train)
        g.loc[:n_train-1, "split"] = "train"
        g.loc[n_train:,   "split"] = "test"
        parts.append(g)
else:
    for uid, g in df.groupby("user_id", sort=False):
        g = g.reset_index(drop=True)
        if len(g) == 1:
            g.loc[0, "split"] = "test"
        else:
            g.loc[:-1, "split"] = "train"
            g.loc[-1,  "split"] = "test"
        parts.append(g)

split_df = pd.concat(parts, ignore_index=True)
train_df = split_df[split_df["split"]=="train"].copy()
test_df  = split_df[split_df["split"]=="test"].copy()

# ---------- WHICH places to extract? ----------
# Warm items = items with ≥1 TRAIN review from these users
warm_items = pd.Index(train_df["gmap_id"].unique())
# Cold items = appear only in TEST (no train text)
all_items  = pd.Index(split_df["gmap_id"].unique())
cold_items = all_items.difference(warm_items)

# Join warm items to places metadata → this is exactly the set of places to PyABSA
places_for_extraction = places[places["gmap_id"].isin(warm_items)].drop_duplicates("gmap_id")

# ---------- user & item texts for extraction (TRAIN only) ----------
# Users: run PyABSA over their TRAIN reviews
user_reviews_train = train_df.sort_values(["user_id","time"]).reset_index(drop=True)

# Items: run PyABSA over TRAIN reviews grouped by gmap_id (these are the warm items)
item_reviews_train = train_df.sort_values(["gmap_id","time"]).reset_index(drop=True)

# ---------- optional: metadata fallback for cold items ----------


def _category_to_list(val):
    """Return a clean list[str] for many possible 'category' shapes."""
    # Missing
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    # Already list/tuple
    if isinstance(val, (list, tuple)):
        seq = list(val)
    # numpy array
    elif isinstance(val, np.ndarray):
        seq = val.tolist()
    # String: maybe JSON, else plain label
    elif isinstance(val, str):
        s = val.strip()
        if s.startswith("[") or s.startswith("(") or s.startswith("{"):
            try:
                parsed = json.loads(s)
                if isinstance(parsed, dict):
                    seq = list(parsed.values())
                elif isinstance(parsed, (list, tuple)):
                    seq = list(parsed)
                else:
                    seq = [val]
            except Exception:
                seq = [val]
        else:
            seq = [val]
    else:
        # Fallback: single value
        seq = [val]
    # Drop nulllikes and stringify
    out = [str(x) for x in seq if x is not None and not (isinstance(x, float) and np.isnan(x))]
    return out

# --- build fallback metadata safely ---
cold_meta = places[places["gmap_id"].isin(cold_items)].copy()

if "category" in cold_meta.columns:
    cold_meta["category_text"] = cold_meta["category"].apply(_category_to_list).apply(lambda xs: "; ".join(xs))
else:
    cold_meta["category_text"] = ""

cols_to_use = [c for c in ["name", "category_text", "description", "state"] if c in cold_meta.columns]
if cols_to_use:
    cold_meta["fallback_text"] = (
        cold_meta[cols_to_use]
        .astype(str)
        .replace({"None": "", "nan": ""})
        .agg(" | ".join, axis=1)
        .str.replace(r"\s+", " ", regex=True)
        .str.strip()
    )
else:
    cold_meta["fallback_text"] = ""

cold_item_metadata = cold_meta[["gmap_id", "fallback_text"]].drop_duplicates("gmap_id")


# ---------- save artifacts ----------
(user_reviews_train
 .to_parquet(OUT_DIR/"user_reviews_train.parquet", index=False))
(item_reviews_train
 .to_parquet(OUT_DIR/"item_reviews_train.parquet", index=False))
(places_for_extraction
 .to_parquet(OUT_DIR/"places_for_extraction.parquet", index=False))
(pd.DataFrame({"gmap_id": warm_items})
 .to_parquet(OUT_DIR/"warm_items.parquet", index=False))
(pd.DataFrame({"gmap_id": cold_items})
 .to_parquet(OUT_DIR/"cold_items.parquet", index=False))
cold_item_metadata.to_parquet(OUT_DIR/"cold_item_metadata.parquet", index=False)

# ---------- quick report ----------
print("=== 4k USER SLICE REPORT ===")
print(f"Users sampled:                {len(sampled_users):,}")
print(f"Train interactions (rows):    {len(train_df):,}")
print(f"Test interactions (rows):     {len(test_df):,}")
print(f"Warm items (train-covered):   {len(warm_items):,}  <- PyABSA these")
print(f"Cold items (test-only):       {len(cold_items):,}  <- use metadata fallback")
print("\nSaved:")
for p in [
    "user_reviews_train.parquet",
    "item_reviews_train.parquet",
    "places_for_extraction.parquet",
    "warm_items.parquet",
    "cold_items.parquet",
    "cold_item_metadata.parquet",
]:
    print(f"- {OUT_DIR/p}")


=== 4k USER SLICE REPORT ===
Users sampled:                4,000
Train interactions (rows):    29,601
Test interactions (rows):     9,099
Warm items (train-covered):   19,585  <- PyABSA these
Cold items (test-only):       4,965  <- use metadata fallback

Saved:
- /content/drive/MyDrive/processed/slice_4k/user_reviews_train.parquet
- /content/drive/MyDrive/processed/slice_4k/item_reviews_train.parquet
- /content/drive/MyDrive/processed/slice_4k/places_for_extraction.parquet
- /content/drive/MyDrive/processed/slice_4k/warm_items.parquet
- /content/drive/MyDrive/processed/slice_4k/cold_items.parquet
- /content/drive/MyDrive/processed/slice_4k/cold_item_metadata.parquet


In [None]:
!pip install -qU "pyabsa==2.4.1"  # pin a version that avoids the .config misdetection you hit

import shutil,sys
# === ETA + setup for aspect extraction (no writes) ===
import os, time, math

SLICE_DIR   = Path("/content/drive/MyDrive/processed/slice_4k")
USER_TRAIN  = SLICE_DIR/"user_reviews_train.parquet"
ITEM_TRAIN  = SLICE_DIR/"item_reviews_train.parquet"

# sanity
for p in [USER_TRAIN, ITEM_TRAIN]:
    assert p.exists(), f"missing: {p}"

# load small samples just for calibration
user_sample = pd.read_parquet(USER_TRAIN, columns=["text"]).dropna().sample(n=min(2000,  len(pd.read_parquet(USER_TRAIN, columns=['text']))), random_state=42)["text"].astype(str).tolist()
item_sample = pd.read_parquet(ITEM_TRAIN, columns=["text"]).dropna().sample(n=min(2000,  len(pd.read_parquet(ITEM_TRAIN, columns=['text']))), random_state=42)["text"].astype(str).tolist()

from pyabsa import ATEPCCheckpointManager

# isolate cache (prevents ~/.config quirks)
os.environ.setdefault("PYABSA_HOME", "/content/pyabsa_ckpt")
os.environ.setdefault("PYABSA_VERBOSE", "1")

# model
BATCH_SIZE = 64 #drop to 32 or go upto 96 depending on your compute
extractor = ATEPCCheckpointManager.get_aspect_extractor(
    checkpoint="english", auto_device=True, force_download=False, batch_size=BATCH_SIZE
)

# warm-up
_ = extractor.extract_aspect(user_sample[:128], pred_sentiment=True, print_result=False, save_result=False)

def measure_rps(texts):
    if not texts: return 0.0
    t0 = time.time()
    _ = extractor.extract_aspect(texts, pred_sentiment=True, print_result=False, save_result=False)
    dt = max(time.time() - t0, 1e-9)
    return len(texts)/dt, dt

rps_u, dt_u = measure_rps(user_sample)
rps_i, dt_i = measure_rps(item_sample)

# totals
n_user_rows = int(pd.read_parquet(USER_TRAIN, columns=["text"]).shape[0])
n_item_rows = int(pd.read_parquet(ITEM_TRAIN, columns=["text"]).shape[0])

eta_user_min = n_user_rows / max(rps_u, 1e-9) / 60
eta_item_min = n_item_rows / max(rps_i, 1e-9) / 60

print("=== CALIBRATION ===")
print(f"User sample: {len(user_sample)} in {dt_u:.1f}s -> {rps_u:.1f} rev/s")
print(f"Item sample: {len(item_sample)} in {dt_i:.1f}s -> {rps_i:.1f} rev/s")
print("\n=== ESTIMATED RUNTIMES (train only) ===")
print(f"User reviews: {n_user_rows:,} → ~{eta_user_min:.1f} min")
print(f"Item reviews: {n_item_rows:,} → ~{eta_item_min:.1f} min")
print("\nTip: if memory is tight, drop BATCH_SIZE to 32; if fast, you can try 96.")


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m575.4/575.4 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.2/54.2 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.3/175.3 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-metadata 1.17.2 requires protobuf>=4.25.2; python_version >= "3.11", but you have 

  return datetime.utcnow().replace(tzinfo=utc)
  _EPOCH_DATETIME_NAIVE = datetime.datetime.utcfromtimestamp(0)
  return datetime.utcnow().replace(tzinfo=utc)
  from click.parser import split_arg_string
  from click.parser import split_arg_string


[2025-10-31 03:01:37] (2.4.1) PyABSA(2.4.1): If your code crashes on Colab, please use the GPU runtime. Then run "pip install pyabsa[dev] -U" and restart the kernel.
Or if it does not work, you can use v1.x versions, e.g., pip install pyabsa<2.0 -U




Try to downgrade transformers<=4.29.0.






  self.pid = os.fork()
  _warn(f"unclosed running multiprocessing pool {self!r}",


[2025-10-31 03:01:50] (2.4.1) ********** Available ATEPC model checkpoints for Version:2.4.1 (this version) **********
[2025-10-31 03:01:50] (2.4.1) ********** Available ATEPC model checkpoints for Version:2.4.1 (this version) **********
[2025-10-31 03:01:50] (2.4.1) Downloading checkpoint:english 
[2025-10-31 03:01:50] (2.4.1) Notice: The pretrained model are used for testing, it is recommended to train the model on your own custom datasets


Downloading checkpoint: 579MB [00:04, 129.61MB/s]                         


Find zipped checkpoint: ./checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43.zip, unzipping
Done.
[2025-10-31 03:02:04] (2.4.1) If the auto-downloading failed, please download it via browser: https://huggingface.co/spaces/yangheng/PyABSA/resolve/main/checkpoints/English/ATEPC/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43.zip 
[2025-10-31 03:02:04] (2.4.1) Load aspect extractor from checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43
[2025-10-31 03:02:04] (2.4.1) config: checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43/fast_lcf_atepc.config
[2025-10-31 03:02:04] (2.4.1) state_dict: checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43/fast_lcf_atepc.state_dict
[2025-10-31 03:02:04] (2.4.1) model: None
[2025-10-31 03:02:04] (2.4.1) tokenizer: checkpoints/ATEPC_ENGLISH_CHECKPO

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/371M [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

preparing ate inference dataloader: 100%|██████████| 128/128 [00:00<00:00, 844.99it/s]
extracting aspect terms: 100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
preparing apc inference dataloader: 100%|██████████| 268/268 [00:00<00:00, 449.29it/s]
  lcf_cdm_vec = torch.tensor(
classifying aspect sentiments: 100%|██████████| 9/9 [00:06<00:00,  1.30it/s]
preparing ate inference dataloader: 100%|██████████| 2000/2000 [00:01<00:00, 1023.30it/s]
extracting aspect terms: 100%|██████████| 63/63 [00:50<00:00,  1.24it/s]
preparing apc inference dataloader: 100%|██████████| 3863/3863 [00:07<00:00, 506.21it/s]
classifying aspect sentiments: 100%|██████████| 121/121 [01:45<00:00,  1.15it/s]
preparing ate inference dataloader: 100%|██████████| 2000/2000 [00:02<00:00, 986.52it/s] 
extracting aspect terms: 100%|██████████| 63/63 [00:53<00:00,  1.19it/s]
preparing apc inference dataloader: 100%|██████████| 3906/3906 [00:07<00:00, 526.85it/s]
classifying aspect sentiments: 100%|██████████| 123/123 [01:46<

=== CALIBRATION ===
User sample: 2000 in 167.9s -> 11.9 rev/s
Item sample: 2000 in 171.4s -> 11.7 rev/s

=== ESTIMATED RUNTIMES (train only) ===
User reviews: 29,601 → ~41.4 min
Item reviews: 29,601 → ~42.3 min

Tip: if memory is tight, drop BATCH_SIZE to 32; if fast, you can try 96.


In [None]:
# === Full PyABSA extraction (resumable) for train users + train items ===
import os, json, time, pandas as pd, numpy as np, pyarrow as pa, pyarrow.parquet as pq
from pathlib import Path
from collections.abc import Mapping
from pyabsa import ATEPCCheckpointManager

SLICE_DIR   = Path("/content/drive/MyDrive/processed/slice_4k")
USER_TRAIN  = SLICE_DIR/"user_reviews_train.parquet"
ITEM_TRAIN  = SLICE_DIR/"item_reviews_train.parquet"

OUT_USERS_DIR = SLICE_DIR/"aspects_users_train"
OUT_ITEMS_DIR = SLICE_DIR/"aspects_items_train"
OUT_USERS_DIR.mkdir(parents=True, exist_ok=True)
OUT_ITEMS_DIR.mkdir(parents=True, exist_ok=True)

# You can tweak these if you like
BATCH_SIZE = 64     # leave at 64 based on your calibration
CHUNK      = 4000   # reviews per chunk (2k–8k is a good range)
CKPT_NAME  = "english"
TRUNCATE_CHARS = 0  # set e.g. 800 to hard-truncate long reviews for speed (0 = no truncation)

# --- model init ---
os.environ.setdefault("PYABSA_HOME", "/content/pyabsa_ckpt")
os.environ.setdefault("PYABSA_VERBOSE", "1")
extractor = ATEPCCheckpointManager.get_aspect_extractor(
    checkpoint=CKPT_NAME, auto_device=True, force_download=False, batch_size=BATCH_SIZE
)

def _first_result(res):
    if res is None: return None
    if isinstance(res, dict):
        if "aspect" in res or "sentence" in res: return res
        if "result" in res and isinstance(res["result"], (list,tuple)) and res["result"]: return res["result"][0]
        if "results" in res and isinstance(res["results"], (list,tuple)) and res["results"]: return res["results"][0]
        return None
    if isinstance(res, (list, tuple)) and res:
        return res[0] if isinstance(res[0], Mapping) else None
    try:
        seq = list(res);  return seq[0] if seq and isinstance(seq[0], Mapping) else None
    except Exception:
        return None

def _span(tokens, span):
    if not span: return ""
    i0, i1 = (min(span), max(span)) if len(span)>1 else (span[0], span[0])
    return " ".join(tokens[i0:i1+1])

def _window(tokens, span, pad=6):
    if not span: return ""
    start = max(0, min(span) - pad); end = min(len(tokens)-1, max(span) + pad)
    return " ".join(tokens[start:end+1])

def _clean(s):
    return " ".join(str(s).replace("‘","'").replace("’","'").replace("`","'").split())

def run_extract(in_path: Path, out_dir: Path, id_cols):
    df = pd.read_parquet(in_path)
    if TRUNCATE_CHARS and TRUNCATE_CHARS > 0:
        df["text"] = df["text"].astype(str).str.slice(0, TRUNCATE_CHARS)
    N = len(df)
    print(f"\n[run] {in_path.name}: {N} reviews")
    existing = {p.name for p in out_dir.glob("part-*.parquet")}
    t0_all = time.time()
    processed = 0
    smoothed_rps = None

    for start in range(0, N, CHUNK):
        stop = min(N, start+CHUNK)
        part = f"part-{start:09d}.parquet"
        if part in existing:
            processed += (stop - start)
            print(f"[skip] {part}")
            continue

        batch = df.iloc[start:stop].reset_index(drop=True).copy()
        texts = batch["text"].astype(str).tolist()

        if "time" in batch.columns:
            t = pd.to_datetime(batch["time"], errors="coerce", utc=True)
            batch["time_s"] = (t.view("int64") // 10**9).where(t.notna(), None)  # epoch seconds or None

        texts = batch["text"].astype(str).tolist()

        t0 = time.time()
        results = extractor.extract_aspect(
            inference_source=texts, pred_sentiment=True, print_result=False, save_result=False
        )
        dt = time.time() - t0
        rps = len(texts)/max(dt,1e-9)
        smoothed_rps = rps if smoothed_rps is None else (0.7*smoothed_rps + 0.3*rps)

        rows = []
        for i, res in enumerate(results):
            r = _first_result(res)
            if not r:
                continue
            sentence = r.get("sentence", "")
            tokens   = r.get("tokens", [])
            aspects  = r.get("aspect", [])
            pos      = r.get("position", [])
            sents    = r.get("sentiment", [])
            confs    = r.get("confidence", [])

            L = min(len(aspects), len(pos), len(sents), len(confs) if confs else 10**9)
            for k in range(L):
                row = {
                    "global_row":  start + i,
                    "review_text": sentence,
                    "aspect":      _clean(_span(tokens, pos[k])),
                    "sentiment":   sents[k],                       # Positive/Negative/Neutral
                    "confidence":  (confs[k] if confs else None),
                    "evidence":    _clean(_window(tokens, pos[k], 6)),
                    "position":    json.dumps(pos[k]),
                }
                for col in id_cols:
                    row[col] = batch.iloc[i][col]
                rows.append(row)

        pa_tbl = pa.Table.from_pylist(rows)
        pq.write_table(pa_tbl, out_dir/part)
        processed += len(texts)

        elapsed = time.time() - t0_all
        remaining = N - processed
        eta_min = remaining / max(smoothed_rps or rps, 1e-9) / 60
        print(f"[ok] {part}: {len(rows)} aspect rows | {len(texts)} reviews in {dt:.1f}s "
              f"(~{rps:.1f} r/s, smoothed ~{smoothed_rps:.1f}) | "
              f"{processed}/{N} processed | ETA ~{eta_min:.1f} min")

def concat_parts(out_dir: Path, full_path: Path):
    parts = sorted(out_dir.glob("part-*.parquet"))
    if not parts:
        print(f"[warn] no parts found in {out_dir}")
        return
    writer = None
    total = 0
    for p in parts:
        tbl = pq.read_table(p)
        if writer is None:
            writer = pq.ParquetWriter(full_path, tbl.schema)
        writer.write_table(tbl)
        total += tbl.num_rows
    if writer is not None:
        writer.close()
    print(f"[ok] combined {len(parts)} parts ({total:,} rows) -> {full_path}")

# Users (train)
run_extract(USER_TRAIN, OUT_USERS_DIR, id_cols=["user_id","gmap_id","name","rating","time_s"])
concat_parts(OUT_USERS_DIR, SLICE_DIR/"aspects_users_train_full.parquet")

# Items (train warm items)
run_extract(ITEM_TRAIN, OUT_ITEMS_DIR, id_cols=["gmap_id","name","rating","time_s"])
concat_parts(OUT_ITEMS_DIR, SLICE_DIR/"aspects_items_train_full.parquet")

print("\nArtifacts:")
print("-", SLICE_DIR/"aspects_users_train")
print("-", SLICE_DIR/"aspects_users_train_full.parquet")
print("-", SLICE_DIR/"aspects_items_train")
print("-", SLICE_DIR/"aspects_items_train_full.parquet")


[2025-10-31 03:08:10] (2.4.1) ********** Available ATEPC model checkpoints for Version:2.4.1 (this version) **********
[2025-10-31 03:08:10] (2.4.1) ********** Available ATEPC model checkpoints for Version:2.4.1 (this version) **********
[2025-10-31 03:08:10] (2.4.1) Downloading checkpoint:english 
[2025-10-31 03:08:10] (2.4.1) Notice: The pretrained model are used for testing, it is recommended to train the model on your own custom datasets
[2025-10-31 03:08:10] (2.4.1) Checkpoint already downloaded, skip
[2025-10-31 03:08:11] (2.4.1) Load aspect extractor from checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43
[2025-10-31 03:08:11] (2.4.1) config: checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43/fast_lcf_atepc.config
[2025-10-31 03:08:11] (2.4.1) state_dict: checkpoints/ATEPC_ENGLISH_CHECKPOINT/fast_lcf_atepc_English_cdw_apcacc_82.36_apcf1_81.89_atef1_75.43/fast_lcf_atepc.state_dict


  return datetime.utcnow().replace(tzinfo=utc)



[run] user_reviews_train.parquet: 29601 reviews
[skip] part-000000000.parquet
[skip] part-000004000.parquet
[skip] part-000008000.parquet
[skip] part-000012000.parquet
[skip] part-000016000.parquet
[skip] part-000020000.parquet
[skip] part-000024000.parquet
[skip] part-000028000.parquet
[ok] combined 8 parts (58,901 rows) -> /content/drive/MyDrive/processed/slice_4k/aspects_users_train_full.parquet


  return datetime.utcnow().replace(tzinfo=utc)



[run] item_reviews_train.parquet: 29601 reviews
[skip] part-000000000.parquet
[skip] part-000004000.parquet
[skip] part-000008000.parquet
[skip] part-000012000.parquet
[skip] part-000016000.parquet
[skip] part-000020000.parquet
[skip] part-000024000.parquet
[skip] part-000028000.parquet
[ok] combined 8 parts (58,606 rows) -> /content/drive/MyDrive/processed/slice_4k/aspects_items_train_full.parquet

Artifacts:
- /content/drive/MyDrive/processed/slice_4k/aspects_users_train
- /content/drive/MyDrive/processed/slice_4k/aspects_users_train_full.parquet
- /content/drive/MyDrive/processed/slice_4k/aspects_items_train
- /content/drive/MyDrive/processed/slice_4k/aspects_items_train_full.parquet


In [None]:
!pip -q install pyarrow fastparquet

import pandas as pd
from pathlib import Path

SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
U_PATH = SLICE/"aspects_users_train_full.parquet"
I_PATH = SLICE/"aspects_items_train_full.parquet"

U = pd.read_parquet(U_PATH)
I = pd.read_parquet(I_PATH)

pd.set_option("display.max_colwidth", 160)
pd.set_option("display.width", 160)

print("Users rows:", len(U), " | Items rows:", len(I))
print("User cols:", list(U.columns))
print("Item cols:", list(I.columns))


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.8 MB[0m [31m7.9 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m26.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hUsers rows: 58901  | Items rows: 58606
User cols: ['global_row', 'review_text', 'aspect', 'sentiment', 'confidence', 'evidence', 'position', 'user_id', 'gmap_id', 'name', 'rating', 'time_s']
Item cols: ['global_row', 'review_text', 'aspect', 'sentiment', 'confidence', 'evidence', 'position', 'gmap_id', 'name', 'rating', 'time_s']


In [None]:
# pick the user_id / gmap_id with the most extracted aspect rows
top_user = U["user_id"].value_counts().index[0]
top_item = I["gmap_id"].value_counts().index[0]
print("Auto-picked:", {"user_id": top_user, "gmap_id": top_item})

# helper views
def view_user(uid, df=U, k=30):
    cols = [c for c in ["gmap_id","name","rating","time_s","aspect","sentiment","confidence","evidence","position","review_text"] if c in df.columns]
    v = (df[df["user_id"]==uid][cols].reset_index(drop=True))
    print(f"\n=== USER {uid} · {len(v)} aspect-rows ===")
    display(v.head(k))

def view_item(iid, df=I, k=30):
    cols = [c for c in ["name","rating","aspect","sentiment","confidence","evidence","position","review_text"] if c in df.columns]
    v = (df[df["gmap_id"]==iid][cols].reset_index(drop=True))
    print(f"\n=== ITEM {iid} · {len(v)} aspect-rows ===")
    display(v.head(k))

view_user(top_user, k=30)
view_item(top_item, k=30)


Auto-picked: {'user_id': '112610368419579311974', 'gmap_id': '0x8091f1c6cc994511:0xe6a7829a8604dca8'}

=== USER 112610368419579311974 · 463 aspect-rows ===


Unnamed: 0,gmap_id,name,rating,time_s,aspect,sentiment,confidence,evidence,position,review_text
0,0x80d9552cc4c0c6e1:0x1d50ac359a7a629c,Brian P,4,,server,Positive,0.9908,in the restaurant . But the server was amazing . The food is,[8],"Typically empty in the restaurant . But the server was amazing . The food is pretty good , could use a little more seasoning . But not bad for a restaurant ..."
1,0x80d9552cc4c0c6e1:0x1d50ac359a7a629c,Brian P,4,,food,Positive,0.8313,"the server was amazing . The food is pretty good , could use",[13],"Typically empty in the restaurant . But the server was amazing . The food is pretty good , could use a little more seasoning . But not bad for a restaurant ..."
2,0x80d9552cc4c0c6e1:0x1d50ac359a7a629c,Brian P,4,,price,Positive,0.9977,in a hotel . And the price was good .,[37],"Typically empty in the restaurant . But the server was amazing . The food is pretty good , could use a little more seasoning . But not bad for a restaurant ..."
3,0x80dc06d20cf5424d:0xe3be4db207191525,Brian P,5,,Service,Positive,0.9992,Service here was fantastic . The girls,[0],Service here was fantastic . The girls at the counter were super helpful and friendly . \n The cookie was so good that I didn ’ t even get a picture of it ....
4,0x80dc06d20cf5424d:0xe3be4db207191525,Brian P,5,,girls,Positive,0.9992,Service here was fantastic . The girls at the counter were super helpful,[6],Service here was fantastic . The girls at the counter were super helpful and friendly . \n The cookie was so good that I didn ’ t even get a picture of it ....
5,0x80dc06d20cf5424d:0xe3be4db207191525,Brian P,5,,bad,Positive,0.9985,. But it wasn ' t bad . The burger tasted great,[63],Service here was fantastic . The girls at the counter were super helpful and friendly . \n The cookie was so good that I didn ’ t even get a picture of it ....
6,0x80dc06d20cf5424d:0xe3be4db207191525,Brian P,5,,.,Positive,0.9992,definitely give this place a try . Everyone really enjoyed the milkshakes,[93],Service here was fantastic . The girls at the counter were super helpful and friendly . \n The cookie was so good that I didn ’ t even get a picture of it ....
7,0x80dc09254b37149f:0xf4879c79898b68b8,Brian P,5,,service,Neutral,0.4537,"unless people had problem with the service , I ' m not sure",[13],"It ’ s Starbucks . . . unless people had problem with the service , I ’ m not sure what people are complaining about . \n \n I found the staff very friendly..."
8,0x80dc09254b37149f:0xf4879c79898b68b8,Brian P,5,,found,Positive,0.9464,complaining about . I found the staff very friendly . The,[29],"It ’ s Starbucks . . . unless people had problem with the service , I ’ m not sure what people are complaining about . \n \n I found the staff very friendly..."
9,0x80dc01eb38c46dc3:0x5808904c07a7b1f2,Brian P,5,,place,Positive,0.9989,This place is one of a kind for,[1],"This place is one of a kind for Pacific Beach . Fun times , good country music , and lots of line dancing ."



=== ITEM 0x8091f1c6cc994511:0xe6a7829a8604dca8 · 49 aspect-rows ===


Unnamed: 0,name,rating,aspect,sentiment,confidence,evidence,position,review_text
0,Dallin Kimble,4,fruit,Positive,0.9988,"novelty sweets , dried and fresh fruit and a variety of nuts ,",[13],"Better than expected with all kinds of novelty sweets , dried and fresh fruit and a variety of nuts , fudge and popcorn . Good travel stop with an old playg..."
1,Dallin Kimble,4,popcorn,Neutral,0.9248,"variety of nuts , fudge and popcorn . Good travel stop with an",[22],"Better than expected with all kinds of novelty sweets , dried and fresh fruit and a variety of nuts , fudge and popcorn . Good travel stop with an old playg..."
2,Kurt Willmon,4,place,Positive,0.9915,Interesting place with a wide selection of products,[1],Interesting place with a wide selection of products but super high prices . Still a good place to stop though .
3,Kurt Willmon,4,prices,Negative,0.8889,selection of products but super high prices . Still a good place to,[11],Interesting place with a wide selection of products but super high prices . Still a good place to stop though .
4,Ray Kings,5,spot,Positive,0.9994,NICE spot to stop on your way thru,[1],NICE spot to stop on your way thru . Stretch them legs and let the kids enjoy the park . And grab some jalapeno pistachios . . .
5,Ray Kings,5,pistachios,Neutral,0.8309,park . And grab some jalapeno pistachios . . .,[24],NICE spot to stop on your way thru . Stretch them legs and let the kids enjoy the park . And grab some jalapeno pistachios . . .
6,Fatima Zafar,3,food,Neutral,0.5239,Very crowded and food was ok .,[3],Very crowded and food was ok .
7,B-dette Liua,5,place,Positive,0.8725,A quaint place that I love love love .,[2],A quaint place that I love love love .
8,Ken Saylor,5,everything,Neutral,0.4754,Must stop excellent everything .,[3],Must stop excellent everything .
9,Crystal Harper,5,place,Positive,0.8351,This place is awesome !,[1],This place is awesome !


In [None]:
from collections import Counter

def summarize_user(uid, df=U, topn=10):
    sub = df[df["user_id"]==uid]
    print(f"[User {uid}] rows={len(sub)}")
    print("sentiment:", dict(Counter(sub["sentiment"])))
    print("top aspects:", Counter(sub["aspect"]).most_common(topn))

def summarize_item(iid, df=I, topn=10):
    sub = df[df["gmap_id"]==iid]
    print(f"[Item {iid}] rows={len(sub)}")
    print("sentiment:", dict(Counter(sub["sentiment"])))
    print("top aspects:", Counter(sub["aspect"]).most_common(topn))

summarize_user(top_user)
summarize_item(top_item)


[User 112610368419579311974] rows=463
sentiment: {'Positive': 312, 'Neutral': 76, 'Negative': 75}
top aspects: [('food', 58), ('Food', 32), ('service', 24), ('Service', 16), ('place', 16), ('staff', 16), ('the', 10), ('price', 8), ('atmosphere', 8), ('server', 7)]
[Item 0x8091f1c6cc994511:0xe6a7829a8604dca8] rows=49
sentiment: {'Positive': 37, 'Neutral': 8, 'Negative': 4}
top aspects: [('place', 10), ('fruit', 4), ('food', 3), ('popcorn', 1), ('prices', 1), ('spot', 1), ('pistachios', 1), ('everything', 1), ('sandwiches', 1), ('casa', 1)]


In [None]:
!pip -q install sentence-transformers pyarrow fastparquet

import os, math, json, time
import numpy as np
import pandas as pd
import pyarrow as pa, pyarrow.parquet as pq
from pathlib import Path
from sentence_transformers import SentenceTransformer


In [None]:
SLICE = Path("/content/drive/MyDrive/processed/slice_4k")

# PyABSA outputs you already have:
U_ABSA = SLICE/"aspects_users_train_full.parquet"
I_ABSA = SLICE/"aspects_items_train_full.parquet"

# Where to write embeddings (row-level) + final entity representations:
U_EMB_DIR = SLICE/"aspects_users_emb"   ; U_EMB_DIR.mkdir(exist_ok=True, parents=True)
I_EMB_DIR = SLICE/"aspects_items_emb"   ; I_EMB_DIR.mkdir(exist_ok=True, parents=True)
U_EMB_FULL = SLICE/"aspects_users_emb_full.parquet"
I_EMB_FULL = SLICE/"aspects_items_emb_full.parquet"

U_REPR = SLICE/"user_repr.parquet"
I_REPR_WARM = SLICE/"item_repr_warm.parquet"

# encoder + batching
MODEL_NAME   = "all-MiniLM-L6-v2"   # 384-dim, fast & good
BATCH_SIZE   = 512                 # sentence-transformers batch; adjust if VRAM is tight
CHUNK_ROWS   = 25000               # parquet read/write chunking
DROP_NEUTRAL = True                 # usually yes for preference learning
CONF_MIN     = 0.0                  # e.g., 0.3 to drop low-confidence extractions


In [None]:
def _iter_parquet_batches(path, chunk_rows=50_000, cols=None):
    """Stream Parquet in row chunks without exploding RAM."""
    import pyarrow.parquet as pq
    pf = pq.ParquetFile(path)
    nrow = pf.metadata.num_rows
    offset = 0
    while offset < nrow:
        n = min(chunk_rows, nrow - offset)
        tbl = pf.read_row_groups(range(offset//pf.metadata.row_group(0).num_rows,
                                       math.ceil((offset+n)/pf.metadata.row_group(0).num_rows)),
                                 columns=cols)
        # slice to exact rows
        start = offset
        stop  = offset + n
        df = tbl.to_pandas()
        if len(df) > n: df = df.iloc[:n].copy()
        yield start, df
        offset += n

def _clean_aspect(s):
    import re
    s = str(s).lower().strip()
    s = re.sub(r"[^a-z0-9\s&/\-]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def _text_for_emb(row):
    ev = str(row.get("evidence", "") or "").strip()
    if ev:
        return ev
    a = str(row.get("aspect", "") or "").strip()
    return a

def _concat_parts(out_dir: Path, full_path: Path):
    parts = sorted(out_dir.glob("part-*.parquet"))
    if not parts:
        print(f"[warn] no parts found in {out_dir}")
        return
    writer = None
    tot = 0
    for p in parts:
        tbl = pq.read_table(p)
        if writer is None:
            writer = pq.ParquetWriter(full_path, tbl.schema)
        writer.write_table(tbl)
        tot += tbl.num_rows
    if writer is not None:
        writer.close()
    print(f"[ok] combined {len(parts)} parts ({tot:,} rows) -> {full_path}")


In [None]:
def encode_absa_to_emb(parquet_in: Path, out_dir: Path, out_full: Path, key_cols: list):
    """
    Reads ABSA rows, filters, encodes text with SBERT, writes rows with `emb` (list[float]).
    Keeps all useful columns for later aggregation.
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    model = SentenceTransformer(MODEL_NAME)
    wrote_parts = 0
    total_rows  = 0

    # what columns we try to keep if present
    keep_cols = key_cols + [
        "aspect","sentiment","confidence","evidence","position","review_text",
        "gmap_id","name","rating","time_s","user_id"
    ]
    keep_cols = list(dict.fromkeys(keep_cols))  # dedup

    for start, df in _iter_parquet_batches(parquet_in, chunk_rows=CHUNK_ROWS, cols=None):
        # select/clean
        cols_present = [c for c in keep_cols if c in df.columns]
        df = df[cols_present].copy()
        df["aspect"] = df["aspect"].astype(str)
        df["evidence"] = df.get("evidence", "").astype(str)
        if "confidence" in df.columns:
            df = df[df["confidence"].fillna(1.0) >= CONF_MIN]

        if DROP_NEUTRAL and "sentiment" in df.columns:
            df = df[df["sentiment"].isin(["Positive","Negative"])]

        if df.empty:
            continue

        df["aspect_norm"] = df["aspect"].map(_clean_aspect)
        df["sent_num"] = df["sentiment"].map({"Positive":1, "Negative":-1}).fillna(0).astype(int)
        df["text_for_emb"] = df.apply(_text_for_emb, axis=1)

        # encode in mini-batches to avoid OOM
        texts = df["text_for_emb"].tolist()
        embs  = []
        for b in range(0, len(texts), BATCH_SIZE):
            chunk = texts[b:b+BATCH_SIZE]
            Z = model.encode(chunk, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
            embs.append(Z)
        Z = np.vstack(embs)
        df["emb"] = [z.astype("float32").tolist() for z in Z]

        # write part
        part = out_dir/f"part-{start:09d}.parquet"
        pq.write_table(pa.Table.from_pandas(df, preserve_index=False), part)
        wrote_parts += 1
        total_rows  += len(df)
        print(f"[ok] {part.name}: {len(df)} rows (total {total_rows:,})")

    _concat_parts(out_dir, out_full)
    print(f"[done] encoded → {out_full}")


In [None]:
# users: key is user_id
encode_absa_to_emb(U_ABSA, U_EMB_DIR, U_EMB_FULL, key_cols=["user_id"])

# items: key is gmap_id
encode_absa_to_emb(I_ABSA, I_EMB_DIR, I_EMB_FULL, key_cols=["gmap_id"])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[ok] part-000000000.parquet: 21570 rows (total 21,570)
[ok] part-000025000.parquet: 21476 rows (total 43,046)
[ok] part-000050000.parquet: 7671 rows (total 50,717)
[ok] combined 3 parts (50,717 rows) -> /content/drive/MyDrive/processed/slice_4k/aspects_users_emb_full.parquet
[done] encoded → /content/drive/MyDrive/processed/slice_4k/aspects_users_emb_full.parquet
[ok] part-000000000.parquet: 21439 rows (total 21,439)
[ok] part-000025000.parquet: 21267 rows (total 42,706)
[ok] part-000050000.parquet: 7443 rows (total 50,149)
[ok] combined 3 parts (50,149 rows) -> /content/drive/MyDrive/processed/slice_4k/aspects_items_emb_full.parquet
[done] encoded → /content/drive/MyDrive/processed/slice_4k/aspects_items_emb_full.parquet


In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import pyarrow as pa, pyarrow.parquet as pq

def _parts_count(dir_path: Path) -> int:
    return len(list(dir_path.glob("part-*.parquet")))

def _emb_dim_from_any_row(df: pd.DataFrame) -> int:
    if "emb" not in df.columns or df.empty: return 0
    # emb is stored as list[float]; grab the first non-null
    for v in df["emb"].head(100):
        if isinstance(v, (list, tuple, np.ndarray)) and len(v):
            return int(len(v))
    return 0

def print_emb_summary(full_parquet: Path, key_col: str, label_col: str = "sentiment", head_n: int = 3):
    cols = [c for c in [key_col, "emb", label_col] if c]  # keep only those that exist
    df = pd.read_parquet(full_parquet, columns=None)  # read all so we can show head nicely
    n_rows = len(df)
    n_keys = df[key_col].nunique() if key_col in df.columns else 0
    dim    = _emb_dim_from_any_row(df)

    print(f"\n=== {full_parquet.name} ===")
    print(f"Rows: {n_rows:,} | Unique {key_col}: {n_keys:,} | Embedding dim: {dim}")
    if label_col in df.columns:
        counts = df[label_col].fillna("NA").value_counts()
        print("Label counts:", dict(counts))
    print("\nSample rows:")
    display(df.head(head_n))

# ---------- PRINT RESULTS ----------
print("Users parts:", _parts_count(U_EMB_DIR))
print("Items parts:", _parts_count(I_EMB_DIR))

print_emb_summary(U_EMB_FULL, key_col="user_id", label_col="sentiment", head_n=5)
print_emb_summary(I_EMB_FULL, key_col="gmap_id", label_col="sentiment", head_n=5)

# Totals (explicit print if you want just one-liners):
u_rows = pq.read_table(U_EMB_FULL, columns=["user_id"]).num_rows
i_rows = pq.read_table(I_EMB_FULL, columns=["gmap_id"]).num_rows
print(f"\nTOTALS — Users rows: {u_rows:,} | Items rows: {i_rows:,}")


Users parts: 3
Items parts: 3

=== aspects_users_emb_full.parquet ===
Rows: 50,717 | Unique user_id: 3,426 | Embedding dim: 384
Label counts: {'Positive': np.int64(41683), 'Negative': np.int64(9034)}

Sample rows:


Unnamed: 0,user_id,aspect,sentiment,confidence,evidence,position,review_text,gmap_id,name,rating,time_s,aspect_norm,sent_num,text_for_emb,emb
0,100000587891567535744,taste,Negative,0.9964,Google ) The food has no taste ( Original ) No tiene sabor,[9],( Translated by Google ) The food has no taste ( Original ) No tiene sabor la comida,0x80dc8208f8ad8269:0x930d0b4346b897e1,Mauricio Amaya,1,,taste,-1,Google ) The food has no taste ( Original ) No tiene sabor,"[-0.008173393085598946, -0.049527186900377274, -0.07560265064239502, 0.040245603770017624, -0.034866008907556534, -0.003338743234053254, -0.0279510375112295..."
1,100000587891567535744,comida,Negative,0.9873,Original ) No tiene sabor la comida,[17],( Translated by Google ) The food has no taste ( Original ) No tiene sabor la comida,0x80dc8208f8ad8269:0x930d0b4346b897e1,Mauricio Amaya,1,,comida,-1,Original ) No tiene sabor la comida,"[-0.04850549250841141, 0.018572920933365822, -0.060179196298122406, -0.04928178712725639, -0.0226199459284544, 0.021158922463655472, 0.07908912003040314, -0..."
2,100004922652291933917,service,Positive,0.9993,Excellent service and everyone that works there is,[1],Excellent service and everyone that works there is super friendly and helpful . First time there will definitely be back . Thank you to the staff .,0x80dd315f7ecfdb09:0x80c63d40f1f66536,Greg Bunton,5,,service,1,Excellent service and everyone that works there is,"[-0.07223241031169891, 0.01844174787402153, 0.0424751453101635, -0.04825015738606453, -0.06795754283666611, -0.011735796928405762, 0.02987959049642086, -0.0..."
3,100004922652291933917,everyone,Positive,0.9994,Excellent service and everyone that works there is super friendly,[3],Excellent service and everyone that works there is super friendly and helpful . First time there will definitely be back . Thank you to the staff .,0x80dd315f7ecfdb09:0x80c63d40f1f66536,Greg Bunton,5,,everyone,1,Excellent service and everyone that works there is super friendly,"[-0.05680256336927414, 0.005493049509823322, 0.03200820833444595, -0.011370634660124779, -0.06678265333175659, -0.0023401747457683086, 0.039099372923374176,..."
4,100004922652291933917,staff,Positive,0.999,back . Thank you to the staff .,[25],Excellent service and everyone that works there is super friendly and helpful . First time there will definitely be back . Thank you to the staff .,0x80dd315f7ecfdb09:0x80c63d40f1f66536,Greg Bunton,5,,staff,1,back . Thank you to the staff .,"[-0.011562777683138847, 0.003409998258575797, -0.009822946973145008, -0.012956133112311363, 0.013000115752220154, 0.036505334079265594, 0.005260028410702944..."



=== aspects_items_emb_full.parquet ===
Rows: 50,149 | Unique gmap_id: 16,307 | Embedding dim: 384
Label counts: {'Positive': np.int64(41483), 'Negative': np.int64(8666)}

Sample rows:


Unnamed: 0,gmap_id,aspect,sentiment,confidence,evidence,position,review_text,name,rating,time_s,aspect_norm,sent_num,text_for_emb,emb
0,0x14e01eae3c43bb3b:0x45792563427359ce,flavor,Positive,0.9776,Unique and authentic flavor and the prices are awesome .,[3],Unique and authentic flavor and the prices are awesome .,Dinorah Adams,5,,flavor,1,Unique and authentic flavor and the prices are awesome .,"[-0.11170550435781479, -0.0137782022356987, 0.002261125948280096, 0.03691209852695465, -0.08656179159879684, 0.029408635571599007, 0.02692347951233387, 0.01..."
1,0x14e01eae3c43bb3b:0x45792563427359ce,prices,Positive,0.9745,Unique and authentic flavor and the prices are awesome .,[6],Unique and authentic flavor and the prices are awesome .,Dinorah Adams,5,,prices,1,Unique and authentic flavor and the prices are awesome .,"[-0.11170550435781479, -0.0137782022356987, 0.002261125948280096, 0.03691209852695465, -0.08656179159879684, 0.029408635571599007, 0.02692347951233387, 0.01..."
2,0x14e1783a55591535:0x57d86a6ec07f7a58,staff,Positive,0.9742,Great staff with great food and scenery .,[1],Great staff with great food and scenery . . .,Dan Fager,5,,staff,1,Great staff with great food and scenery .,"[0.03233940526843071, -0.0008123239385895431, 0.053897272795438766, 0.02881188504397869, -0.11186285316944122, 0.027684025466442108, 0.004093879833817482, -..."
3,0x14e1783a55591535:0x57d86a6ec07f7a58,food,Positive,0.9841,Great staff with great food and scenery . . .,[4],Great staff with great food and scenery . . .,Dan Fager,5,,food,1,Great staff with great food and scenery . . .,"[0.0354946106672287, -9.670048893895e-05, 0.053762584924697876, 0.023796088993549347, -0.10315193980932236, 0.014202915132045746, 0.02044086717069149, -0.09..."
4,0x14e1783a55591535:0x57d86a6ec07f7a58,scenery,Positive,0.9768,Great staff with great food and scenery . . .,[6],Great staff with great food and scenery . . .,Dan Fager,5,,scenery,1,Great staff with great food and scenery . . .,"[0.0354946106672287, -9.670048893895e-05, 0.053762584924697876, 0.023796088993549347, -0.10315193980932236, 0.014202915132045746, 0.02044086717069149, -0.09..."



TOTALS — Users rows: 50,717 | Items rows: 50,149


In [None]:
import numpy as np, pandas as pd
from pathlib import Path
from typing import Tuple

SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
U_EMB_FULL = SLICE/"aspects_users_emb_full.parquet"   # has columns: user_id, emb, sentiment
I_EMB_FULL = SLICE/"aspects_items_emb_full.parquet"   # has columns: gmap_id, emb, sentiment

def _l2norm(v: np.ndarray) -> float:
    n = float(np.linalg.norm(v))
    return n

def _mean_norm(X: np.ndarray) -> np.ndarray:
    """Mean then L2-normalize; if empty, return zeros."""
    if X.size == 0:
        return np.zeros((X.shape[1] if X.ndim==2 else 0,), dtype="float32")
    m = X.mean(axis=0)
    n = np.linalg.norm(m)
    return (m / (n + 1e-9)).astype("float32")

def summarize_entity_prefs(emb_full_path: Path, key_col: str, save_path: Path=None) -> pd.DataFrame:
    """
    For each entity (user or item), compute:
      - n_pos, n_neg
      - x_pos, x_neg (normalized)
      - norms of x_pos/x_neg
      - zero_pos/zero_neg flags
    """
    cols = [key_col, "emb", "sentiment"]
    df = pd.read_parquet(emb_full_path, columns=cols).dropna(subset=[key_col, "emb"])
    # normalize label strings
    df["sentiment"] = df["sentiment"].astype(str).str.lower()

    rows = []
    for ent, g in df.groupby(key_col, sort=False):
        # stack embeddings
        embs = np.vstack(g["emb"].to_numpy()).astype("float32")  # [n,d]
        lbls = g["sentiment"].to_numpy()

        pos = embs[lbls == "positive"]
        neg = embs[lbls == "negative"]

        # Aggregation:
        #   x_pos = mean(pos) normalized
        #   x_neg = mean(-neg) normalized  (flip so dislikes point away)
        x_pos = _mean_norm(pos) if pos.size else np.zeros((embs.shape[1],), dtype="float32")
        x_neg = _mean_norm(-neg) if neg.size else np.zeros((embs.shape[1],), dtype="float32")

        rows.append({
            key_col: ent,
            "n_pos": int(len(pos)),
            "n_neg": int(len(neg)),
            "x_pos_norm": _l2norm(x_pos),
            "x_neg_norm": _l2norm(x_neg),
            "zero_pos":  bool(len(pos) == 0 or _l2norm(x_pos) == 0.0),
            "zero_neg":  bool(len(neg) == 0 or _l2norm(x_neg) == 0.0),
        })

    out = pd.DataFrame(rows)

    # Print summary
    n = len(out)
    zp = out["zero_pos"].mean() if n else 0.0
    zn = out["zero_neg"].mean() if n else 0.0
    print(f"Entities: {n:,}")
    print(f"Zero-pos share: {zp:.3f} | Zero-neg share: {zn:.3f}")
    if n:
        print("x_pos_norm mean/median:", float(out["x_pos_norm"].mean()), float(out["x_pos_norm"].median()))
        print("x_neg_norm mean/median:", float(out["x_neg_norm"].mean()), float(out["x_neg_norm"].median()))
        q = out[["x_pos_norm","x_neg_norm"]].quantile([0.05,0.25,0.5,0.75,0.95])
        print("\nQuantiles:\n", q)

    if save_path is not None:
        out.to_parquet(save_path, index=False)
        print(f"[saved] {save_path}")
    return out

# ---- Run for USERS and ITEMS
user_summary = summarize_entity_prefs(U_EMB_FULL, key_col="user_id",
                                      save_path=SLICE/"user_pref_summary.parquet")
item_summary = summarize_entity_prefs(I_EMB_FULL, key_col="gmap_id",
                                      save_path=SLICE/"item_pref_summary.parquet")

# Quick headline numbers you asked for:
print("\n== Headline ==")
print("Users with zero negatives:", int(user_summary["zero_neg"].sum()),
      "/", len(user_summary), "(%.2f%%)" % (100*user_summary["zero_neg"].mean()))
print("Items with zero negatives:", int(item_summary["zero_neg"].sum()),
      "/", len(item_summary), "(%.2f%%)" % (100*item_summary["zero_neg"].mean()))
print("Avg norms — users: xpos=%.3f xneg=%.3f" %
      (user_summary["x_pos_norm"].mean(), user_summary["x_neg_norm"].mean()))
print("Avg norms — items: xpos=%.3f xneg=%.3f" %
      (item_summary["x_pos_norm"].mean(), item_summary["x_neg_norm"].mean()))


Entities: 3,426
Zero-pos share: 0.044 | Zero-neg share: 0.453
x_pos_norm mean/median: 0.9559252642951928 1.0
x_neg_norm mean/median: 0.547285457399945 0.9999999403953552

Quantiles:
       x_pos_norm  x_neg_norm
0.05         1.0         0.0
0.25         1.0         0.0
0.50         1.0         1.0
0.75         1.0         1.0
0.95         1.0         1.0
[saved] /content/drive/MyDrive/processed/slice_4k/user_pref_summary.parquet
Entities: 16,307
Zero-pos share: 0.091 | Zero-neg share: 0.727
x_pos_norm mean/median: 0.9091800982409547 1.0
x_neg_norm mean/median: 0.27295026503502867 0.0

Quantiles:
       x_pos_norm  x_neg_norm
0.05         0.0         0.0
0.25         1.0         0.0
0.50         1.0         0.0
0.75         1.0         1.0
0.95         1.0         1.0
[saved] /content/drive/MyDrive/processed/slice_4k/item_pref_summary.parquet

== Headline ==
Users with zero negatives: 1551 / 3426 (45.27%)
Items with zero negatives: 11856 / 16307 (72.70%)
Avg norms — users: xpos=0.956 xn

In [None]:
import numpy as np, pandas as pd
from pathlib import Path
import pyarrow.parquet as pq

ALPHA, BETA = 1.0, 0.75

def _l2(v):
    v = np.asarray(v, "float32"); n = np.linalg.norm(v)
    return (v/(n+1e-9)).astype("float32") if n>0 else v

def aggregate_entity_repr_weighted_norm(emb_full_path: Path, key_col: str, out_path: Path):
    schema = pq.read_schema(emb_full_path)
    names = set(schema.names)
    use_cols = [c for c in [key_col, "emb", "sentiment", "confidence"] if c in names]
    df = pd.read_parquet(emb_full_path, columns=use_cols)
    if df.empty:
        pd.DataFrame(columns=[key_col,"x_pos","x_neg","repr"]).to_parquet(out_path, index=False)
        print(f"[warn] empty → {out_path}"); return

    # normalize key as STRING (permanent fix)
    df[key_col] = df[key_col].astype(str).str.strip()

    s = df["sentiment"].astype(str).str.lower().values
    conf = (pd.to_numeric(df.get("confidence", 1.0), errors="coerce")
              .fillna(1.0).astype("float32").values)

    w = np.where(s=="positive", +1.0, np.where(s=="negative", -0.75, +0.10)).astype("float32") * conf
    keys = df[key_col].to_numpy()
    embs = np.vstack(df["emb"].to_numpy()).astype("float32")

    order = np.argsort(keys, kind="mergesort")
    keys, embs, w = keys[order], embs[order], w[order]
    splits = np.flatnonzero(keys[1:] != keys[:-1]) + 1
    starts = np.r_[0, splits]; ends = np.r_[splits, len(keys)]

    rows = []
    for a,b in zip(starts, ends):
        k, V, W = keys[a], embs[a:b], w[a:b]
        pos_mask, neg_mask = W>0, W<0

        if pos_mask.any():
            wp = W[pos_mask][:,None]
            x_pos = _l2((V[pos_mask]*wp).sum(0) / (wp.sum()+1e-9))
        else:
            x_pos = np.zeros(V.shape[1], "float32")

        if neg_mask.any():
            wn = (-W[neg_mask])[:,None]
            x_neg = _l2(((-V[neg_mask])*wn).sum(0) / (wn.sum()+1e-9))
        else:
            x_neg = np.zeros(V.shape[1], "float32")

        r = _l2(ALPHA*x_pos - BETA*x_neg) if (np.linalg.norm(ALPHA*x_pos - BETA*x_neg)>0) else x_pos
        rows.append({key_col:k, "x_pos":x_pos.tolist(), "x_neg":x_neg.tolist(), "repr":r.tolist()})

    pd.DataFrame(rows).to_parquet(out_path, index=False)
    print(f"[done] {len(rows):,} entities → {out_path}")

# Run
SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
aggregate_entity_repr_weighted_norm(SLICE/"aspects_users_emb_full.parquet", "user_id", SLICE/"user_repr.parquet")
aggregate_entity_repr_weighted_norm(SLICE/"aspects_items_emb_full.parquet", "gmap_id", SLICE/"item_repr_warm.parquet")


[done] 3,426 entities → /content/drive/MyDrive/processed/slice_4k/user_repr.parquet
[done] 16,307 entities → /content/drive/MyDrive/processed/slice_4k/item_repr_warm.parquet


In [None]:
import pandas as pd
from pathlib import Path
SLICE = Path("/content/drive/MyDrive/processed/slice_4k")

# Inputs
TRAIN_REV  = SLICE/"user_reviews_train.parquet"            # per-interaction (train)
U_ABSA     = SLICE/"aspects_users_train_full.parquet"      # ABSA rows (users)
I_ABSA     = SLICE/"aspects_items_train_full.parquet"      # ABSA rows (items)
U_EMB_FULL = SLICE/"aspects_users_emb_full.parquet"        # ABSA rows + embeddings
I_EMB_FULL = SLICE/"aspects_items_emb_full.parquet"
U_REPR     = SLICE/"user_repr.parquet"                     # aggregated users
I_REPR     = SLICE/"item_repr_warm.parquet"                # aggregated items

# Train universe
train_users = pd.read_parquet(TRAIN_REV, columns=["user_id"]).dropna()
train_items = pd.read_parquet(TRAIN_REV, columns=["gmap_id"]).dropna()
n_u_train = train_users["user_id"].nunique()
n_i_train = train_items["gmap_id"].nunique()

# ABSA universe (before encoding)
u_absa = pd.read_parquet(U_ABSA, columns=["user_id","sentiment"])
i_absa = pd.read_parquet(I_ABSA, columns=["gmap_id","sentiment"])
n_u_absa = u_absa["user_id"].nunique()
n_i_absa = i_absa["gmap_id"].nunique()

# After encoding (row level with emb)
u_emb = pd.read_parquet(U_EMB_FULL, columns=["user_id","sentiment","emb"])
i_emb = pd.read_parquet(I_EMB_FULL, columns=["gmap_id","sentiment","emb"])
n_u_emb = u_emb["user_id"].nunique()
n_i_emb = i_emb["gmap_id"].nunique()

# After aggregation (entity level)
UR = pd.read_parquet(U_REPR)
IR = pd.read_parquet(I_REPR)
n_u_repr = UR["user_id"].nunique()
n_i_repr = IR["gmap_id"].nunique()

# Why some dropped? (only-Neutral users/items)
u_only_neutral = (u_emb
    .assign(s=u_emb["sentiment"].str.lower())
    .groupby("user_id")["s"].apply(lambda x: set(x)).reset_index())
u_only_neutral = u_only_neutral[u_only_neutral["s"].apply(lambda S: S.issubset({"neutral"}))]["user_id"].nunique()

i_only_neutral = (i_emb
    .assign(s=i_emb["sentiment"].str.lower())
    .groupby("gmap_id")["s"].apply(lambda x: set(x)).reset_index())
i_only_neutral = i_only_neutral[i_only_neutral["s"].apply(lambda S: S.issubset({"neutral"}))]["gmap_id"].nunique()

print("=== Entity count funnel ===")
print(f"Users in TRAIN:                 {n_u_train:,}")
print(f"Users with any ABSA rows:       {n_u_absa:,}")
print(f"Users with encoded rows:        {n_u_emb:,}")
print(f"Users aggregated (repr):        {n_u_repr:,}")
print(f"  of which only-neutral users:  {u_only_neutral:,}")

print()
print(f"Items (warm) in TRAIN:          {n_i_train:,}")
print(f"Items with any ABSA rows:       {n_i_absa:,}")
print(f"Items with encoded rows:        {n_i_emb:,}")
print(f"Items aggregated (repr):        {n_i_repr:,}")
print(f"  of which only-neutral items:  {i_only_neutral:,}")


=== Entity count funnel ===
Users in TRAIN:                 4,000
Users with any ABSA rows:       3,863
Users with encoded rows:        3,426
Users aggregated (repr):        3,426
  of which only-neutral users:  0

Items (warm) in TRAIN:          19,585
Items with any ABSA rows:       18,918
Items with encoded rows:        16,307
Items aggregated (repr):        16,307
  of which only-neutral items:  0


In [None]:
# ================================
# Text-Only Contrastive Recommender
# (ABSA repr -> small projection heads -> InfoNCE)
# ================================
import unicodedata, re, math, json, random
import numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast


from pathlib import Path
SLICE = Path("/content/drive/MyDrive/processed/slice_4k")

# ---------- Config ----------
POS_THRESH        = 3            # rating >= POS_THRESH is positive
BATCH             = 1024         # pairs per step (users/items are just rows from repr tables)
EPOCHS            = 10
LR                = 2e-3
WD                = 1e-4
DIM               = 128          # projection dim
EXTRA_NEGS        = 256          # add this many extra item negatives (random/pop-weighted) per batch
USE_POP_SAMPLING  = True         # sample extra negatives with item popularity
AMP               = True         # mixed precision
SEED              = 7
Ks                = [10, 50]     # eval cutoffs

random.seed(SEED); np.random.default_rng(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Helpers ----------
def _norm_id(x):
    if pd.isna(x): return ""
    s = str(x)
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("\u200b","").replace("\u200c","").replace("\u200d","")
    s = re.sub(r"\s+", " ", s).strip()
    return s.lower()

def _l2_rows_np(X: np.ndarray) -> np.ndarray:
    n = np.linalg.norm(X, axis=1, keepdims=True) + 1e-9
    return (X / n).astype("float32")

def _l2_rows_t(X: torch.Tensor, eps=1e-9) -> torch.Tensor:
    return X / (X.norm(dim=1, keepdim=True) + eps)

def _dcg_binary(rel: np.ndarray) -> float:
    if rel.size == 0: return 0.0
    return (rel / np.log2(np.arange(2, 2 + rel.size))).sum()

# ---------- Load text embeddings (ABSA aggregated) ----------
UR = pd.read_parquet(SLICE/"user_repr.parquet",  columns=["user_id","repr"]).dropna()
IR = pd.read_parquet(SLICE/"item_repr_warm.parquet", columns=["gmap_id","repr"]).dropna()

UR["user_id"]  = UR["user_id"].map(_norm_id)
IR["gmap_id"]  = IR["gmap_id"].map(_norm_id)

user_ids = UR["user_id"].to_numpy()
item_ids = IR["gmap_id"].to_numpy()
U_base   = _l2_rows_np(np.vstack(UR["repr"].to_numpy()).astype("float32"))
I_base   = _l2_rows_np(np.vstack(IR["repr"].to_numpy()).astype("float32"))

uid2row  = {u:i for i,u in enumerate(user_ids)}
iid2row  = {i:j for j,i in enumerate(item_ids)}

print(f"[repr] users: {len(user_ids)} | items: {len(item_ids)}")

# ---------- Build positive (user,item) pairs from TRAIN (rating) + (optional) ABSA sign ----------
TR = pd.read_parquet(SLICE/"user_reviews_train.parquet", columns=["user_id","gmap_id","rating"]).dropna()
TR["user_id"] = TR["user_id"].map(_norm_id)
TR["gmap_id"] = TR["gmap_id"].map(_norm_id)
TR["rating"]  = TR["rating"].astype(float)

# (Optional union with ABSA sign) — quick version using already-encoded rows to avoid recompute
# If you want only ratings, comment out the ABSA block below.
UA = pd.read_parquet(SLICE/"aspects_users_train_full.parquet", columns=["user_id","gmap_id","sentiment"]).dropna()
UA["user_id"]   = UA["user_id"].map(_norm_id)
UA["gmap_id"]   = UA["gmap_id"].map(_norm_id)
UA["sentiment"] = UA["sentiment"].astype(str).str.title()
ua2 = UA[UA["sentiment"].isin(["Positive","Negative"])].copy()
ua2["s"] = ua2["sentiment"].map({"Positive": 1, "Negative": -1})
sv  = ua2.groupby(["user_id","gmap_id"])["s"].sum().reset_index()

pos_r  = TR[TR["rating"] >= POS_THRESH][["user_id","gmap_id"]]
pos_s  = sv[sv["s"] > 0][["user_id","gmap_id"]]
pos_df = pd.concat([pos_r, pos_s], ignore_index=True).drop_duplicates()

# keep only pairs where both sides have text repr vectors (can generalize beyond train)
pos_df = pos_df[pos_df["user_id"].isin(uid2row) & pos_df["gmap_id"].isin(iid2row)].copy()

# cap positives per user to avoid domination
MAX_POS_PER_USER = 50
pos_df = (pos_df.groupby("user_id", group_keys=False)
              .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=SEED))
              .reset_index(drop=True))

# convert to row indices into U_base / I_base
pairs = np.array([(uid2row[u], iid2row[i]) for u,i in zip(pos_df["user_id"], pos_df["gmap_id"])],
                 dtype=np.int64)
print(f"[train pairs] {len(pairs)} | unique users: {len(np.unique(pairs[:,0]))}")

# popularity for extra negative sampling
if USE_POP_SAMPLING:
    item_counts = pos_df["gmap_id"].value_counts()
    pop = np.zeros(len(item_ids), dtype=np.float64)
    for it, c in item_counts.items():
        j = iid2row.get(it, None)
        if j is not None: pop[j] = c
    alpha = 0.75
    pop_prob = (np.power(pop + 1.0, alpha))
    pop_prob = pop_prob / pop_prob.sum()
else:
    pop_prob = None

# ---------- Dataset / Loader ----------
class PairDS(Dataset):
    def __init__(self, pairs: np.ndarray):
        self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        uix, iix = self.pairs[idx]
        return int(uix), int(iix)

loader = DataLoader(PairDS(pairs), batch_size=BATCH, shuffle=True, drop_last=True, num_workers=0)

# ---------- Projection model (small heads on text embeddings) ----------
class TextProjector(nn.Module):
    def __init__(self, d_in, d_out=128):
        super().__init__()
        # lightweight MLPs (you can make them deeper if needed)
        self.Pu = nn.Sequential(
            nn.Linear(d_in, d_out, bias=False),
        )
        self.Pi = nn.Sequential(
            nn.Linear(d_in, d_out, bias=False),
        )
        # learnable log-temperature
        self.log_tau = nn.Parameter(torch.tensor(math.log(0.07)))

        # init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, U_batch, I_batch):
        u = _l2_rows_t(self.Pu(U_batch))
        v = _l2_rows_t(self.Pi(I_batch))
        return u, v

    def tau(self):  # keep temp sane
        return self.log_tau.exp().clamp(1e-3, 1.0)

# freeze base text vectors as tensors on device
U_base_t = torch.tensor(U_base, dtype=torch.float32, device=device)
I_base_t = torch.tensor(I_base, dtype=torch.float32, device=device)

D_in = U_base.shape[1]
model = TextProjector(d_in=D_in, d_out=DIM).to(device)
opt   = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
ce    = nn.CrossEntropyLoss()
# scaler= torch.cuda.amp.GradScaler(enabled=AMP)
scaler = GradScaler(enabled=AMP)

    # --- helpers must be defined BEFORE the training loop ---
ce = torch.nn.CrossEntropyLoss()

def info_nce(u, v, tau):
        # in-batch negatives (user->item and item->user)
        logits = (u @ v.t()) / tau
        target = torch.arange(logits.size(0), device=logits.device)
        return 0.5 * (ce(logits, target) + ce(logits.t(), target))

def add_extra_negs(u, v_pos, extra_js, tau):
        """
        u: (B,d), v_pos: (B,d) in projected space.
        extra_js: (M,) item indices for negatives (int64 on device).
        """
        if extra_js is None or extra_js.numel() == 0:
            return torch.tensor(0.0, device=u.device)
        # PROJECT the negatives to the same space as v_pos
        v_neg = _l2_rows_t(model.Pi(I_base_t[extra_js]))    # Pi: your item projector (Linear)
        v_bank = torch.cat([v_pos, v_neg], dim=0)           # (B+M, d)
        logits = (u @ v_bank.t()) / tau                     # (B, B+M)
        target = torch.arange(u.size(0), device=u.device)   # first B are positives
        return ce(logits, target)



# ---------- Train ----------
for ep in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    for uix, iix in loader:
        uix = uix.to(device, non_blocking=True)
        iix = iix.to(device, non_blocking=True)
        # sample extra item negatives once per step
        if EXTRA_NEGS > 0:
            if pop_prob is not None:
                extra_js_np = np.random.choice(len(item_ids), size=EXTRA_NEGS, replace=False, p=pop_prob)
            else:
                extra_js_np = np.random.choice(len(item_ids), size=EXTRA_NEGS, replace=False)
            extra_js = torch.tensor(extra_js_np, dtype=torch.long, device=device)
        else:
            extra_js = None

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=AMP):
            u_in = U_base_t[uix]   # (B, d_in)  text-only user repr
            v_in = I_base_t[iix]   # (B, d_in)  text-only item repr
            u, v = model(u_in, v_in)
            tau  = model.tau()

            loss_core  = info_nce(u, v, tau)
            loss_extra = add_extra_negs(u, v, extra_js, tau)
            loss = loss_core + loss_extra

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt); scaler.update()
        running += float(loss.detach())

    print(f"epoch {ep:02d} | loss {running/max(1,len(loader)):.4f}  (tau={model.tau().item():.4f})")

# ---------- Export projected embeddings ----------
with torch.no_grad():
    ZU = _l2_rows_t(model.Pu(U_base_t)).cpu().numpy().astype("float32")
    ZI = _l2_rows_t(model.Pi(I_base_t)).cpu().numpy().astype("float32")

# ---------- Full-catalog Eval (text-only; works for warm & cold as long as repr exists) ----------
TEST_REV = SLICE/"user_reviews_test.parquet"
if not TEST_REV.exists():
    print("[eval] No TEST split parquet found; skipped eval.")
else:
    test = pd.read_parquet(TEST_REV, columns=["user_id","gmap_id","rating"]).dropna()
    test["user_id"] = test["user_id"].map(_norm_id)
    test["gmap_id"] = test["gmap_id"].map(_norm_id)

    # we can evaluate ANY user/item that has a text repr (not just trained ones)
    test = test[test["user_id"].isin(uid2row) & test["gmap_id"].isin(iid2row)].copy()
    test_pos = (test[test["rating"] >= POS_THRESH]
                .groupby("user_id")["gmap_id"].apply(set).to_dict())

    eval_users = [u for u,s in test_pos.items() if len(s)>0]
    print(f"[eval] users with ≥1 positive (and repr present): {len(eval_users)}")

    # batched scoring
    ZU_eval = _l2_rows_np(ZU); ZI_eval = _l2_rows_np(ZI)
    maxK = min(max(Ks), ZI_eval.shape[0])

    hits = {k:0 for k in Ks}; rec={k:0.0 for k in Ks}; ndcg={k:0.0 for k in Ks}; mrr=0.0
    for uid in eval_users:
        uix = uid2row[uid]
        s   = (ZU_eval[uix:uix+1] @ ZI_eval.T)[0]  # (I,)
        top_idx    = np.argpartition(-s, maxK-1)[:maxK]
        top_sorted = top_idx[np.argsort(-s[top_idx])]
        top_items  = [item_ids[j] for j in top_sorted]

        pos_set = test_pos[uid]
        rel = np.fromiter((1 if it in pos_set else 0 for it in top_items), dtype=np.float32, count=len(top_items))

        for K in Ks:
            Kc = min(K, len(top_items))
            rK = rel[:Kc]
            hits[K] += int(rK.sum() > 0)
            rec[K]  += rK.sum() / max(1, len(pos_set))
            if rK.sum() > 0:
                gains = _dcg_binary(rK)
                ideal_len = min(len(pos_set), Kc)
                idcg = _dcg_binary(np.ones(ideal_len, dtype=np.float32))
                ndcg[K] += gains / max(1e-9, idcg)

        ones = np.where(rel == 1)[0]
        if len(ones): mrr += 1.0 / (ones[0] + 1)

    N = max(1, len(eval_users))
    print("\n=== TEXT-ONLY FULL-CATALOG EVAL ===")
    print("evaluated users:", len(eval_users))
    for K in Ks:
        print(f"Hit@{K}: {hits[K]/N:.4f} | Recall@{K}: {rec[K]/N:.4f} | NDCG@{K}: {ndcg[K]/N:.4f}")
    print(f"MRR: {mrr/N:.4f}")

# ---------- (Optional) Save for downstream reranker ----------
# np.save(SLICE/'ZU_textcl.npy', ZU); np.save(SLICE/'ZI_textcl.npy', ZI)
# pd.DataFrame({'user_id': user_ids}).to_parquet(SLICE/'user_ids_textcl.parquet', index=False)
# pd.DataFrame({'gmap_id': item_ids}).to_parquet(SLICE/'item_ids_textcl.parquet', index=False)


[repr] users: 3426 | items: 16307


  .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=SEED))
  scaler = GradScaler(enabled=AMP)
  with torch.cuda.amp.autocast(enabled=AMP):


[train pairs] 10588 | unique users: 2616
epoch 01 | loss 14.5696  (tau=0.0708)
epoch 02 | loss 13.6559  (tau=0.0712)
epoch 03 | loss 13.2293  (tau=0.0712)
epoch 04 | loss 12.8644  (tau=0.0710)
epoch 05 | loss 12.5109  (tau=0.0706)
epoch 06 | loss 12.2189  (tau=0.0698)
epoch 07 | loss 11.9581  (tau=0.0687)
epoch 08 | loss 11.7094  (tau=0.0674)
epoch 09 | loss 11.4762  (tau=0.0660)
epoch 10 | loss 11.2622  (tau=0.0645)
[eval] users with ≥1 positive (and repr present): 2

=== TEXT-ONLY FULL-CATALOG EVAL ===
evaluated users: 2
Hit@10: 0.0000 | Recall@10: 0.0000 | NDCG@10: 0.0000
Hit@50: 0.5000 | Recall@50: 0.5000 | NDCG@50: 0.1351
MRR: 0.0417


In [None]:
import pandas as pd
from pathlib import Path
SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
U_ABSA = pd.read_parquet(SLICE/"aspects_users_train_full.parquet", columns=["user_id","sentiment"])
I_ABSA = pd.read_parquet(SLICE/"aspects_items_train_full.parquet", columns=["gmap_id","sentiment"])

def only_neutral_count(df, idcol):
    g = (df.assign(s=df["sentiment"].astype(str).str.lower())
           .groupby(idcol)["s"].apply(lambda x: set(x)))
    return (g.apply(lambda S: S and S.issubset({"neutral"}))).sum()

print("users with only-neutral (ABSA):", only_neutral_count(U_ABSA, "user_id"))
print("items with only-neutral (ABSA):", only_neutral_count(I_ABSA, "gmap_id"))


users with only-neutral (ABSA): 101
items with only-neutral (ABSA): 1076


In [None]:
import numpy as np, pandas as pd
from pathlib import Path

SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
U_REPR       = SLICE/"user_repr.parquet"          # columns: user_id, repr
I_REPR_WARM  = SLICE/"item_repr_warm.parquet"     # columns: gmap_id, repr
I_REPR_COLD  = SLICE/"item_repr_cold.parquet"     # (optional)

UR  = pd.read_parquet(U_REPR, columns=["user_id","repr"])
IRw = pd.read_parquet(I_REPR_WARM, columns=["gmap_id","repr"])
IRc = pd.read_parquet(I_REPR_COLD, columns=["gmap_id","repr"]) if I_REPR_COLD.exists() else pd.DataFrame(columns=["gmap_id","repr"])

def to_id_vecs(df, id_col):
    ids = df[id_col].to_numpy()
    vecs = np.vstack(df["repr"].to_numpy()).astype("float32")
    # L2 just in case
    n = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9
    return ids, (vecs/n)

user_ids, user_vecs = to_id_vecs(UR,  "user_id")
item_ids_w, item_vecs_w = to_id_vecs(IRw, "gmap_id")

if not IRc.empty:
    item_ids_c, item_vecs_c = to_id_vecs(IRc, "gmap_id")
    item_ids_all  = np.concatenate([item_ids_w, item_ids_c], axis=0)
    item_vecs_all = np.vstack([item_vecs_w, item_vecs_c])
else:
    item_ids_all, item_vecs_all = item_ids_w, item_vecs_w

print("users:", len(user_ids), " | warm items:", len(item_ids_w), " | cold items:", (0 if IRc.empty else len(item_ids_c)))
D = user_vecs.shape[1]; print("embedding dim:", D)


users: 3426  | warm items: 16307  | cold items: 4995
embedding dim: 384


In [None]:
import numpy as np, pandas as pd
from pathlib import Path
from sklearn.neighbors import NearestNeighbors

SLICE = Path("/content/drive/MyDrive/processed/slice_4k")
TRAIN_REV = SLICE/"user_reviews_train.parquet"       # user_id, gmap_id, rating
U_ABSA    = SLICE/"aspects_users_train_full.parquet" # user_id, gmap_id, sentiment

UR = pd.read_parquet(SLICE/"user_repr.parquet", columns=["user_id"])
IRw= pd.read_parquet(SLICE/"item_repr_warm.parquet", columns=["gmap_id"])
u_set = set(UR["user_id"].to_numpy()); i_set = set(IRw["gmap_id"].to_numpy())

TR = pd.read_parquet(TRAIN_REV, columns=["user_id","gmap_id","rating"]).dropna()
TR["rating"] = TR["rating"].astype(float)
TR = TR[TR["user_id"].isin(u_set) & TR["gmap_id"].isin(i_set)]  # use only entities with repr

# ABSA net sentiment per (u,i)
UA = pd.read_parquet(U_ABSA, columns=["user_id","gmap_id","sentiment"])
sv = (UA[UA["sentiment"].isin(["Positive","Negative"])]
        .assign(s=UA["sentiment"].map({"Positive":1,"Negative":-1}))
        .groupby(["user_id","gmap_id"])["s"].sum().reset_index())

pos_r = TR[TR["rating"]>=3][["user_id","gmap_id"]]
pos_s = sv[sv["s"]>0][["user_id","gmap_id"]]
pos_df = pd.concat([pos_r, pos_s], ignore_index=True).drop_duplicates()
pos_df = pos_df[pos_df["user_id"].isin(u_set) & pos_df["gmap_id"].isin(i_set)]

# cap max positives per user (prevents over-represented users)
MAX_POS_PER_USER = 50
pos_pairs = (pos_df.groupby("user_id")
             .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=42))
             .reset_index(drop=True))

# maps
uid2ix = {u:i for i,u in enumerate(user_ids)}
iid2ix_w = {i:j for j,i in enumerate(item_ids_w)}

# training pairs as indices
pairs = []
for u,i in zip(pos_pairs["user_id"].to_numpy(), pos_pairs["gmap_id"].to_numpy()):
    uix = uid2ix.get(u, None); iix = iid2ix_w.get(i, None)
    if uix is not None and iix is not None:
        pairs.append((uix, iix))
pairs = np.array(pairs, dtype=np.int64)
print("train pairs:", len(pairs), " | unique users:", len(np.unique(pairs[:,0])))

# Hard negatives (precompute neighbors among warm items)
nbrs = NearestNeighbors(n_neighbors=50, metric="cosine").fit(item_vecs_w)

# Helper to get hard negatives for a given pos item idx, excluding a mask of seen idxs
def hard_neg_indices(pos_idx, seen_mask, k=2):
    dists, idxs = nbrs.kneighbors(item_vecs_w[pos_idx:pos_idx+1], n_neighbors=50, return_distance=True)
    cand = [j for j in idxs[0] if (j != pos_idx) and (not seen_mask[j])]
    return cand[:k]


train pairs: 10588  | unique users: 2616


  .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=42))


In [None]:
# --- imports ---
import numpy as np, pandas as pd
from pathlib import Path
from sklearn.neighbors import NearestNeighbors

SLICE     = Path("/content/drive/MyDrive/processed/slice_4k")
TRAIN_REV = SLICE/"user_reviews_train.parquet"        # user_id, gmap_id, rating
U_ABSA    = SLICE/"aspects_users_train_full.parquet"  # user_id, gmap_id, sentiment

# Learned representations (from your weighted aggregator):
U_REPR    = SLICE/"user_repr.parquet"                 # columns: user_id, x_pos, x_neg, repr
I_REPR_W  = SLICE/"item_repr_warm.parquet"            # columns: gmap_id, x_pos, x_neg, repr

# ---------- 1) Load reps and normalize IDs ----------
UR  = pd.read_parquet(U_REPR,  columns=["user_id","repr"]).dropna()
IRw = pd.read_parquet(I_REPR_W, columns=["gmap_id","repr"]).dropna()

UR["user_id"]   = UR["user_id"].astype(str).str.strip()
IRw["gmap_id"]  = IRw["gmap_id"].astype(str).str.strip()

# numpy pools (L2-normalize to be safe)
def _as_mat(col):
    M = np.vstack(col.to_numpy()).astype("float32")
    n = np.linalg.norm(M, axis=1, keepdims=True)
    n[n==0] = 1.0
    return (M / n).astype("float32")

user_ids   = UR["user_id"].to_numpy()
item_ids_w = IRw["gmap_id"].to_numpy()
user_vecs  = _as_mat(UR["repr"])
item_vecs_w= _as_mat(IRw["repr"])

uid2ix     = {u:i for i,u in enumerate(user_ids)}
iid2ix_w   = {i:j for j,i in enumerate(item_ids_w)}

u_set = set(user_ids.tolist())
i_set = set(item_ids_w.tolist())

print("Pools → users:", len(user_ids), " | warm items:", len(item_ids_w))

# ---------- 2) Load train interactions + ABSA; normalize IDs ----------
TR = pd.read_parquet(TRAIN_REV, columns=["user_id","gmap_id","rating"]).dropna()
TR["user_id"]  = TR["user_id"].astype(str).str.strip()
TR["gmap_id"]  = TR["gmap_id"].astype(str).str.strip()
TR["rating"]   = TR["rating"].astype(float)

UA = pd.read_parquet(U_ABSA, columns=["user_id","gmap_id","sentiment"]).dropna()
UA["user_id"]  = UA["user_id"].astype(str).str.strip()
UA["gmap_id"]  = UA["gmap_id"].astype(str).str.strip()
UA["sentiment"]= UA["sentiment"].astype(str).str.title()

# keep only entities we have vectors for
TR = TR[TR["user_id"].isin(u_set) & TR["gmap_id"].isin(i_set)].copy()
UA = UA[UA["user_id"].isin(u_set) & UA["gmap_id"].isin(i_set)].copy()

# ---------- 3) ABSA net sentiment per (u,i) ----------
ua2 = UA[UA["sentiment"].isin(["Positive","Negative"])].copy()
ua2["s"] = ua2["sentiment"].map({"Positive": 1, "Negative": -1})
sv = ua2.groupby(["user_id","gmap_id"])["s"].sum().reset_index()

# Positives: rating >= 3  UNION  ABSA sum > 0
pos_r = TR[TR["rating"] >= 3][["user_id","gmap_id"]]
pos_s = sv[sv["s"] > 0][["user_id","gmap_id"]]
pos_df = pd.concat([pos_r, pos_s], ignore_index=True).drop_duplicates()

# ---------- 4) Cap positives per user (avoid over-represented users) ----------
MAX_POS_PER_USER = 50
pos_pairs = (pos_df.groupby("user_id", group_keys=False)
             .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=42))
             .reset_index(drop=True))

# Convert to index pairs (u_idx, i_idx)
pairs = []
for u,i in zip(pos_pairs["user_id"].to_numpy(), pos_pairs["gmap_id"].to_numpy()):
    uix = uid2ix.get(u); iix = iid2ix_w.get(i)
    if (uix is not None) and (iix is not None):
        pairs.append((uix, iix))
pairs = np.array(pairs, dtype=np.int64)
print("train pairs:", len(pairs), " | unique users:", len(np.unique(pairs[:,0])))

# ---------- 5) Observed items per user (for negative sampling) ----------
obs_df = TR.drop_duplicates(["user_id","gmap_id"])
obs_by_u = {u: set(g["gmap_id"]) for u, g in obs_df.groupby("user_id")}

# ---------- 6) Hard-negatives via item-item cosine neighbors ----------
# NearestNeighbors with metric='cosine' returns cosine *distance* (1 - cos_sim) — that’s fine.
nbrs = NearestNeighbors(metric="cosine", n_neighbors=50, algorithm="auto").fit(item_vecs_w)

def hard_neg_indices(pos_idx: int, seen_idx_set: set[int], k: int = 2):
    """
    Return up to k hard-negative item indices for a given positive item index,
    excluding itself and any 'seen' item indices.
    """
    dists, idxs = nbrs.kneighbors(item_vecs_w[pos_idx:pos_idx+1], n_neighbors=50, return_distance=True)
    cand = []
    for j in idxs[0]:
        if j == pos_idx:        # skip itself
            continue
        if j in seen_idx_set:   # skip items the user has seen
            continue
        cand.append(j)
        if len(cand) >= k:
            break
    return cand

# Example: build one user’s seen index set & draw hard negatives for their first positive
if len(pairs) > 0:
    sample_uix, sample_pix = pairs[0]
    sample_uid = user_ids[sample_uix]
    seen_ids   = obs_by_u.get(sample_uid, set())
    seen_idx   = { iid2ix_w[i] for i in seen_ids if i in iid2ix_w }

    hn = hard_neg_indices(sample_pix, seen_idx, k=2)
    print("Example hard negs for first pair:", hn)


Pools → users: 3426  | warm items: 16307
train pairs: 10588  | unique users: 2616
Example hard negs for first pair: [np.int64(11949), np.int64(5151)]


  .apply(lambda g: g.sample(n=min(MAX_POS_PER_USER, len(g)), random_state=42))


In [None]:
# ==============================
# Proper Contrastive User–Item Trainer (Two-Tower, InfoNCE)
# ==============================
import math, unicodedata, re, json, random, numpy as np, pandas as pd
from dataclasses import dataclass
from collections import defaultdict, Counter

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rng = np.random.default_rng(7)
random.seed(7)
torch.manual_seed(7)

# ---------- utils ----------
def _norm_id(x):
    if pd.isna(x): return ""
    s = str(x)
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("\u200b","").replace("\u200c","").replace("\u200d","")
    s = re.sub(r"\s+", " ", s).strip()
    return s.lower()

def l2_rows(X, eps=1e-9):
    return X / (X.norm(dim=1, keepdim=True) + eps)

# ---------- config ----------
@dataclass
class CFG:
    pos_thresh: int = 3                 # rating >= pos_thresh is positive
    dim: int = 128
    batch_users: int = 1024             # users per batch
    extra_negs: int = 256               # sampled negatives per batch (added to in-batch pool)
    lr: float = 2e-3
    wd: float = 1e-4
    epochs: int = 12
    warmup: int = 1                     # warmup epochs
    max_grad_norm: float = 1.0
    item_item_lambda: float = 0.1       # set 0 to turn off auxiliary item-item contrastive
    amp: bool = True

cfg = CFG()

# ---------- load TRAIN ----------
train_df = pd.read_parquet(SLICE/"user_reviews_train.parquet", columns=["user_id","gmap_id","rating"]).dropna()
train_df["user_id"] = train_df["user_id"].map(_norm_id)
train_df["gmap_id"] = train_df["gmap_id"].map(_norm_id)
train_pos = train_df[train_df["rating"] >= cfg.pos_thresh].copy()

# maps
u2ix, i2ix = {}, {}
for u in train_pos["user_id"]:
    if u and u not in u2ix: u2ix[u] = len(u2ix)
for it in train_pos["gmap_id"]:
    if it and it not in i2ix: i2ix[it] = len(i2ix)

n_users, n_items = len(u2ix), len(i2ix)
print(f"n_users: {n_users}  n_items: {n_items}")

# per-user positives
user_pos = defaultdict(list)
for r in train_pos.itertuples(index=False):
    u = u2ix.get(r.user_id); it = i2ix.get(r.gmap_id)
    if u is not None and it is not None:
        user_pos[u].append(it)

# item popularity (for pop-weighted negative sampling)
item_counts = Counter([i for _, items in user_pos.items() for i in items])
pop = np.zeros(n_items, dtype=np.float64)
for j, c in item_counts.items(): pop[j] = c
# log-smoothed popularity distribution
alpha = 0.75
pop_prob = (np.power(pop + 1.0, alpha))
pop_prob = pop_prob / pop_prob.sum()

# ---------- dataset ----------
class UserToOnePosDS(Dataset):
    """
    Yields (u_idx, pos_item_idx). If user has multiple positives, pick one at random each __getitem__,
    so across epochs the model sees different positives per user.
    """
    def __init__(self, user_pos_map, n_users):
        self.user_pos_map = user_pos_map
        self.users = [u for u in range(n_users) if len(user_pos_map.get(u, [])) > 0]

    def __len__(self): return len(self.users)

    def __getitem__(self, idx):
        u = self.users[idx]
        pos_list = self.user_pos_map[u]
        i_pos = random.choice(pos_list) if len(pos_list) > 1 else pos_list[0]
        return u, i_pos

dataset = UserToOnePosDS(user_pos, n_users)
loader  = DataLoader(dataset, batch_size=cfg.batch_users, shuffle=True, drop_last=True, num_workers=0)

# ---------- model ----------
class TwoTower(nn.Module):
    def __init__(self, n_users, n_items, dim=128):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, dim)
        self.item_emb = nn.Embedding(n_items, dim)
        nn.init.normal_(self.user_emb.weight, std=0.02)
        nn.init.normal_(self.item_emb.weight, std=0.02)
        self.log_tau = nn.Parameter(torch.tensor(math.log(0.07)))  # learnable temp

    def forward(self, u_idx, i_idx):
        u = l2_rows(self.user_emb(u_idx))
        v = l2_rows(self.item_emb(i_idx))
        return u, v

    def all_user_vectors(self): return l2_rows(self.user_emb.weight).detach()
    def all_item_vectors(self): return l2_rows(self.item_emb.weight).detach()
    def tau(self): return self.log_tau.exp().clamp(1e-3, 1.0)

model = TwoTower(n_users, n_items, cfg.dim).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
sched = CosineAnnealingLR(opt, T_max=max(1, cfg.epochs - cfg.warmup))

ce = nn.CrossEntropyLoss()

# ---------- losses ----------
def info_nce_user_item(u, v, tau):
    # in-batch negatives
    logits = (u @ v.t()) / tau
    target = torch.arange(logits.size(0), device=logits.device)
    loss_uv = ce(logits, target)
    loss_vu = ce(logits.t(), target)
    return 0.5 * (loss_uv + loss_vu)

def add_sampled_negatives(u, v_pos, extra_neg_idx, item_table, tau):
    """
    Adds extra negatives sampled by popularity to v-bank, returns an augmented loss.
    u: (B,d), v_pos: (B,d), extra_neg_idx: (M,), item_table: Embedding
    """
    if extra_neg_idx is None or extra_neg_idx.numel() == 0:
        return torch.tensor(0.0, device=u.device)
    v_neg = l2_rows(item_table(extra_neg_idx))          # (M,d)
    # user→[pos|negs]
    v_bank = torch.cat([v_pos, v_neg], dim=0)           # (B+M,d)
    logits = (u @ v_bank.t()) / tau                     # (B, B+M)
    target = torch.arange(u.size(0), device=u.device)   # match first B as positives
    return ce(logits, target)

def item_item_aux(v_pos, pos_indices, user_pos_map, item_table, tau, samples=1):
    """
    Optional stabilization: for each anchor positive item, pull it toward another
    positive item from the same user and push away others in batch (SimCLR-ish).
    """
    if cfg.item_item_lambda <= 0:
        return torch.tensor(0.0, device=v_pos.device)
    B = v_pos.size(0)
    # find second positives per user if available
    mates = []
    for u in pos_indices:  # 'pos_indices' here is the user index for each row in batch
        items = user_pos_map.get(int(u), [])
        if len(items) > 1:
            mates.append(random.choice(items))
        else:
            mates.append(None)
    # build mask for those with mates
    pairs = [(i, j) for i, j in enumerate(mates) if j is not None]
    if not pairs: return torch.tensor(0.0, device=v_pos.device)

    mate_idx = torch.tensor([j for _, j in pairs], device=v_pos.device, dtype=torch.long)
    v_mate = l2_rows(item_table(mate_idx))      # (R,d)
    v_anchor = v_pos[[i for i,_ in pairs]]      # (R,d)

    # contrast against all items in-batch (pos) as negatives
    logits = (v_anchor @ v_pos.t()) / tau       # (R,B)
    target = torch.tensor([i for i,_ in pairs], device=v_pos.device, dtype=torch.long)
    return ce(logits, target)

# ---------- training ----------
scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

for epoch in range(1, cfg.epochs+1):
    model.train()
    running = 0.0
    # simple warmup: flat LR for warmup epochs, then cosine
    if epoch > cfg.warmup: sched.step()

    for u_idx, i_pos in loader:
        u_idx = u_idx.to(device, non_blocking=True)
        i_pos = i_pos.to(device, non_blocking=True)

        # sample extra negatives (global, popularity-weighted)
        if cfg.extra_negs > 0:
            extra_js = torch.from_numpy(
                rng.choice(n_items, size=cfg.extra_negs, replace=False, p=pop_prob)
            ).to(device)
        else:
            extra_js = None

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            u, v_pos = model(u_idx, i_pos)
            tau = model.tau()

            # core symmetric InfoNCE with in-batch negatives
            loss_core = info_nce_user_item(u, v_pos, tau)

            # add extra sampled negatives (user→item)
            loss_extra = add_sampled_negatives(u, v_pos, extra_js, model.item_emb, tau)

            # optional item-item auxiliary
            loss_i2i = item_item_aux(v_pos, u_idx, user_pos, model.item_emb, tau)

            loss = loss_core + loss_extra + cfg.item_item_lambda * loss_i2i

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
        scaler.step(opt); scaler.update()

        running += float(loss.detach())

    avg = running / max(1, len(loader))
    print(f"epoch {epoch:02d} | total_loss {avg:.4f}  (tau={model.tau().item():.4f})")

# ---------- export embeddings + maps for your eval ----------
with torch.no_grad():
    UZ = model.all_user_vectors().cpu().numpy()
    IZ = model.all_item_vectors().cpu().numpy()

user_ids_all = [None]*n_users
for u, ix in u2ix.items(): user_ids_all[ix] = u
item_ids_all = [None]*n_items
for it, j in i2ix.items(): item_ids_all[j] = it

# (optional) persist
# np.save(SLICE/"UZ.npy", UZ); np.save(SLICE/"IZ.npy", IZ)
# with open(SLICE/"u2ix.json","w") as f: json.dump(u2ix, f)
# with open(SLICE/"i2ix.json","w") as f: json.dump(i2ix, f)
print("Exported UZ/IZ + id lists — ready for eval.")


n_users: 3860  n_items: 18140
epoch 01 | total_loss 15.7902  (tau=0.0700)


  scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
  with torch.cuda.amp.autocast(enabled=cfg.amp):


epoch 02 | total_loss 15.7618  (tau=0.0701)
epoch 03 | total_loss 15.0316  (tau=0.0705)
epoch 04 | total_loss 12.8721  (tau=0.0708)
epoch 05 | total_loss 11.1598  (tau=0.0710)
epoch 06 | total_loss 9.9155  (tau=0.0710)
epoch 07 | total_loss 9.0407  (tau=0.0709)
epoch 08 | total_loss 8.4568  (tau=0.0709)
epoch 09 | total_loss 8.2062  (tau=0.0708)
epoch 10 | total_loss 8.1640  (tau=0.0708)
epoch 11 | total_loss 7.7543  (tau=0.0708)
epoch 12 | total_loss 8.0199  (tau=0.0708)
Exported UZ/IZ + id lists — ready for eval.


In [None]:
import numpy as np

# --- Coverage diagnostics for TEST (why only 5 users?) ---
TEST_REV = SLICE / "user_reviews_test.parquet"
if TEST_REV.exists():
    t = pd.read_parquet(TEST_REV, columns=["user_id","gmap_id","rating"]).dropna()
    t["user_id"] = t["user_id"].map(_norm_id)
    t["gmap_id"] = t["gmap_id"].map(_norm_id)

    t["u_in"] = t["user_id"].isin(u2ix)   # user known to the trained model
    t["i_in"] = t["gmap_id"].isin(i2ix)   # item known to the trained model
    # t["reason"] = np.select(
    #     [t["u_in"] & t["i_in"],
    #      ~t["u_in"] & t["i_in"],
    #      t["u_in"] & ~t["i_in"],
    #      ~t["u_in"] & ~t["i_in"]],
    #     ["kept (user+item warm)",
    #      "drop: user cold, item warm",
    #      "drop: user warm, item cold",
    #      "drop: user cold, item cold"]
    # )
    t["reason"] = "(unknown)"
    t.loc[ t["u_in"] &  t["i_in"], "reason"] = "kept (user+item warm)"
    t.loc[~t["u_in"] &  t["i_in"], "reason"] = "drop: user cold, item warm"
    t.loc[ t["u_in"] & ~t["i_in"], "reason"] = "drop: user warm, item cold"
    t.loc[~t["u_in"] & ~t["i_in"], "reason"] = "drop: user cold, item cold"

    print("\n=== TEST COVERAGE (pre-filter) ===")
    print(f"rows: {len(t)} | uniq users: {t['user_id'].nunique()} | uniq items: {t['gmap_id'].nunique()}")
    print(t["reason"].value_counts())
    print("users in train maps:", t.loc[t["u_in"], "user_id"].nunique(),
          "| items in train maps:", t.loc[t["i_in"], "gmap_id"].nunique())

    kept = t[t["reason"]=="kept (user+item warm)"].copy()
    kept_pos = kept[kept["rating"]>=POS_THRESH]
    print("rows kept:", len(kept), "| users kept:", kept["user_id"].nunique(),
          "| users with ≥1 positive:", kept_pos["user_id"].nunique())



=== TEST COVERAGE (pre-filter) ===
rows: 9184 | uniq users: 3552 | uniq items: 6540
reason
drop: user cold, item cold    6037
drop: user cold, item warm    3129
drop: user warm, item cold      10
kept (user+item warm)            8
Name: count, dtype: int64
users in train maps: 8 | items in train maps: 2024
rows kept: 8 | users kept: 6 | users with ≥1 positive: 5


In [None]:
# ===================== Proper Eval for Two-Tower User→Item =====================
import unicodedata, re, numpy as np, pandas as pd

# ---------------- Config ----------------
POS_THRESH = 3
Ks = [10, 50]
VERBOSE_PER_USER = False   # set True to print per-user ranks/top-10
USER_BATCH = 512           # batch scoring size for users (tune for VRAM/CPU)

# ---------------- Helpers ----------------
def _norm_id(x):
    if pd.isna(x): return ""
    s = str(x)
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("\u200b","").replace("\u200c","").replace("\u200d","")
    s = re.sub(r"\s+", " ", s).strip()
    return s.lower()

def _l2_rows(X):
    n = np.linalg.norm(X, axis=1, keepdims=True) + 1e-9
    return X / n

def _dcg_binary(rel):
    # rel: 1D array of 0/1
    if rel.size == 0: return 0.0
    return (rel / np.log2(np.arange(2, 2 + rel.size))).sum()

# ---------------- Hard alignment checks (must pass) ----------------
assert UZ.shape[0] == len(user_ids_all) == len(u2ix), \
    f"mismatch: UZ={UZ.shape[0]}, user_ids_all={len(user_ids_all)}, u2ix={len(u2ix)}"
assert IZ.shape[0] == len(item_ids_all) == len(i2ix), \
    f"mismatch: IZ={IZ.shape[0]}, item_ids_all={len(item_ids_all)}, i2ix={len(i2ix)}"

UZ_eval = _l2_rows(UZ)
IZ_eval = _l2_rows(IZ)

# ---------------- Load TEST ----------------
TEST_REV = SLICE / "user_reviews_test.parquet"
if not TEST_REV.exists():
    print("[info] No TEST split parquet found; skipped eval.")
else:
    test = pd.read_parquet(TEST_REV, columns=["user_id","gmap_id","rating"]).dropna()
    test["user_id"] = test["user_id"].map(_norm_id)
    test["gmap_id"] = test["gmap_id"].map(_norm_id)

    # warm-only filter (users+items must exist in training maps)
    warm_mask = test["user_id"].isin(u2ix) & test["gmap_id"].isin(i2ix)
    test = test[warm_mask].copy()

    # positives per user
    test_pos = (test[test["rating"] >= POS_THRESH]
                .groupby("user_id")["gmap_id"].apply(set).to_dict())

    # users we will evaluate (those with >=1 positive)
    eval_users = [u for u, s in test_pos.items() if len(s) > 0]
    if len(eval_users) == 0:
        print("evaluated users: 0")
        for K in Ks:
            print(f"Hit@{K}: 0.0000 | Recall@{K}: 0.0000 | NDCG@{K}: 0.0000")
        print("MRR: 0.0000")
    else:
        # ------------- Full-catalog ranking (batched users) -------------
        hits = {k: 0 for k in Ks}
        rec  = {k: 0.0 for k in Ks}
        ndcg = {k: 0.0 for k in Ks}
        mrr  = 0.0

        # vectorize: pre-build arrays for fast lookup
        u_indices = np.array([u2ix[u] for u in eval_users], dtype=np.int32)
        maxK = min(max(Ks), IZ_eval.shape[0])

        for start in range(0, len(u_indices), USER_BATCH):
            batch_uix = u_indices[start:start+USER_BATCH]
            # (B, I) scores
            scores = UZ_eval[batch_uix] @ IZ_eval.T

            # for each user in batch, compute metrics
            for bi, uix in enumerate(batch_uix):
                uid = eval_users[start + bi]
                pos_set = test_pos.get(uid, set())
                if not pos_set:
                    continue

                # top-K indices for this user's scores
                s = scores[bi]
                top_idx = np.argpartition(-s, maxK-1)[:maxK]
                top_sorted = top_idx[np.argsort(-s[top_idx])]
                top_items = [item_ids_all[j] for j in top_sorted]

                # rel@K
                rel = np.fromiter((1 if it in pos_set else 0 for it in top_items),
                                  dtype=np.float32, count=len(top_items))

                for K in Ks:
                    Kc = min(K, len(top_items))
                    rK = rel[:Kc]
                    hits[K] += int(rK.sum() > 0)
                    rec[K]  += rK.sum() / max(1, len(pos_set))
                    if rK.sum() > 0:
                        gains = _dcg_binary(rK)
                        ideal_len = min(len(pos_set), Kc)
                        idcg = _dcg_binary(np.ones(ideal_len, dtype=np.float32))
                        ndcg[K] += gains / max(1e-9, idcg)

                # MRR (from top list only)
                ones = np.where(rel == 1)[0]
                if len(ones):
                    mrr += 1.0 / (ones[0] + 1)

                # optional verbose: show a quick per-user preview
                if VERBOSE_PER_USER:
                    print("\n--- USER ---")
                    print(f"user_id: {uid} | uix: {uix}")
                    print(f"#positives in TEST: {len(pos_set)}")
                    print("Top-10 predicted (id, rel, score):")
                    showK = min(10, len(top_sorted))
                    for r in range(showK):
                        j = top_sorted[r]
                        print(f"  {r+1:>2}. {item_ids_all[j]} | rel={1 if item_ids_all[j] in pos_set else 0} | score={s[j]:.4f}")

        N = max(1, len(eval_users))
        print("\n=== FULL-CATALOG SUMMARY ===")
        print("evaluated users:", len(eval_users))
        for K in Ks:
            print(f"Hit@{K}: {hits[K]/N:.4f} | Recall@{K}: {rec[K]/N:.4f} | NDCG@{K}: {ndcg[K]/N:.4f}")
        print(f"MRR: {mrr/N:.4f}")

        # ---------------- 1+100 negative sanity eval ----------------
        rng = np.random.default_rng(7)
        def small_eval(uid, pos_set, negatives=100):
            uix = u2ix.get(uid)
            if uix is None or not pos_set:
                return None
            pos = next(iter(pos_set))
            j_pos = i2ix.get(pos)
            if j_pos is None:
                return None
            all_js = np.arange(IZ_eval.shape[0])
            mask = np.ones_like(all_js, dtype=bool)
            for it in pos_set:
                j = i2ix.get(it, None)
                if j is not None: mask[j] = False
            neg_pool = all_js[mask]
            if len(neg_pool) == 0:
                return None
            neg_js = rng.choice(neg_pool, size=min(negatives, len(neg_pool)), replace=False)
            cand_js = np.concatenate([[j_pos], neg_js])
            scores = (UZ_eval[[uix]] @ IZ_eval[cand_js].T)[0]
            return 1 + int((scores > scores[0]).sum())

        ranks = [r for uid in eval_users if (r := small_eval(uid, test_pos[uid])) is not None]
        if ranks:
            print(f"\n[small-eval] median rank of the positive among 1+100 candidates: {int(np.median(ranks))} (lower is better)")

        # ---------------- Popularity baseline (from TRAIN, warm-only) ----------------
        TRAIN_REV = SLICE / "user_reviews_train.parquet"
        if TRAIN_REV.exists():
            tr = pd.read_parquet(TRAIN_REV, columns=["user_id","gmap_id","rating"]).dropna()
            tr["user_id"] = tr["user_id"].map(_norm_id)
            tr["gmap_id"] = tr["gmap_id"].map(_norm_id)
            tr = tr[tr["user_id"].isin(u2ix) & tr["gmap_id"].isin(i2ix)]
            tr_pos = tr[tr["rating"] >= POS_THRESH]
            # global top-K popular items (train-only)
            pop_counts = tr_pos["gmap_id"].value_counts()
            pop_items = pop_counts.index.tolist()

            hits_b = {k: 0 for k in Ks}
            rec_b  = {k: 0.0 for k in Ks}
            ndcg_b = {k: 0.0 for k in Ks}
            mrr_b  = 0.0

            for uid in eval_users:
                pos_set = test_pos[uid]
                # same pop list for every user
                top_items = pop_items[:max(Ks)]
                rel = np.fromiter((1 if it in pos_set else 0 for it in top_items),
                                  dtype=np.float32, count=len(top_items))
                for K in Ks:
                    Kc = min(K, len(top_items))
                    rK = rel[:Kc]
                    hits_b[K] += int(rK.sum() > 0)
                    rec_b[K]  += rK.sum() / max(1, len(pos_set))
                    if rK.sum() > 0:
                        gains = _dcg_binary(rK)
                        ideal_len = min(len(pos_set), Kc)
                        idcg = _dcg_binary(np.ones(ideal_len, dtype=np.float32))
                        ndcg_b[K] += gains / max(1e-9, idcg)
                ones = np.where(rel == 1)[0]
                if len(ones):
                    mrr_b += 1.0 / (ones[0] + 1)

            print("\n=== TRAIN-POPULARITY BASELINE ===")
            for K in Ks:
                print(f"[pop] Hit@{K}: {hits_b[K]/N:.4f} | Recall@{K}: {rec_b[K]/N:.4f} | NDCG@{K}: {ndcg_b[K]/N:.4f}")
            print(f"[pop] MRR: {mrr_b/N:.4f}")
        else:
            print("\n[baseline] train parquet not found; skipped popularity baseline.")



=== FULL-CATALOG SUMMARY ===
evaluated users: 5
Hit@10: 0.4000 | Recall@10: 0.2000 | NDCG@10: 0.2453
Hit@50: 0.4000 | Recall@50: 0.2000 | NDCG@50: 0.2453
MRR: 0.4000

[small-eval] median rank of the positive among 1+100 candidates: 12 (lower is better)

=== TRAIN-POPULARITY BASELINE ===
[pop] Hit@10: 0.2000 | Recall@10: 0.2000 | NDCG@10: 0.0602
[pop] Hit@50: 0.2000 | Recall@50: 0.2000 | NDCG@50: 0.0602
[pop] MRR: 0.0222


In [None]:
# Who are the 6 warm users and why the rest are cold?
test = pd.read_parquet(TEST_REV, columns=["user_id","gmap_id","rating"]).dropna()
test["user_id"] = test["user_id"].map(_norm_id)
test["gmap_id"] = test["gmap_id"].map(_norm_id)

test["u_in"] = test["user_id"].isin(u2ix)
test["i_in"] = test["gmap_id"].isin(i2ix)
print("warm users list:", sorted(test.loc[test["u_in"], "user_id"].unique())[:10])
print("example cold users:", sorted(test.loc[~test["u_in"], "user_id"].unique())[:10])
print("example cold items:", sorted(test.loc[~test["i_in"], "gmap_id"].unique())[:10])


warm users list: ['100718639385225116518', '101020057175139692923', '102893998785939432128', '108232646927166938414', '109763093167237806406', '114717559460691055995', '116724074294721483116', '117808329473688233635']
example cold users: ['100005985626504485510', '100024364330312454916', '100025111743832471949', '100032470571192285287', '100035184660344460481', '100039483638572224445', '100045004894534276535', '100049036307915726348', '100052956792307367209', '100054775494099126892']
example cold items: ['0x14e19dc4f71fdc05:0x7a6d8c1d34a6d329', '0x15320f854f19320b:0xe97092cc4ac072f4', '0x15326cb1b101a6b9:0x7d08e9b1c9e03f24', '0x54cb90abba160855:0xf995f7dec50e6255', '0x54cdc242c6f5c1c7:0x164e00d745cf6861', '0x54cde6d8ae4f49eb:0x248e5e16becd5ab5', '0x54d15560bcc75ad3:0x2380fe272ebb9d5c', '0x54d155635394ce8f:0x2039b6fdd06aa2ff', '0x54d1556373350d17:0xdaad894bad675c7f', '0x54d1557c0cc636b1:0xd72cf5c459d9795e']
