# item–item cf using ease https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.ease.recom_ease

In [1]:

import pandas as pd
import numpy as np
import cornac
from cornac.data import Dataset
from cornac.eval_methods import RatioSplit
from cornac.hyperopt import GridSearch, Discrete
from cornac.metrics import Recall, NDCG, Precision, MAP, AUC
from cornac.models.ease import EASE
from tqdm.auto import trange


print("cornac:", cornac.__version__)

cornac: 2.3.5


In [2]:
import pandas as pd

def load_filtered_transactions(
    trans_path="/workspace/data/processed/transactions_clean.parquet",
    avail_path="/workspace/data/processed/articles_for_recs.parquet",
    bad_ids={"12025DK", "12025FI", "12025NO", "12025SE", "970300", "459978"},
    cols=("shopUserId", "orderId", "groupId"),
):
    global df 
    df = pd.read_parquet(trans_path, columns=list(cols))
    avail_df = pd.read_parquet(avail_path)

    gid = df["groupId"].astype(str).str.strip()
    avail_ids = set(avail_df["groupId"].astype(str).str.strip().unique())

    df = df.loc[gid.isin(avail_ids) & ~gid.isin(bad_ids)].reset_index(drop=True)
    return df

df = load_filtered_transactions()

In [3]:
df

Unnamed: 0,shopUserId,orderId,groupId
0,943483,902721,260257
1,943480,902718,280034
2,943480,902718,290150
3,943480,902718,291294
4,943480,902718,292359
...,...,...,...
262693,110507,166445,240012
262694,252853,166428,260345
262695,252853,166428,239301
262696,252844,166420,263855


Aggregate to user–item level;

In [4]:
import pandas as pd

def make_user_item_pairs(
    df: pd.DataFrame,
    user_col: str = "shopUserId",
    order_col: str = "orderId",
    item_col: str = "groupId",
    pref_value: float = 1.0,
) -> pd.DataFrame:
    """
    Deduplicate within the same order, then reduce to unique (user, item) pairs.
    Returns a DataFrame with [user_col, item_col, "pref"].
    """
    pairs = (
        df.drop_duplicates(subset=[user_col, order_col, item_col])
          .drop_duplicates(subset=[user_col, item_col])
          [[user_col, item_col]]
          .copy()
    )
    pairs["pref"] = float(pref_value)  # binary preference
    return pairs

pairs = make_user_item_pairs(df)

In [5]:
pairs

Unnamed: 0,shopUserId,groupId,pref
0,943483,260257,1.0
1,943480,280034,1.0
2,943480,290150,1.0
3,943480,291294,1.0
4,943480,292359,1.0
...,...,...,...
262691,252879,430037,1.0
262692,252874,261706,1.0
262693,110507,240012,1.0
262694,252853,260345,1.0


In [6]:
from itertools import combinations
import pandas as pd

def product_pair_user_counts(
    pairs: pd.DataFrame,
    user_col: str = "shopUserId",
    item_col: str = "groupId",
) -> pd.DataFrame:
    """
    For each unique user, compute all unordered pairs of unique items.
    Count how many distinct users have each product pair.
    Returns DataFrame: ['groupId_a', 'groupId_b', 'distinct_users']
    """
    ui = pairs[[user_col, item_col]].drop_duplicates()
    combos = (
        ui.groupby(user_col)[item_col]
          .apply(lambda s: list(combinations(sorted(s.unique()), 2)))
          .explode()
          .dropna()
          .reset_index(name="pair")
    )
    combos[["groupId_a", "groupId_b"]] = pd.DataFrame(
        combos["pair"].tolist(), index=combos.index
    )
    combos = combos.drop(columns="pair")
    product_pair_counts = (
        combos.drop_duplicates([user_col, "groupId_a", "groupId_b"])
              .groupby(["groupId_a", "groupId_b"])[user_col]
              .nunique()
              .reset_index(name="distinct_users")
              .sort_values("distinct_users", ascending=False)
              .reset_index(drop=True)
    )
    return product_pair_counts


product_pairs_df = product_pair_user_counts(pairs)


In [7]:
import pandas as pd

def filter_pairs_by_popular_pairs(
    pairs: pd.DataFrame,
    product_pairs_df: pd.DataFrame,
    *,
    user_col: str = "shopUserId",
    item_col: str = "groupId",
    min_distinct_users: int = 25,
    require_min_items_per_user: int | None = 2,
    inplace: bool = True,
) -> pd.DataFrame:
    """
    Keep only rows in `pairs` whose item is part of any (item,item) pair with
    at least `min_distinct_users` distinct users. Optionally require each user
    to retain at least `require_min_items_per_user` items (None to skip).
    
    If `inplace=True`, mutates `pairs` (drops rows) and returns it.
    If `inplace=False`, returns a filtered copy.
    """
    qual = product_pairs_df.loc[
        product_pairs_df["distinct_users"] > min_distinct_users,
        [f"{item_col}_a", f"{item_col}_b"]
    ]

    allowed_items = pd.unique(pd.concat([qual[f"{item_col}_a"], qual[f"{item_col}_b"]], ignore_index=True))

    filtered = pairs[pairs[item_col].isin(allowed_items)].copy()

    if require_min_items_per_user is not None:
        users_keep = (
            filtered[[user_col, item_col]].drop_duplicates()
            .groupby(user_col)[item_col].nunique()
            .loc[lambda s: s >= require_min_items_per_user].index
        )
        filtered = filtered[filtered[user_col].isin(users_keep)].copy()

    if inplace:
        pairs.drop(index=pairs.index.difference(filtered.index), inplace=True)
        return pairs
    else:
        return filtered

filter_pairs_by_popular_pairs(pairs, product_pairs_df, inplace=True)

Unnamed: 0,shopUserId,groupId,pref
0,943483,260257,1.0
1,943480,280034,1.0
2,943480,290150,1.0
3,943480,291294,1.0
4,943480,292359,1.0
...,...,...,...
262672,79866,260513,1.0
262681,927364,260230,1.0
262682,927364,210338,1.0
262684,927364,217467,1.0


In [8]:
def filter_pairs_by_item_frequency(
    pairs, item_col="groupId", q_low=0, q_high=0.96, inclusive="both", return_stats=False
):  # quantile filter
    gid = pairs[item_col].astype(str).str.strip()
    counts = gid.value_counts()
    low, high = counts.quantile([q_low, q_high])
    mask = gid.map(counts).between(low, high, inclusive=inclusive)
    filtered = pairs[mask].reset_index(drop=True)

    if not return_stats:
        return filtered

    stats = {
        "low": float(low),
        "high": float(high),
        "groups_total": int(counts.size),
        "groups_kept": int(counts.between(low, high, inclusive=inclusive).sum()),
        "rows_total": int(len(pairs)),
        "rows_kept": int(mask.sum()),
        "gid_counts": counts,  # included only for your print if you want the exact formula
    }
    return filtered, stats

pairs, s = filter_pairs_by_item_frequency(pairs, return_stats=True)

In [9]:


gid_counts = s["gid_counts"]
low, high = s["low"], s["high"]

print(
    f"Groups kept: {(gid_counts >= low).sum() - (gid_counts > high).sum()} (of {gid_counts.size})\n"
    f"Rows kept: {s['rows_kept']} (of {s['rows_total']})\n"
    f"Count thresholds: low={low:.0f}, high={high:.0f}"
)


Groups kept: 294 (of 307)
Rows kept: 134249 (of 163796)
Count thresholds: low=59, high=1485


In [10]:
# Cornac triplets
def to_user_item_ratings(pairs, user_col="shopUserId", item_col="groupId", pref_col="pref"):
    return list(zip(pairs[user_col].astype(str),
                    pairs[item_col].astype(str),
                    pairs[pref_col].astype(float)))

uir = to_user_item_ratings(pairs)
len(uir)

134249

In [11]:
def make_splits(uir, test_size=0.1, seed=42, verbose=True, exclude_unknowns=True, print_stats=True):
    rs = RatioSplit(
        data=uir,
        test_size=test_size,
        exclude_unknowns=exclude_unknowns,
        seed=seed,
        verbose=verbose,
    )
    if print_stats:
        print("Users:", rs.train_set.num_users, "Items:", rs.train_set.num_items)
    return rs

rs = make_splits(uir)

rating_threshold = 1.0
exclude_unknowns = True
---
Training data:
Number of users = 35731
Number of items = 294
Number of ratings = 120824
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 35731
Number of items = 294
Number of ratings = 12735
Number of unknown users = 0
Number of unknown items = 0
---
Total users = 35731
Total items = 294
Users: 35731 Items: 294


In [12]:
import tempfile, shutil

def run_experiment(rs, k=20, verbose=False):
    metrics = [Recall(k=k), Precision(k=k), NDCG(k=k), MAP()]
    ease = EASE(name="EASE", verbose=False)  # use default parameters

    tmpdir = tempfile.mkdtemp()
    try:
        exp = cornac.Experiment(
            eval_method=rs,
            models=[
                cornac.models.MostPop(),
                cornac.models.BPR(learning_rate=0.1, lambda_reg=0.1, max_iter=500, name="BPR-default"),
                ease,
            ],
            metrics=metrics,
            verbose=verbose,
            show_validation=False,
            save_dir=tmpdir,  # logs go here
        )
        exp.run()
        return exp
    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)  # cleans logs/models


exp = run_experiment(rs)



[MostPop] Training started!

[MostPop] Evaluation started!


Ranking:   0%|          | 0/10182 [00:00<?, ?it/s]


[BPR-default] Training started!

[BPR-default] Evaluation started!


Ranking:   0%|          | 0/10182 [00:00<?, ?it/s]


[EASE] Training started!

[EASE] Evaluation started!


Ranking:   0%|          | 0/10182 [00:00<?, ?it/s]


TEST:
...
            |    MAP | NDCG@20 | Precision@20 | Recall@20 | Train (s) | Test (s)
----------- + ------ + ------- + ------------ + --------- + --------- + --------
MostPop     | 0.0501 |  0.0687 |       0.0106 |    0.1709 |    0.0040 |   1.4483
BPR-default | 0.0896 |  0.1298 |       0.0183 |    0.2994 |    2.1026 |   5.5003
EASE        | 0.1329 |  0.1786 |       0.0216 |    0.3517 |    0.0299 |   1.9086



# Train EASE on the full dataset
final_ease_full, rs_full = finalize_ease(uir)


In [13]:
import numpy as np
import pandas as pd
from cornac.data import Dataset
from cornac.models import EASE

def build_ease_topk_wide(
    uir,
    rel_min: float = 0.50,
    k_min: int = 4,
    k_max: int = 10,
    out_path: str | None = "/workspace/data/processed/basket_completion.parquet",
) -> pd.DataFrame:
    """Train EASE on (user,item,pref) and return a wide Top-K table.
    Keeps neighbors with score >= rel_min * row_max; drops rows with < k_min recs.
    Saves to Parquet if out_path is provided."""
    train_set = Dataset.from_uir(uir)
    model = EASE(verbose=True)
    model.fit(train_set)

    item_ids = train_set.item_ids
    B = model.B.astype(np.float32, copy=True)
    np.fill_diagonal(B, np.nan)
    B_df = pd.DataFrame(B, index=item_ids, columns=item_ids)

    def _pack_row(s: pd.Series) -> pd.Series:
        v = s.dropna()
        if v.empty:
            return pd.Series({"Product ID": s.name})
        mx = v.max()
        v = v[(v > 0) & (v >= mx * rel_min)].nlargest(k_max)
        if len(v) < k_min:
            return pd.Series({"Product ID": s.name})
        row = {"Product ID": s.name}
        for i, (it, sc) in enumerate(v.items(), start=1):
            row[f"Top {i}"] = it
            row[f"Score {i}"] = float(sc)
        for j in range(len(v) + 1, k_max + 1):
            row[f"Top {j}"] = pd.NA
            row[f"Score {j}"] = np.nan
        return pd.Series(row)

    wide = B_df.apply(_pack_row, axis=1)

    must_have = [f"Top {i}" for i in range(1, k_min + 1)]
    wide = wide[wide[must_have].notna().all(axis=1)].reset_index(drop=True)

    cols = ["Product ID"] + [x for i in range(1, k_max + 1) for x in (f"Top {i}", f"Score {i}")]
    wide = wide.reindex(columns=cols)

    top_cols   = [f"Top {i}" for i in range(1, k_max + 1)]
    score_cols = [f"Score {i}" for i in range(1, k_max + 1)]
    wide[top_cols] = wide[top_cols].astype("string")
    wide[score_cols] = wide[score_cols].astype("Float32")

    if out_path:
        wide.to_parquet(out_path, index=False)
    return wide


# Usage
wide = build_ease_topk_wide(
    uir,
    rel_min=0.50,
    k_min=4,
    k_max=10,
    out_path="/workspace/data/processed/basket_completion.parquet",
)


In [14]:
wide

Unnamed: 0,Product ID,Top 1,Score 1,Top 2,Score 2,Top 3,Score 3,Top 4,Score 4,Top 5,...,Top 6,Score 6,Top 7,Score 7,Top 8,Score 8,Top 9,Score 9,Top 10,Score 10
0,260257,261924,0.063252,260182,0.058383,260930,0.052023,261040,0.046962,,...,,,,,,,,,,
1,280034,261436,0.025692,265041,0.025252,261699,0.016718,281675,0.014822,218982,...,210727,0.013326,280055,0.013153,,,,,,
2,291294,290232,0.037895,290246,0.034699,291278,0.027567,290150,0.027049,290153,...,291252,0.0209,290149,0.019764,,,,,,
3,264275,260223,0.032345,260949,0.029868,260463,0.026803,261427,0.023683,261595,...,260620,0.019878,218982,0.016718,261476,0.016186,,,,
4,260976,260935,0.052049,260950,0.04956,260968,0.039501,260949,0.026409,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
126,240012,241091,0.029768,240108,0.025992,241687,0.02328,241653,0.019384,242511,...,210730,0.017399,,,,,,,,
127,240108,241091,0.045856,210727,0.036786,242511,0.033488,240279,0.028137,240191,...,242214,0.023407,,,,,,,,
128,270586,218982,0.027422,261426,0.025542,261012,0.023191,270312,0.022186,242289,...,261699,0.014271,260223,0.014194,,,,,,
129,261612,261610,0.048616,261618,0.048519,261616,0.036385,261608,0.031624,261990,...,261620,0.027608,,,,,,,,
