In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
from core.ds import RatingsDS
from core.rec import MF, log_all_embeddings, log_items_with_metadata
from core.utils import benchmark_loader, print_top

In [3]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Read csv and index encode

In [67]:

# ---------- 1) Read CSV and index-encode IDs ----------
csv_path = "../data/anime-rec/rating.csv"
anime_path = "../data/anime-rec/anime.csv"
df = pd.read_csv(csv_path)
anime_df = pd.read_csv(anime_path)


In [65]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7813737 entries, 0 to 7813736
Data columns (total 3 columns):
 #   Column    Dtype
---  ------    -----
 0   user_id   int64
 1   anime_id  int64
 2   rating    int64
dtypes: int64(3)
memory usage: 178.8 MB


In [69]:
# Only keep explicit ratings
explicit_df = df.loc[(df["rating"] != -1) & (df["anime_id"] != 30913), :].copy()
len(df), len(explicit_df)

(7813737, 6337239)

In [70]:
explicit_df.describe()

Unnamed: 0,user_id,anime_id,rating
count,6337239.0,6337239.0,6337239.0
mean,36747.91,8902.859,7.808497
std,21013.41,8881.992,1.572496
min,1.0,1.0,1.0
25%,18984.0,1239.0,7.0
50%,36815.0,6213.0,8.0
75%,54873.0,14075.0,9.0
max,73516.0,34475.0,10.0


In [71]:

# Map raw IDs to contiguous 0..N-1 indices
explicit_df["user_idx"], user_uniques = pd.factorize(explicit_df["user_id"])
explicit_df["item_idx"], item_uniques = pd.factorize(explicit_df["anime_id"])

num_users = len(user_uniques)
num_items = len(item_uniques)
# Optional: normalize ratings (helps if different scales)
ratings = explicit_df["rating"].astype(np.float32).values
users = explicit_df["user_idx"].astype(np.int64).values
items = explicit_df["item_idx"].astype(np.int64).values


In [72]:
def map_id_idx(id):
    item_uniq = list(item_uniques)
    try:
        enc = item_uniq.index(id)
        return enc
    except:
        # Embedding idx of unk
        return len(item_uniques) + 1

In [73]:
anime_df["anime_idx"] = anime_df["anime_id"].map(map_id_idx)

In [74]:
explicit_df[~explicit_df["item_idx"].isin(anime_df["anime_idx"])]

Unnamed: 0,user_id,anime_id,rating,user_idx,item_idx


In [76]:
len(anime_df) - 2368

9926

In [77]:
explicit_df["item_idx"].nunique()

9926

In [78]:
num_items

9926

In [79]:
len(item_uniques), len(user_uniques)

(9926, 69600)

In [40]:
anime_df.loc[(~anime_df["anime_id"].isin(item_uniques)) & (anime_df["anime_idx"] == 9928), :].sort_values(by = "anime_idx")

Unnamed: 0,anime_id,name,genre,type,episodes,rating,members,anime_idx
31,32983,Natsume Yuujinchou Go,"Drama, Fantasy, Shoujo, Slice of Life, Superna...",TV,13,8.76,38865,9928
62,32995,Yuri!!! on Ice,"Comedy, Sports",TV,12,8.61,103178,9928
74,21,One Piece,"Action, Adventure, Comedy, Drama, Fantasy, Sho...",TV,Unknown,8.58,504862,9928
76,31933,JoJo no Kimyou na Bouken: Diamond wa Kudakenai,"Action, Adventure, Comedy, Drama, Shounen, Sup...",TV,39,8.57,74074,9928
140,10937,Mobile Suit Gundam: The Origin,"Action, Mecha, Military, Sci-Fi, Shounen, Space",OVA,6,8.42,15420,9928
...,...,...,...,...,...,...,...,...
12282,34388,Shikkoku no Shaga The Animation,Hentai,OVA,Unknown,,195,9928
12283,29992,Silent Chaser Kagami,Hentai,OVA,1,4.95,112,9928
12284,26031,Super Erotic Anime,Hentai,OVA,2,4.45,118,9928
12285,34399,Taimanin Asagi 3,"Demons, Hentai, Supernatural",OVA,Unknown,,485,9928


In [81]:
anime_meta_df = anime_df.loc[anime_df["anime_id"].isin(item_uniques), :].sort_values(by = "anime_idx")

In [82]:
len(anime_meta_df)

9926

# Dataset

In [83]:
batch_size = 8192 * 4
dataset = RatingsDS(users, items, ratings)

# Train/val split (e.g., 90/10)
val_size = max(1, int(0.1 * len(dataset)))
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True, persistent_workers=True, prefetch_factor=2)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=6, pin_memory=True, persistent_workers=True, prefetch_factor=2)

# Training

In [84]:
# ---------- 4) Training ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MF(num_users, num_items, 32).to(device)

In [85]:
# If your ratings are in [1, 5], MSE is common; add weight decay for regularization
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-5)

def rmse(loader):
    model.eval()
    se, n = 0.0, 0
    with torch.no_grad():
        for u, i, r in loader:
            u, i, r = u.to(device), i.to(device), r.to(device)
            pred = model(u, i)
            se += torch.sum((pred - r) ** 2).item()
            n += r.numel()
    return (se / n) ** 0.5

def eval_mse_rmse(loader):
    model.eval()
    se, n = 0.0, 0
    with torch.no_grad():
        for u, i, r in loader:
            u, i, r = u.to(device), i.to(device), r.to(device)
            pred = model(u, i)
            se += torch.sum((pred - r) ** 2).item()
            n += r.numel()
    mse = se / max(1, n)
    return mse, mse ** 0.5


In [86]:
for u, i, r in train_loader:
    u, i, r = u.to(device), i.to(device), r.to(device)
    pred = model(u, i)
    print(pred)
    break

tensor([ 4.3594,  2.1768,  5.6247,  ..., -1.6068, -6.0129,  1.7296],
       device='cuda:0', grad_fn=<AddBackward0>)


In [87]:
from torch.utils.tensorboard import SummaryWriter
import time

# ---- before training loop ----
run_name = f"MF_d{64}_{int(time.time())}"   # tweak as you like
writer = SummaryWriter(log_dir=f"runs/MF/{run_name}")

In [88]:
# right after creating `writer`, before the loop
try:
    dummy_u = torch.randint(0, num_users, (1024,), dtype=torch.long, device=device)
    dummy_i = torch.randint(0, num_items, (1024,), dtype=torch.long, device=device)
    writer.add_graph(model, (dummy_u, dummy_i))
except Exception:
    pass  # graph tracing can fail on some setups; safe to ignore


In [None]:

# epochs = 15
global_step = 0
epochs = 10
proj_every = 1
for ep in range(1, epochs + 1):
    model.train()
    running_se, n_seen = 0.0, 0

    bar = tqdm(train_loader, desc=f"Epoch {ep}", leave=False)

    for i, (u, i, r) in enumerate(bar, 1):
        u, i, r = u.to(device), i.to(device), r.to(device)
        pred = model(u, i)
        loss = criterion(pred, r)
        optimizer.zero_grad()
        loss.backward()
         # per-step logging (optional but nice to have)
        writer.add_scalar("Loss/train_step", loss.item(), global_step)
        global_step += 1
        optimizer.step()
      # accumulate for epoch metrics
        with torch.no_grad():
            running_se += torch.sum((pred - r) ** 2).item()
            n_seen += r.numel()

    # ---- end of epoch: compute metrics ----
    train_mse = running_se / max(1, n_seen)
    val_mse, val_rmse = eval_mse_rmse(val_loader)

    # scalars
    writer.add_scalar("Loss/train_epoch", train_mse, ep)
    writer.add_scalar("Loss/val_epoch",   val_mse,   ep)
    writer.add_scalar("RMSE/val_epoch",   val_rmse,  ep)

    # (optional) learning rate
    writer.add_scalar("LR", optimizer.param_groups[0]["lr"], ep)

    # histograms of params (embeddings + biases)
    writer.add_histogram("embeddings/user_factors", model.P.weight, ep)
    writer.add_histogram("embeddings/item_factors", model.Q.weight, ep)
    if getattr(model, "user_bias", None) is not None:
        writer.add_histogram("bias/user_bias",  model.user_bias.weight, ep)
        writer.add_histogram("bias/item_bias",  model.item_bias.weight, ep)
        # global bias as a scalar (single value)
        # writer.add_scalar("bias/global_bias", model.global_bias.item(), ep)

    # (optional) weight norms (quick health check)
    with torch.no_grad():
        writer.add_scalar("Norms/user_factors_L2", model.P.weight.norm(p=2).item(), ep)
        writer.add_scalar("Norms/item_factors_L2", model.Q.weight.norm(p=2).item(), ep)
    
    # Add embeddings
    if global_step % proj_every == 0:
        # If factorize: user_ids=user_uniques, item_ids=item_uniques
        log_all_embeddings(writer, model, global_step)
        log_items_with_metadata(writer, model, anime_meta_df, global_step)


    print(f"Epoch {ep:02d} | train MSE {train_mse:.4f} | val RMSE {val_rmse:.4f}")
# ---- after training ----
writer.flush()
writer.close()


Epoch 1:   0%|          | 0/175 [00:00<?, ?it/s]

Epoch 1:  20%|██        | 35/175 [00:06<00:17,  8.19it/s]

In [53]:
# ! tensorboard --logdir runs/MF --port 6006

In [None]:

# ---------- 5) Inference helper ----------
# Predict a rating for a raw (user_id, item_id)
def predict(user_id, item_id):
    model.eval()
    # map raw ids -> indices (fallback: raise if unknown)
    try:
        u_idx = int(np.where(user_uniques == user_id)[0][0])
        i_idx = int(np.where(item_uniques == item_id)[0][0])
    except IndexError:
        raise ValueError("Unknown user_id or item_id")
    u = torch.tensor([u_idx], dtype=torch.long, device=device)
    i = torch.tensor([i_idx], dtype=torch.long, device=device)
    with torch.no_grad():
        return model(u, i).item()

# Example:
# print(predict(user_id="A123", item_id="I777"))


## Benchmark

In [None]:
# (On Linux you can leave mp_ctx=None; on Windows/macOS use "spawn")
mp_ctx = None
if os.name == "nt":
    import multiprocessing as mp
    mp_ctx = mp.get_context("spawn")

results = benchmark_loader(
    model=model, device=device, dataset=dataset,
    batch_sizes=(4096, 8192, 16384, 32768),
    worker_choices=(0, 2, 4, 8, 12),
    pin_memory_choices=(True, False),
    prefetch_choices=(2, 4, 8),
    persistent=True,
    mp_ctx=mp_ctx,
)


In [None]:


print_top(results, top=8)
best = results[0]
print("\nRecommended DataLoader kwargs:")
print({
    "batch_size": best["batch_size"],
    "num_workers": best["num_workers"],
    "pin_memory": best["pin_memory"],
    "prefetch_factor": best["prefetch"],
    "persistent_workers": (best["num_workers"] > 0),
    "drop_last": True
})