<p align="center">
  <h1 align="center">🍳 Cookbook 03: RecSys & GNN Health Monitoring</h1>
  <p align="center">
    <strong>Diagnosing Embedding Collapse & GNN Over-Smoothing with GradTracer</strong>
  </p>
</p>

---

This recipe targets large-scale recommendation systems using Factorization Machines (FM), Matrix Factorization (MF), and Graph Neural Networks (GNN).

We will demonstrate how GradTracer identifies two critical phenomenon:
1. **Embedding Collapse (Zombie / Dead Neurons)** in standard MF/FM models.
2. **Over-Smoothing (Variance Stagnation)** in GNN Message Passing layers.

## 1. Setup & Load MovieLens-1M

In [None]:
# !pip install gradtracer torch pandas numpy scipy requests tqdm

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import urllib.request
import zipfile
from gradtracer import FlowManager, EmbeddingTracker, FlowTracker

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Downloading MovieLens-1M for scaled testing...")
url = "http://files.grouplens.org/datasets/movielens/ml-1m.zip"
if not os.path.exists("ml-1m"):
    urllib.request.urlretrieve(url, "ml-1m.zip")
    with zipfile.ZipFile("ml-1m.zip", 'r') as zip_ref:
        zip_ref.extractall(".")

columns = ['user_id', 'item_id', 'rating', 'timestamp']
df = pd.read_csv('ml-1m/ratings.dat', sep='::', names=columns, engine='python')

# Implicit Feedback conversion (Rating >= 4 is positive)
df = df[df['rating'] >= 4].copy()
df['rating'] = 1.0

df['user_id'] = df['user_id'].astype('category').cat.codes
df['item_id'] = df['item_id'].astype('category').cat.codes
num_users = df['user_id'].nunique()
num_items = df['item_id'].max() + 1

class ImplicitDataset(Dataset):
    def __init__(self, df):
        self.users = torch.tensor(df['user_id'].values, dtype=torch.long)
        self.items = torch.tensor(df['item_id'].values, dtype=torch.long)
        self.labels = torch.ones(len(df), dtype=torch.float32)
        
    def __len__(self):
        return len(self.users)
        
    def __getitem__(self, idx):
        return self.users[idx], self.items[idx], self.labels[idx]

train_loader = DataLoader(ImplicitDataset(df.sample(frac=0.8)), batch_size=1024, shuffle=True)

## 2. Tracking Matrix Factorization (Embedding Collapse)

In [None]:
class FactorizationMachine(nn.Module):
    def __init__(self, num_users, num_items, dim=64):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, dim)
        self.item_emb = nn.Embedding(num_items, dim)
        
        # Introduce bad initialization to force zombie behavior
        nn.init.uniform_(self.user_emb.weight, -0.5, 0.5)
        nn.init.uniform_(self.item_emb.weight, -0.5, 0.5)
        
    def forward(self, user, item):
        return (self.user_emb(user) * self.item_emb(item)).sum(1)

model_fm = FactorizationMachine(num_users, num_items).to(device)
optimizer = torch.optim.Adam(model_fm.parameters(), lr=0.1) # Extremely high LR to force oscillation
criterion = nn.BCEWithLogitsLoss()

manager = FlowManager()
# We track BOTH embeddings
manager.add_tracker("user", EmbeddingTracker(model_fm.user_emb, name="U_Emb", auto_fix=True))
manager.add_tracker("item", EmbeddingTracker(model_fm.item_emb, name="I_Emb", auto_fix=True))

model_fm.train()
for i, (u, i_id, l) in enumerate(train_loader):
    u, i_id, l = u.to(device), i_id.to(device), l.to(device)
    optimizer.zero_grad()
    preds = model_fm(u, i_id)
    loss = criterion(preds, l)
    loss.backward()
    
    manager.step()
    optimizer.step()
    
    if i > 50: break
    
manager.report()

## 3. Detecting GNN Over-Smoothing (Experimental)
For Graph Neural Networks (like LightGCN or GAT), GradTracer monitors the transformation layers. If the Gradient Norm / Variance of layer $(L)$ drops to near zero while $(L-1)$ is still active, it implies that message passing is washing out distinct node representations (Over-Smoothing).

In [None]:
class SimpleGNN(nn.Module):
    def __init__(self, num_nodes, dim=64):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.gcn1 = nn.Linear(dim, dim)
        self.gcn2 = nn.Linear(dim, dim)
        self.gcn3 = nn.Linear(dim, dim) # 3 Layers deep usually starts oversmoothing
        
    def forward(self, nodes):
        # Simulating Message Passing heavily for demo
        x = self.emb(nodes)
        x = F.relu(self.gcn1(x))
        x = F.relu(self.gcn2(x))
        x = self.gcn3(x)
        return x.sum(-1)

gnn = SimpleGNN(num_users).to(device)
gnn_tracker = FlowTracker(gnn, track_gradients=True)
optimizer = torch.optim.Adam(gnn.parameters(), lr=0.01)

gnn.train()
for i, (u, _, _) in enumerate(train_loader):
    u = u.to(device)
    optimizer.zero_grad()
    out = gnn(u)
    loss = out.mean()
    loss.backward()
    gnn_tracker.step(loss.item())
    optimizer.step()
    if i > 25: break

# Examine the layer variances
print("\n\n=== GNN Over-Smoothing Check ===")
layer_stats = gnn_tracker.summary()
for layer_name, stats in layer_stats.items():
    if "gcn" in layer_name and "weight" in layer_name:
        print(f"{layer_name}: Health Score = {stats.get('health_score', 'N/A')}")
        if stats.get('health_score', 100) < 50:
            print(f"  ⚠️ Early warning for over-smoothing detected in {layer_name}!")
        