In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from scipy.sparse import csr_matrix
import pandas as pd
import json
import warnings
warnings.filterwarnings('ignore')

class JointMultiLevelPOIModel(nn.Module):
    def __init__(self, n_users, n_pois_per_level, user_attr_dim, poi_attr_dims, 
                 latent_dim=64, use_s_matrix=True):
        super().__init__()
        self.use_s_matrix = use_s_matrix
        self.n_levels = len(n_pois_per_level)
        self.latent_dim = latent_dim
        
        self.U_u = nn.Parameter(torch.randn(n_users, latent_dim) * 0.01)
        self.U_p_levels = nn.ParameterList([
            nn.Parameter(torch.randn(n_pois, latent_dim) * 0.01)
            for n_pois in n_pois_per_level
        ])
        self.V_u = nn.Parameter(torch.randn(user_attr_dim, latent_dim) * 0.01)
        self.V_p_levels = nn.ParameterList([
            nn.Parameter(torch.randn(dims, latent_dim) * 0.01)
            for dims in poi_attr_dims
        ])
        
    def compute_L1(self, X, Y_list, reg_lambda=0.01):
        """Attribute reconstruction"""
        X_pred = self.U_u @ self.V_u.t()
        loss_user = torch.norm(X_pred - X, p='fro') ** 2
        
        loss_poi = torch.tensor(0.0, device=X.device)
        for l, Y in enumerate(Y_list):
            if Y is not None:
                Y_pred = self.U_p_levels[l] @ self.V_p_levels[l].t()
                loss_poi = loss_poi + torch.norm(Y_pred - Y, p='fro') ** 2
        
        reg_loss = reg_lambda * (torch.norm(self.V_u) ** 2 + 
                                  sum(torch.norm(V) ** 2 for V in self.V_p_levels))
        
        return loss_user + loss_poi + reg_loss
    
    def compute_L2(self, R_list, num_negatives=5):
        """BPR ranking - CRITICAL FIX: Always return tensor"""
        if R_list is None:
            return torch.tensor(0.0, device=self.U_u.device)
            
        total_loss = torch.tensor(0.0, device=self.U_u.device)
        valid_pairs = 0
        
        for l, R in enumerate(R_list):
            if R is None or R.nnz == 0:
                continue
                
            rows, cols = R.nonzero()
            if len(rows) == 0:
                continue
            
            # Sample subset if too many interactions (for efficiency)
            if len(rows) > 5000:
                idx = np.random.choice(len(rows), 5000, replace=False)
                rows, cols = rows[idx], cols[idx]
            
            users = torch.from_numpy(rows).long().to(self.U_u.device)
            pos_pois = torch.from_numpy(cols).long().to(self.U_u.device)
            
            u_factors = self.U_u[users]
            p_pos = self.U_p_levels[l][pos_pois]
            scores_pos = torch.sum(u_factors * p_pos, dim=1)
            
            n_pos = len(users)
            n_pois_l = self.U_p_levels[l].shape[0]
            
            for _ in range(num_negatives):
                neg_pois = torch.randint(0, n_pois_l, (n_pos,), device=self.U_u.device)
                # Ensure negatives are actually negative (not in R)
                p_neg = self.U_p_levels[l][neg_pois]
                scores_neg = torch.sum(u_factors * p_neg, dim=1)
                
                diff = scores_pos - scores_neg
                total_loss = total_loss + torch.nn.functional.softplus(-diff).mean()
                valid_pairs += 1
        
        if valid_pairs == 0:
            return torch.tensor(0.0, device=self.U_u.device)
            
        return total_loss / valid_pairs
    
    def compute_L3(self, S_list, reg_lambda=0.01):
        """Feature alignment - CRITICAL FIX: Always return tensor"""
        if not self.use_s_matrix or S_list is None:
            return torch.tensor(0.0, device=self.U_u.device)
        
        loss = torch.tensor(0.0, device=self.U_u.device)
        count = 0
        
        for l, S in enumerate(S_list):
            if S is not None:
                O_pred = self.U_u @ self.U_p_levels[l].t()
                # Ensure S is on same device
                if S.device != O_pred.device:
                    S = S.to(O_pred.device)
                loss = loss + torch.norm(O_pred - S, p='fro') ** 2
                count += 1
        
        if count == 0:
            return torch.tensor(0.0, device=self.U_u.device)
            
        reg_loss = reg_lambda * (torch.norm(self.U_u) ** 2 + 
                                  sum(torch.norm(U) ** 2 for U in self.U_p_levels))
        
        return loss / count + reg_loss
    
    def forward(self, X, Y_list, R_list=None, S_list=None, alpha=0.5, beta=0.1, gamma=1.0):
        L1 = self.compute_L1(X, Y_list)
        L2 = self.compute_L2(R_list)
        L3 = self.compute_L3(S_list)
        
        # CRITICAL: Ensure all are tensors
        device = X.device
        if not isinstance(L1, torch.Tensor):
            L1 = torch.tensor(float(L1), device=device)
        if not isinstance(L2, torch.Tensor):
            L2 = torch.tensor(float(L2), device=device)
        if not isinstance(L3, torch.Tensor):
            L3 = torch.tensor(float(L3), device=device)
        
        total = gamma * L1 + alpha * L2 + beta * L3
        return total, L1, L2, L3


def load_joint_data(paths):
    """Load all data with CRITICAL FIXES for S_matrix and R_matrix"""
    print("Loading data...")
    
    # 1. Load user embeddings
    with open(paths['user_emb'], 'rb') as f:
        user_data = pickle.load(f)
        X_A = user_data['X_A']
        X_T = user_data['X_T']
        X = np.concatenate([X_A, X_T], axis=1)
        X = torch.from_numpy(X).float()
        n_users = X.shape[0]
        user_ids = user_data.get('user_ids', [f"user_{i}" for i in range(n_users)])
        print(f"  Users: {n_users}, attributes: {X.shape[1]}")
    
    # 2. Load POI embeddings
    with open(paths['poi_emb'], 'rb') as f:
        poi_data = pickle.load(f)
    
    Y_list = []
    n_pois_list = []
    poi_attr_dims = []
    
    for level in range(4):
        level_key = f'level_{level}'
        level_emb = poi_data['poi_embeddings'][level_key]
        
        Y_A = level_emb['Y_A']
        Y_T = level_emb['Y_T']
        
        # Check for A_lp in different possible locations
        A_lp = None
        if 'A_lp' in poi_data:
            if level_key in poi_data['A_lp']:
                A_lp = poi_data['A_lp'][level_key]
            elif isinstance(poi_data['A_lp'], dict) and level in poi_data['A_lp']:
                A_lp = poi_data['A_lp'][level]
        
        if A_lp is None:
            # Use zeros with appropriate dimension
            target_dims = {0: 221, 1: 171, 2: 125, 3: 105}
            A_lp = np.zeros((Y_A.shape[0], target_dims[level]))
            
        Y = np.concatenate([Y_A, Y_T, A_lp], axis=1)
        Y_list.append(torch.from_numpy(Y).float())
        n_pois_list.append(Y.shape[0])
        poi_attr_dims.append(Y.shape[1])
        print(f"  Level {level}: {Y.shape[0]} POIs, {Y.shape[1]} attributes")
    
    # 3. Build R^l from CSV - CRITICAL SECTION
    print("\nBuilding R^l from CSV...")
    df = pd.read_csv(paths['csv'])
    
    with open(paths['poi_tree'], 'r') as f:
        poi_tree = json.load(f)
    
    # Build poi_key -> uuid mapping
    poi_key_to_uuid = {}
    uuid_to_level = {}
    for level_key, level_data in poi_tree.items():
        if level_key.startswith('level_'):
            l = int(level_key.split('_')[1])
            for pkey, pinfo in level_data.items():
                if 'uuid' in pinfo:
                    poi_key_to_uuid[pkey] = pinfo['uuid']
                    uuid_to_level[pinfo['uuid']] = l
    
    # Map user IDs
    user_to_idx = {uid: i for i, uid in enumerate(user_ids)}
    df['user_idx'] = df['user_id'].map(user_to_idx)
    
    # Map POI IDs
    df['poi_uuid'] = df['poi_id'].map(lambda x: poi_key_to_uuid.get(x))
    df['level'] = df['poi_uuid'].map(uuid_to_level)
    
    # Filter valid
    valid_df = df[df['user_idx'].notna() & df['poi_uuid'].notna()].copy()
    valid_df['user_idx'] = valid_df['user_idx'].astype(int)
    
    print(f"  CSV rows: {len(df)}, Valid rows: {len(valid_df)}")
    print(f"  Unique users in CSV: {df['user_id'].nunique()}")
    print(f"  Users mapped: {valid_df['user_idx'].nunique()}/{n_users}")
    
    # Build R matrices
    R_list = []
    for level in range(4):
        level_df = valid_df[valid_df['level'] == level]
        level_key = f'level_{level}'
        poi_ids = poi_data['poi_embeddings'][level_key]['poi_ids']
        poi_to_idx = {pid: i for i, pid in enumerate(poi_ids)}
        
        # Filter to POIs in this level
        level_df = level_df[level_df['poi_uuid'].isin(poi_to_idx)]
        
        if len(level_df) == 0:
            R_list.append(None)
            print(f"  Level {level}: No interactions")
            continue
        
        # Aggregate interactions
        agg = level_df.groupby(['user_idx', 'poi_uuid']).size().reset_index(name='count')
        agg['poi_idx'] = agg['poi_uuid'].map(poi_to_idx)
        
        rows = agg['user_idx'].values.astype(int)
        cols = agg['poi_idx'].values.astype(int)
        data = agg['count'].values.astype(float)
        
        R = csr_matrix((data, (rows, cols)), shape=(n_users, len(poi_ids)))
        R_list.append(R)
        print(f"  Level {level}: R{R.shape}, {R.nnz} interactions")
    
    # 4. Load S_matrices - CRITICAL FIX FOR TYPE HANDLING
    S_list = None
    if 's_matrix' in paths:
        try:
            with open(paths['s_matrix'], 'rb') as f:
                S_data = pickle.load(f)
            
            S_list = []
            for level in range(4):
                level_key = f'level_{level}'
                S_raw = S_data.get('S_matrices', {}).get(level_key)
                
                if S_raw is not None:
                    # Handle both numpy and torch tensors
                    if isinstance(S_raw, torch.Tensor):
                        S_tensor = S_raw.float()
                    elif isinstance(S_raw, np.ndarray):
                        S_tensor = torch.from_numpy(S_raw).float()
                    else:
                        print(f"  Level {level}: Unknown type {type(S_raw)}")
                        S_tensor = None
                    
                    if S_tensor is not None:
                        S_list.append(S_tensor)
                        print(f"  S^{level}: shape {S_tensor.shape}")
                else:
                    S_list.append(None)
            
            if not any(s is not None for s in S_list):
                S_list = None
                
        except Exception as e:
            print(f"  Warning: Could not load S matrices: {e}")
            S_list = None
    
    data_dict = {
        'X': X, 'Y_list': Y_list, 'R_list': R_list, 'S_list': S_list,
        'n_users': n_users, 'n_pois_list': n_pois_list,
        'poi_attr_dims': poi_attr_dims, 'user_ids': user_ids
    }
    
    return data_dict


def run_joint_optimization(paths, latent_dim=64, alpha=0.3, beta=0.2, 
                          gamma=1.0, lr=0.005, epochs=150, device='cpu'):
    """Run joint optimization"""
    
    # Load data
    data = load_joint_data(paths)
    
    # Check if we have any training signal
    has_R = any(r is not None and r.nnz > 0 for r in data['R_list']) if data['R_list'] else False
    has_S = data.get('S_list') is not None
    
    if not has_R and not has_S:
        print("\nWARNING: No interaction data (R^l) or feature matrices (S^l) found!")
        print("Training will only reconstruct attributes (L1 only)")
    
    # Initialize model
    model = JointMultiLevelPOIModel(
        n_users=data['n_users'],
        n_pois_per_level=data['n_pois_list'],
        user_attr_dim=data['X'].shape[1],
        poi_attr_dims=data['poi_attr_dims'],
        latent_dim=latent_dim,
        use_s_matrix=has_S
    ).to(device)
    
    # Move data to device
    X = data['X'].to(device)
    Y_list = [y.to(device) for y in data['Y_list']]
    R_list = data['R_list']  # Keep as scipy sparse
    S_list = None
    if has_S:
        S_list = [s.to(device) if s is not None else None for s in data['S_list']]
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    
    # Training loop
    history = {'total': [], 'L1': [], 'L2': [], 'L3': []}
    print(f"\nTraining: k={latent_dim}, α={alpha}, β={beta}, γ={gamma}")
    print(f"Device: {device}")
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        total_loss, L1, L2, L3 = model(
            X, Y_list, R_list, S_list,
            alpha=alpha, beta=beta, gamma=gamma
        )
        
        total_loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Record
        history['total'].append(total_loss.item())
        history['L1'].append(L1.item())
        history['L2'].append(L2.item())
        history['L3'].append(L3.item())
        
        if (epoch + 1) % 20 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d} | "
                  f"Total: {total_loss.item():.2f} | "
                  f"L1: {L1.item():.2f} | "
                  f"L2: {L2.item():.4f} | "
                  f"L3: {L3.item():.2f}")
    
    # Extract results
    model.eval()
    with torch.no_grad():
        results = {
            'U_u': model.U_u.cpu().numpy(),
            'U_p_levels': [up.cpu().numpy() for up in model.U_p_levels],
            'V_u': model.V_u.cpu().numpy(),
            'V_p_levels': [vp.cpu().numpy() for vp in model.V_p_levels],
            'history': history,
            'config': {'latent_dim': latent_dim, 'alpha': alpha, 
                      'beta': beta, 'gamma': gamma}
        }
    
    if 'output' in paths:
        with open(paths['output'], 'wb') as f:
            pickle.dump(results, f)
        print(f"\nSaved to {paths['output']}")
    
    return model, results, history


if __name__ == "__main__":
    paths = {
        'csv': '../Sources/Files/user_poi_interactions.csv',
        'poi_tree': '../Sources/Files/poi_tree_with_uuids.json',
        'user_emb': '../Sources/Embeddings v3/user_embeddings.pkl',
        'poi_emb': '../Sources/Embeddings v3/poi_embeddings.pkl',
        's_matrix': '../Sources/Embeddings v3/S_matrices_feature.pkl',
        'output': '../Sources/Embeddings v3/joint_optimized_final.pkl'
    }
    
    config = {
        'latent_dim': 64,
        'alpha': 0.3,      # L2 weight
        'beta': 0.2,       # L3 weight
        'gamma': 1.0,      # L1 weight
        'lr': 0.005,
        'epochs': 200,
        'device': 'cpu'
    }
    
    model, results, history = run_joint_optimization(paths, **config)
    
    print("\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    print(f"L1 (Attribute): {history['L1'][-1]:.2f}")
    print(f"L2 (BPR):       {history['L2'][-1]:.4f}")
    print(f"L3 (S-align):   {history['L3'][-1]:.2f}")

Loading data...
  Users: 21, attributes: 71
  Level 0: 4696 POIs, 442 attributes
  Level 1: 1355 POIs, 342 attributes
  Level 2: 44 POIs, 250 attributes
  Level 3: 5 POIs, 210 attributes

Building R^l from CSV...
  CSV rows: 567, Valid rows: 567
  Unique users in CSV: 21
  Users mapped: 21/21
  Level 0: No interactions
  Level 1: No interactions
  Level 2: No interactions
  Level 3: No interactions
  S^0: shape torch.Size([21, 4696])
  S^1: shape torch.Size([21, 1355])
  S^2: shape torch.Size([21, 44])
  S^3: shape torch.Size([21, 5])

Training: k=64, α=0.3, β=0.2, γ=1.0
Device: cpu
Epoch   1 | Total: 298813.16 | L1: 293260.59 | L2: 0.0000 | L3: 27762.76
Epoch  20 | Total: 271485.09 | L1: 266209.50 | L2: 0.0000 | L3: 26378.03
Epoch  40 | Total: 209999.73 | L1: 205365.41 | L2: 0.0000 | L3: 23171.62
Epoch  60 | Total: 132633.59 | L1: 128838.66 | L2: 0.0000 | L3: 18974.64
Epoch  80 | Total: 72768.44 | L1: 69875.20 | L2: 0.0000 | L3: 14466.23
Epoch 100 | Total: 38359.95 | L1: 36212.89 | L2