In [49]:
import json
import numpy as np
import tensorflow as tf
from collections import defaultdict
import pickle

## Step 1: Load POI Tree 

In [50]:
class POITreeLoader:
    def __init__(self, poi_tree_path):
        with open(poi_tree_path, 'r') as f:
            self.poi_tree = json.load(f)
        
        self.level_names = {
            0: 'individual_poi',
            1: 'container_poi', 
            2: 'street_poi',
            3: 'district_poi'
        }
        
        self.num_levels = 4
        self._build_mappings()
    
    def _build_mappings(self):
        """Build POI ID mappings and parent-child relationships"""
        self.poi_to_idx = {}
        self.idx_to_poi = {}
        self.parent_child = defaultdict(list)
        self.poi_level_counts = {}
        
        for level in range(self.num_levels):
            level_key = f'level_{level}'
            pois = self.poi_tree[level_key]
            
            self.poi_to_idx[level] = {}
            self.idx_to_poi[level] = {}
            
            for idx, (poi_id, poi_data) in enumerate(pois.items()):
                self.poi_to_idx[level][poi_id] = idx
                self.idx_to_poi[level][idx] = poi_id
                
                # Build parent-child relationships
                if 'parent' in poi_data and poi_data['parent']:
                    parent = poi_data['parent']
                    self.parent_child[parent].append(poi_id)
            
            self.poi_level_counts[level] = len(pois)
        
        print("POI Tree Mappings Built:")
        for level in range(self.num_levels):
            print(f"  Level {level} ({self.level_names[level]}): {self.poi_level_counts[level]} POIs")

poi_loader = POITreeLoader("../Sources/Files/poi_tree_with_uuids.json")

POI Tree Mappings Built:
  Level 0 (individual_poi): 4696 POIs
  Level 1 (container_poi): 1355 POIs
  Level 2 (street_poi): 44 POIs
  Level 3 (district_poi): 5 POIs


## Step 2 : Load all Embeddings files for Training

In [None]:
class MultiLevelDataPreparation:
    def __init__(self, poi_loader, user_embeddings, poi_embeddings, 
                interactions, metadata, target_poi_dim=20):
        self.poi_loader = poi_loader
        self.num_levels = poi_loader.num_levels
        self.num_users = metadata['counts']['users']
        self.target_poi_dim = target_poi_dim 
        
        # Load embeddings
        self.user_embeddings = user_embeddings
        self.poi_embeddings = poi_embeddings
        self.interactions = interactions
        self.metadata = metadata
        
        self._prepare_attribute_matrices()
        self._prepare_interaction_matrices()
    
    # def _prepare_interaction_matrices(self):
    #     """
    #     Parse interaction data from nested structure with 'edges' key
    #     """
    #     print("\nPreparing Interaction Matrices for L2...")
        
    #     self.positive_samples = {}
        
    #     # Step 1: Extract positive samples for all levels
    #     for level in range(self.num_levels):
    #         level_key = f'level_{level}'
            
    #         pos_samples = self._extract_samples_from_level(
    #             self.interactions['interactions'], 
    #             level_key, 
    #             'positive'
    #         )
    #         self.positive_samples[level] = pos_samples
        
    #     # Step 2: Generate negative samples for all levels at once
    #     self._generate_negative_samples(num_negatives_per_positive=1)
        
    #     # Step 3: Print summary
    #     for level in range(self.num_levels):
    #         print(f"  Level {level}: {self.positive_samples[level].shape[0]} positive, "
    #             f"{self.negative_samples[level].shape[0]} negative samples")
            
    # def _extract_samples_from_level(self, data_dict, level_key, sample_type):
    #     """
    #     Extract user-POI pairs from the nested structure
        
    #     Expected structure:
    #     data_dict[level_key] = {
    #         'edges': {
    #             'user_indices': array([...]),
    #             'poi_indices': array([...]),
    #             ...
    #         },
    #         'matrices': {...},
    #         'user_to_pois': {...}
    #     }
        
    #     âœ… Returns NumPy array (not TensorFlow tensor)
    #     """
    #     if level_key not in data_dict:
    #         print(f"  Warning: {level_key} not found in {sample_type} samples")
    #         return np.array([], dtype=np.int32).reshape(0, 2)
        
    #     level_data = data_dict[level_key]
        
    #     # Check if it's the nested structure with 'edges'
    #     if isinstance(level_data, dict) and 'edges' in level_data:
    #         edges = level_data['edges']
            
    #         # Extract indices from edges
    #         if 'user_indices' in edges and 'poi_indices' in edges:
    #             user_indices = edges['user_indices']
    #             poi_indices = edges['poi_indices']
                
    #             # Convert to numpy if needed
    #             if not isinstance(user_indices, np.ndarray):
    #                 user_indices = np.array(user_indices, dtype=np.int32)
    #             if not isinstance(poi_indices, np.ndarray):
    #                 poi_indices = np.array(poi_indices, dtype=np.int32)
                
    #             # Ensure they have the same length
    #             if len(user_indices) != len(poi_indices):
    #                 print(f"  Warning: Mismatched lengths for level {level_key}: "
    #                     f"users={len(user_indices)}, pois={len(poi_indices)}")
    #                 min_len = min(len(user_indices), len(poi_indices))
    #                 user_indices = user_indices[:min_len]
    #                 poi_indices = poi_indices[:min_len]
                
    #             # Stack into [N, 2] array
    #             samples = np.column_stack([user_indices, poi_indices])
    #             return samples  # âœ… Return NumPy array
    #         else:
    #             print(f"  Warning: 'user_indices' or 'poi_indices' not found in {level_key} edges")
    #             return np.array([], dtype=np.int32).reshape(0, 2)
        
    #     # If it's directly a dict, list, or array
    #     elif isinstance(level_data, (list, np.ndarray)):
    #         try:
    #             data_array = np.array(level_data, dtype=np.int32)
    #             if len(data_array.shape) == 2 and data_array.shape[1] == 2:
    #                 return data_array
    #         except Exception as e:
    #             print(f"  Warning: Could not parse {level_key} as array: {e}")
        
    #     print(f"  Warning: Unknown structure for {level_key} in {sample_type} samples")
    #     return np.array([], dtype=np.int32).reshape(0, 2)
    
    # def _generate_negative_samples(self, num_negatives_per_positive=1):
    #     """
    #     Generate multiple negative samples per positive sample
    #     âœ… Works on all levels at once
    #     """
    #     self.negative_samples = {}
        
    #     for level in range(self.num_levels):
    #         pos_samples = self.positive_samples[level]
            
    #         if pos_samples.shape[0] == 0:
    #             self.negative_samples[level] = np.array([], dtype=np.int32).reshape(0, 2)
    #             continue
            
    #         neg_samples = []
    #         num_pois = self.poi_loader.poi_level_counts[level]
            
    #         # âœ… FIX: Ensure pos_samples is NumPy array (not TensorFlow tensor)
    #         if isinstance(pos_samples, tf.Tensor):
    #             pos_samples = pos_samples.numpy()
            
    #         # Build user -> positive POIs mapping
    #         user_positive_pois = {}
    #         for user_idx, poi_idx in pos_samples:
    #             user_idx = int(user_idx)  # âœ… Convert to Python int
    #             poi_idx = int(poi_idx)
                
    #             if user_idx not in user_positive_pois:
    #                 user_positive_pois[user_idx] = set()
    #             user_positive_pois[user_idx].add(poi_idx)
            
    #         # Generate negative samples
    #         for i in range(pos_samples.shape[0]):
    #             user_idx = int(pos_samples[i, 0])
    #             pos_poi_set = user_positive_pois.get(user_idx, set())
                
    #             # Sample N negative POIs for this user
    #             for _ in range(num_negatives_per_positive):
    #                 neg_poi_idx = np.random.randint(0, num_pois)
                    
    #                 # Ensure it's not a positive POI
    #                 max_attempts = 100
    #                 attempts = 0
    #                 while neg_poi_idx in pos_poi_set and attempts < max_attempts:
    #                     neg_poi_idx = np.random.randint(0, num_pois)
    #                     attempts += 1
                    
    #                 neg_samples.append([user_idx, neg_poi_idx])
            
    #         self.negative_samples[level] = np.array(neg_samples, dtype=np.int32)
            
    #         print(f"  Level {level}: Generated {len(neg_samples)} negative samples "
    #             f"({num_negatives_per_positive} per positive)")
            
    def _prepare_attribute_matrices(self):
        """Prepare and normalize all attribute matrices"""
        print("\nPreparing Attribute Matrices for L1...")
        
        # Build raw matrices
        self.X_raw = self._build_user_attribute_matrix()
        self.Y_raw = {}
        for level in range(self.num_levels):
            self.Y_raw[level] = self._build_poi_attribute_matrix(level)
            
            # Verify dimensions
            if self.Y_raw[level].shape[1] != self.target_poi_dim:
                print(f"  WARNING: Level {level} has {self.Y_raw[level].shape[1]} features, "
                    f"expected {self.target_poi_dim}")
        
        # Standardize user features
        X_mean = np.mean(self.X_raw, axis=0, keepdims=True)
        X_std = np.std(self.X_raw, axis=0, keepdims=True) + 1e-8
        self.X = (self.X_raw - X_mean) / X_std
        self.X = tf.constant(self.X, dtype=tf.float32)
        
        print(f"  User attribute matrix X:")
        print(f"    Shape: {self.X.shape}")
        print(f"    Normalized: mean={self.X.numpy().mean():.4f}, std={self.X.numpy().std():.4f}")
        
        # Standardize POI features per level
        self.Y = {}
        for level in range(self.num_levels):
            Y_raw = self.Y_raw[level]
            
            # Check if all zeros
            if Y_raw.std() < 1e-6:
                print(f"  WARNING: Level {level} has near-zero variance! Adding small noise.")
                Y_raw = Y_raw + np.random.randn(*Y_raw.shape).astype(np.float32) * 0.01
            
            Y_mean = np.mean(Y_raw, axis=0, keepdims=True)
            Y_std = np.std(Y_raw, axis=0, keepdims=True) + 1e-8
            Y_normalized = (Y_raw - Y_mean) / Y_std
            
            self.Y[level] = tf.constant(Y_normalized, dtype=tf.float32)
            
            print(f"  POI attribute matrix Y^{level}:")
            print(f"    Shape: {self.Y[level].shape}")
            print(f"    Normalized: mean={Y_normalized.mean():.4f}, std={Y_normalized.std():.4f}")

    def _build_user_attribute_matrix(self):
        """Build user attribute matrix"""
        user_ids = self.metadata['user_ids']
        features_list = []
        
        for user_id in user_ids:
            if user_id in self.user_embeddings['user_embeddings']:
                user_data = self.user_embeddings['user_embeddings'][user_id]
                features = self._extract_user_features(user_data)
            else:
                features = np.zeros(self._get_user_feature_dim())
            
            features_list.append(features)
        
        X = np.array(features_list, dtype=np.float32)
        return X
    
    def _get_user_feature_dim(self):
        """Get dimension of user features"""
        if self.user_embeddings['user_embeddings']:
            sample_user_id = list(self.user_embeddings['user_embeddings'].keys())[0]
            sample_data = self.user_embeddings['user_embeddings'][sample_user_id]
            sample_features = self._extract_user_features(sample_data)
            return len(sample_features)
        return 10
    
    def _extract_user_features(self, user_data):
        """Extract feature vector from user data"""
        features = []
        
        if isinstance(user_data, (np.ndarray, list)):
            return np.array(user_data, dtype=np.float32).flatten()
        
        if isinstance(user_data, dict):
            for key in sorted(user_data.keys()):
                value = user_data[key]
                if isinstance(value, (int, float, np.number)):
                    features.append(float(value))
                elif isinstance(value, (list, np.ndarray)):
                    arr = np.array(value).flatten()
                    features.extend(arr.tolist())
        
        if not features:
            features = np.zeros(10)
        
        return np.array(features, dtype=np.float32)
    
    def _parse_price(self, price_str):
        """Parse price string"""
        try:
            price_str = str(price_str).replace('$', '').strip()
            parts = price_str.split(' - ')
            if len(parts) == 2:
                avg_price = (float(parts[0]) + float(parts[1])) / 2.0
                return avg_price / 50.0
            elif len(parts) == 1:
                return float(parts[0]) / 50.0
        except:
            pass
        return 0.0
    
    def _build_poi_attribute_matrix(self, level):
        """Build Y^l with consistent dimensions across all levels"""
        level_key = f'level_{level}'
        pois = self.poi_loader.poi_tree[level_key]
        
        poi_features = []
        for poi_id in sorted(pois.keys()):
            poi_data = pois[poi_id]['data']
            features = self._extract_poi_features_by_level(poi_data, level)
            
            # Pad or truncate to target dimension
            if len(features) < self.target_poi_dim:
                # Pad with zeros
                padding = np.zeros(self.target_poi_dim - len(features), dtype=np.float32)
                features = np.concatenate([features, padding])
            elif len(features) > self.target_poi_dim:
                # Truncate
                features = features[:self.target_poi_dim]
            
            poi_features.append(features)
        
        Y_l = np.array(poi_features, dtype=np.float32)
        return Y_l
    
    def _extract_poi_features_by_level(self, poi_data, level):
        """Extract features specific to each level"""
        
        if level in [0, 1]:  # individual_poi, container_poi
            return self._extract_detailed_poi_features(poi_data)
        elif level == 2:  # street_poi
            return self._extract_street_features(poi_data)
        elif level == 3:  # district_poi
            return self._extract_district_features(poi_data)
        
    def _extract_detailed_poi_features(self, poi_data):
        """Features for individual POIs and containers"""
        features = []
        
        # Normalized coordinates
        if 'latitude' in poi_data and 'longitude' in poi_data:
            lat_norm = (float(poi_data['latitude']) - 1.2) / 0.3
            lon_norm = (float(poi_data['longitude']) - 103.6) / 0.4
            features.extend([np.clip(lat_norm, 0, 1), np.clip(lon_norm, 0, 1)])
        else:
            features.extend([0.5, 0.5])
        
        # Category
        categories = ['supermarket', 'restaurant', 'cafe', 'shop', 'entertainment', 
                    'hotel', 'mall', 'school', 'hospital', 'atm', 'shopping_mall', 
                    'convenience_store', 'other']
        category = poi_data.get('category', 'other')
        category_encoding = [1.0 if cat == category else 0.0 for cat in categories]
        features.extend(category_encoding)
        
        # Price (normalized)
        price_val = self._parse_price(poi_data.get('price', '0'))
        features.append(price_val)
        
        # Popularity (normalized)
        popularity = float(poi_data.get('popularity', 0)) / 5.0
        features.append(popularity)
        
        # Characteristics
        char_str = str(poi_data.get('characteristic', '')).lower()
        characteristics = ['budget', 'premium', 'family', 'essentials', 'luxury']
        char_encoding = [1.0 if char in char_str else 0.0 for char in characteristics]
        features.extend(char_encoding)
        
        return np.array(features, dtype=np.float32)
    
    def _extract_street_features(self, poi_data):
        """Features for streets (Level 2) - same dimension as detailed features"""
        features = []
        
        # 1. Coordinates (2 features)
        if 'latitude' in poi_data and 'longitude' in poi_data:
            lat_norm = (float(poi_data['latitude']) - 1.2) / 0.3
            lon_norm = (float(poi_data['longitude']) - 103.6) / 0.4
            features.extend([np.clip(lat_norm, 0, 1), np.clip(lon_norm, 0, 1)])
        else:
            features.extend([0.5, 0.5])
        
        # 2. District/Area one-hot (8 features - to match category dimension)
        district = poi_data.get('district', 'OTHER')
        districts = [
            "Orchard Road",
            "Scotts Road",
            "Bras Basah Road",
            "Bugis Street",
            "Victoria Street",
            "North Bridge Road",
            "Beach Road",
            "Arab Street",
            "Haji Lane",
            "Chinatown Point Road",
            "South Bridge Road",
            "Eu Tong Sen Street",
            "New Bridge Road",
            "Tanjong Pagar Road",
            "Anson Road",
            "Shenton Way",
            "Raffles Place",
            "Collyer Quay",
            "Marina Boulevard",
            "Marina Bay Sands Drive",
            "Cecil Street",
            "Robinson Road",
            "Telok Ayer Street",
            "Amoy Street",
            "Keong Saik Road",
            "Bukit Timah Road",
            "Holland Road",
            "Clementi Road",
            "Upper Thomson Road",
            "Serangoon Road",
            "Balestier Road",
            "Thomson Road",
            "Geylang Road",
            "Paya Lebar Road",
            "East Coast Road",
            "Tampines Avenue 1",
            "Pasir Ris Drive 1",
            "Ang Mo Kio Avenue 3",
            "Yishun Ring Road",
            "Woodlands Avenue 6",
            "Jurong West Street 41",
            "Boon Lay Way",
            "Choa Chu Kang Avenue 4",
            "Bukit Batok Road",
            "Sengkang East Way",
            "Punggol Central"
        ]

        district_encoding = [1.0 if d == district else 0.0 for d in districts]
        features.extend(district_encoding)
        
        # 3. Average price level (1 feature)
        features.append(float(poi_data.get('avg_price', 0)))
        
        # 4. Average popularity (1 feature)
        features.append(float(poi_data.get('avg_popularity', 0)) / 5.0)
        
        # 5. Street characteristics (5 features - to match characteristics dimension)
        street_type = str(poi_data.get('type', '')).lower()
        street_types = ['commercial', 'residential', 'mixed', 'tourist', 'other']
        type_encoding = [1.0 if st in street_type else 0.0 for st in street_types]
        features.extend(type_encoding)
        
        # Total: 2 + 8 + 1 + 1 + 5 = 17 features
        return np.array(features, dtype=np.float32)
    
    def _extract_district_features(self, poi_data):
        """Features for districts (Level 3) - same dimension as detailed features"""
        features = []
        
        # 1. District center coordinates (2 features)
        if 'latitude' in poi_data and 'longitude' in poi_data:
            lat_norm = (float(poi_data['latitude']) - 1.2) / 0.3
            lon_norm = (float(poi_data['longitude']) - 103.6) / 0.4
            features.extend([np.clip(lat_norm, 0, 1), np.clip(lon_norm, 0, 1)])
        else:
            features.extend([0.5, 0.5])
        
        # 2. District name one-hot (8 features)
        district_name = poi_data.get('name', 'OTHER')
        districts = [
            "ANG MO KIO",
            "BEDOK",
            "BISHAN",
            "BOON LAY",
            "BUKIT BATOK",
            "BUKIT MERAH",
            "BUKIT PANJANG",
            "BUKIT TIMAH",
            "CENTRAL WATER CATCHMENT",
            "CHANGI",
            "CHANGI BAY",
            "CHOA CHU KANG",
            "CLEMENTI",
            "DOWNTOWN CORE",
            "GEYLANG",
            "HOUGANG",
            "JURONG EAST",
            "JURONG WEST",
            "KALLANG",
            "LIM CHU KANG",
            "MANDAI",
            "MARINA EAST",
            "MARINA SOUTH",
            "MARINA BAY",
            "NOVENA",
            "ORCHARD",
            "OUTRAM",
            "PASIR RIS",
            "PAYA LEBAR",
            "PIONEER",
            "PUNGGOL",
            "QUEENSTOWN",
            "RIVER VALLEY",
            "ROCHOR",
            "SELETAR",
            "SEMBAWANG",
            "SENGKANG",
            "SERANGOON",
            "SIMPANG",
            "SOUTHERN ISLANDS",
            "STRAITS VIEW",
            "SUNGEI KADUT",
            "TAMPINES",
            "TANGLIN",
            "TENGAH",
            "TOA PAYOH",
            "TUAS",
            "WESTERN ISLANDS",
            "WESTERN WATER CATCHMENT",
            "WOODLANDS",
            "YISHUN"
        ]

        district_encoding = [1.0 if d == district_name else 0.0 for d in districts]
        features.extend(district_encoding)
        
        # 3. District density (1 feature)
        density = float(poi_data.get('num_pois', 0)) / 1000.0  # Normalize
        features.append(density)
        
        # 4. District average popularity (1 feature)
        features.append(float(poi_data.get('avg_popularity', 0)) / 5.0)
        
        # 5. Region one-hot (5 features)
        region = poi_data.get('region', 'OTHER')
        regions = ['CENTRAL', 'NORTH', 'SOUTH', 'EAST', 'WEST']
        region_encoding = [1.0 if r == region else 0.0 for r in regions]
        features.extend(region_encoding)
        
        # Total: 2 + 8 + 1 + 1 + 5 = 17 features
        return np.array(features, dtype=np.float32)
    

In [52]:
class JointLossComputation:
    def __init__(self, data_prep, embedding_dim=32):
        self.data_prep = data_prep
        self.embedding_dim = embedding_dim
        self.num_levels = data_prep.num_levels
        self.num_users = data_prep.num_users

        self.L1_scale = None
        self.L2_scale = None
        self.calibration_epochs = 10
        self.epoch_counter = 0
        
        # Initialize learnable parameters
        self._initialize_parameters()
    
    def _initialize_parameters(self):
        """Initialize with level-specific transformation matrices"""
        
        # User latent matrix
        user_stddev = np.sqrt(2.0 / (self.num_users + self.embedding_dim))
        self.U_u = tf.Variable(
            tf.random.normal([self.num_users, self.embedding_dim], stddev=user_stddev),
            name='U_u',
            dtype=tf.float32
        )
        
        # POI latent matrices for each level
        self.U_p = {}
        for level in range(self.num_levels):
            num_pois = self.data_prep.poi_loader.poi_level_counts[level]
            poi_stddev = np.sqrt(2.0 / (num_pois + self.embedding_dim))
            
            self.U_p[level] = tf.Variable(
                tf.random.normal([num_pois, self.embedding_dim], stddev=poi_stddev),
                name=f'U_p_level_{level}',
                dtype=tf.float32
            )
        
        # User transformation matrix
        user_feature_dim = self.data_prep.X.shape[1]
        v_u_stddev = np.sqrt(2.0 / (user_feature_dim + self.embedding_dim))
        
        self.V_u = tf.Variable(
            tf.random.normal([user_feature_dim, self.embedding_dim], stddev=v_u_stddev),
            name='V_u',
            dtype=tf.float32
        )
        
        # ðŸ”´ FIX: Level-specific transformation matrices (2D, not 1D!)
        self.V_p = {}
        for level in range(self.num_levels):
            poi_feature_dim = self.data_prep.Y[level].shape[1]  # Should be 20
            v_p_stddev = np.sqrt(2.0 / (poi_feature_dim + self.embedding_dim))
            
            # âœ… Correct: [poi_feature_dim, embedding_dim]
            self.V_p[level] = tf.Variable(
                tf.random.normal([poi_feature_dim, self.embedding_dim], stddev=v_p_stddev),
                name=f'V_p_level_{level}',
                dtype=tf.float32
            )
            
            print(f"  V_p^{level}: shape {self.V_p[level].shape} (feature_dim={poi_feature_dim}, embed_dim={self.embedding_dim})")
        
        print(f"\nInitialized Parameters:")
        print(f"  U_u: {self.U_u.shape}")
        for level in range(self.num_levels):
            print(f"  U_p^{level}: {self.U_p[level].shape}")
        print(f"  V_u: {self.V_u.shape}")
        for level in range(self.num_levels):
            print(f"  V_p^{level}: {self.V_p[level].shape}")

    def compute_L1_loss(self):
        """Compute L1 loss using MEAN squared error instead of SUM"""
        # User reconstruction
        X_reconstructed = tf.matmul(self.U_u, self.V_u, transpose_b=True)
        # CHANGE: reduce_sum -> reduce_mean
        L1_user = tf.reduce_mean(tf.square(X_reconstructed - self.data_prep.X)) 
        
        L1_poi = tf.constant(0.0, dtype=tf.float32)
        poi_losses = {}
        
        for level in range(self.num_levels):
            Y_l = self.data_prep.Y[level]
            Y_reconstructed = tf.matmul(self.U_p[level], self.V_p[level], transpose_b=True)
            
            # CHANGE: reduce_sum -> reduce_mean
            level_loss = tf.reduce_mean(tf.square(Y_reconstructed - Y_l))
            L1_poi += level_loss
            poi_losses[level] = level_loss
        
        L1_total = L1_user + L1_poi
        return L1_total, L1_user, L1_poi, poi_losses
    
    def compute_L2_loss(self):
        """
        Compute L2 (BPR) loss for ranking
        L2 = -Î£_l Î£_i Î£_j log(Ïƒ(u_i^T v_{p+}^l - u_i^T v_{p-}^l))
        
        âœ… FIX: Use the SAME user for positive and negative samples
        """
        L2_total = tf.constant(0.0, dtype=tf.float32)
        level_losses = {}
        
        for level in range(self.num_levels):
            pos_samples = tf.constant(self.data_prep.positive_samples[level], dtype=tf.int32)
            neg_samples = tf.constant(self.data_prep.negative_samples[level], dtype=tf.int32)
            
            if pos_samples.shape[0] == 0:
                level_losses[level] = tf.constant(0.0, dtype=tf.float32)
                continue
            
            # âœ… FIX: Both samples should have same user index
            # pos_samples: [user_idx, pos_poi_idx]
            # neg_samples: [user_idx, neg_poi_idx]  <- SAME user_idx
            
            user_indices = pos_samples[:, 0]  # Get user indices from positive samples
            pos_poi_indices = pos_samples[:, 1]
            neg_poi_indices = neg_samples[:, 1]
            
            # Verify shapes match
            if pos_poi_indices.shape[0] != neg_poi_indices.shape[0]:
                raise ValueError(
                    f"Level {level}: Positive and negative samples must have same length. "
                    f"Got pos={pos_poi_indices.shape[0]}, neg={neg_poi_indices.shape[0]}"
                )
            
            # Get embeddings
            u = tf.gather(self.U_u, user_indices)  # Same user for both
            v_pos = tf.gather(self.U_p[level], pos_poi_indices)
            v_neg = tf.gather(self.U_p[level], neg_poi_indices)
            
            # Compute scores
            scores_pos = tf.reduce_sum(u * v_pos, axis=1)  # [batch_size]
            scores_neg = tf.reduce_sum(u * v_neg, axis=1)  # [batch_size]
            
            # Now shapes match!
            diff = scores_pos - scores_neg  # [batch_size]
            loss = -tf.reduce_sum(tf.math.log(tf.sigmoid(diff) + 1e-10))
            
            L2_total += loss
            level_losses[level] = loss
        
        return L2_total, level_losses
    
    def compute_regularization(self):
        """Compute regularization term"""
        reg = tf.reduce_sum(tf.square(self.U_u))
        reg += tf.reduce_sum(tf.square(self.V_u))
        
        for level in range(self.num_levels):
            reg += tf.reduce_sum(tf.square(self.U_p[level]))
            reg += tf.reduce_sum(tf.square(self.V_p[level]))  # âœ… [level] not missing
        
        return reg
    
    def _compute_L2_for_level(self, level):
        """Compute BPR loss for a specific level"""
        pos_samples = self.data_prep.positive_samples[level]
        neg_samples = self.data_prep.negative_samples[level]
        
        # Check if we have samples
        if pos_samples.shape[0] == 0 or neg_samples.shape[0] == 0:
            return tf.constant(0.0, dtype=tf.float32)
        
        # Ensure we have the same number of positive and negative samples
        min_samples = min(pos_samples.shape[0], neg_samples.shape[0])
        pos_samples = pos_samples[:min_samples]
        neg_samples = neg_samples[:min_samples]
        
        # Extract user and POI indices
        pos_user_indices = pos_samples[:, 0]
        pos_poi_indices = pos_samples[:, 1]
        neg_user_indices = neg_samples[:, 0]
        neg_poi_indices = neg_samples[:, 1]
        
        # Get embeddings (vectorized)
        u_pos = tf.gather(self.U_u, pos_user_indices)  # [n_samples, embedding_dim]
        v_pos = tf.gather(self.U_p[level], pos_poi_indices)  # [n_samples, embedding_dim]
        u_neg = tf.gather(self.U_u, neg_user_indices)  # [n_samples, embedding_dim]
        v_neg = tf.gather(self.U_p[level], neg_poi_indices)  # [n_samples, embedding_dim]
        
        # Compute scores (element-wise dot product)
        scores_pos = tf.reduce_sum(u_pos * v_pos, axis=1)  # [n_samples]
        scores_neg = tf.reduce_sum(u_neg * v_neg, axis=1)  # [n_samples]
        
        # BPR loss: -ln(Ïƒ(score_pos - score_neg))
        diff = scores_pos - scores_neg
        loss = -tf.reduce_sum(tf.math.log(tf.sigmoid(diff) + 1e-10))
        
        return loss
    
    def compute_regularization(self):
        reg = tf.reduce_sum(tf.square(self.U_u))
        reg += tf.reduce_sum(tf.square(self.V_u))
        
        for level in [0, 1]:
            reg += tf.reduce_sum(tf.square(self.U_p[level]))
            reg += tf.reduce_sum(tf.square(self.V_p[level]))
        
        # âœ… FIX: Divide by total number of parameters
        total_params = (
            tf.size(self.U_u, out_type=tf.float32) +
            tf.size(self.V_u, out_type=tf.float32)
        )
        
        for level in [0, 1]:
            total_params += tf.size(self.U_p[level], out_type=tf.float32)
            total_params += tf.size(self.V_p[level], out_type=tf.float32)
        
        # Normalize by total parameters
        reg = reg / total_params
        
        return reg
        
    def compute_total_loss(self, lambda1, lambda2, reg_weight):
        L1_total, L1_user, L1_poi, L1_poi_levels = self.compute_L1_loss()
        L2_total, L2_levels = self.compute_L2_loss()
        L_reg = self.compute_regularization()
        
        # âœ… CALIBRATE: Compute scales from first 10 epochs
        if self.epoch_counter < self.calibration_epochs:
            if self.L1_scale is None:
                self.L1_scale = float(L1_total.numpy())
                self.L2_scale = float(L2_total.numpy())
            else:
                # Running average
                alpha = 0.9
                self.L1_scale = alpha * self.L1_scale + (1 - alpha) * float(L1_total.numpy())
                self.L2_scale = alpha * self.L2_scale + (1 - alpha) * float(L2_total.numpy())
            
            self.epoch_counter += 1
        
        # âœ… NORMALIZE: Scale losses to comparable ranges
        L1_normalized = L1_total / max(self.L1_scale, 1e-6)
        L2_normalized = L2_total / max(self.L2_scale, 1e-6)
        
        # âœ… WEIGHT: Now lambda1 and lambda2 have real meaning
        total_loss = (lambda1 * L1_normalized) + \
                     (lambda2 * L2_normalized) + \
                     (reg_weight * L_reg)
        
        # Return un-normalized values for logging
        loss_components = {
            'total_loss': total_loss,
            'L1_total': L1_total,
            'L1_user': L1_user,
            'L1_poi': L1_poi,
            'L1_poi_levels': L1_poi_levels,
            'L2_total': L2_total,
            'L2_levels': L2_levels,
            'regularization': L_reg
        }
        
        return total_loss, loss_components
    
    def get_trainable_variables(self):
        """Return list of trainable variables"""
        trainable_vars = [self.U_u, self.V_u]
        
        for level in range(self.num_levels):
            trainable_vars.append(self.U_p[level])
            trainable_vars.append(self.V_p[level]) 
        
        return trainable_vars
    

In [56]:
class MultiLevelTrainer:
    def __init__(self, loss_computer, learning_rate=0.001, clipnorm=1.0):
        self.loss_computer = loss_computer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.history = defaultdict(list)
    
    def train(self, num_epochs=100, lambda1=0.5, lambda2=0.5, reg_weight=0.01,
            verbose=True, log_interval=10):
        """
        Train the model with joint optimization across all POI levels
        """
        print(f"\n{'='*80}")
        print(f"STARTING JOINT OPTIMIZATION OVER {self.loss_computer.num_levels} POI LEVELS")
        print(f"{'='*80}")
        print(f"Hyperparameters:")
        print(f"  Î»1 (Attribute loss weight): {lambda1}")
        print(f"  Î»2 (Interaction loss weight): {lambda2}")
        print(f"  Regularization weight: {reg_weight}")
        print(f"  Learning rate: {self.optimizer.learning_rate.numpy()}")
        print(f"  Epochs: {num_epochs}\n")
        
        for epoch in range(num_epochs):
            # Compute loss and gradients
            with tf.GradientTape() as tape:
                total_loss, loss_components = self.loss_computer.compute_total_loss(
                    lambda1, lambda2, reg_weight
                )
            
            # Compute gradients
            trainable_vars = self.loss_computer.get_trainable_variables()
            gradients = tape.gradient(total_loss, trainable_vars)

            gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=1.0)
            
            # Apply gradients
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
            
            # Log history
            self._log_history(loss_components)
            
            # Print progress
            if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1):
                self._print_progress(epoch, loss_components)
        
        print(f"\n{'='*80}")
        print("TRAINING COMPLETED")
        print(f"{'='*80}\n")
        
    def _log_history(self, loss_components):
        """Log training history - HANDLE DICTIONARIES PROPERLY"""
        
        # âœ… Convert tensor scalars to Python floats
        self.history['total_loss'].append(float(loss_components['total_loss'].numpy()))
        self.history['L1_total'].append(float(loss_components['L1_total'].numpy()))
        self.history['L1_user'].append(float(loss_components['L1_user'].numpy()))
        self.history['L1_poi'].append(float(loss_components['L1_poi'].numpy()))
        self.history['L2_total'].append(float(loss_components['L2_total'].numpy()))
        self.history['regularization'].append(float(loss_components['regularization'].numpy()))
        
        # âœ… Initialize nested dictionaries if needed
        if 'L1_poi_levels' not in self.history:
            self.history['L1_poi_levels'] = {
                level: [] for level in range(self.loss_computer.num_levels)
            }
        
        if 'L2_levels' not in self.history:
            self.history['L2_levels'] = {
                level: [] for level in range(self.loss_computer.num_levels)
            }
        
        # âœ… Extract values from dictionaries
        for level, loss_tensor in loss_components['L1_poi_levels'].items():
            self.history['L1_poi_levels'][level].append(float(loss_tensor.numpy()))
        
        for level, loss_tensor in loss_components['L2_levels'].items():
            self.history['L2_levels'][level].append(float(loss_tensor.numpy()))
    
    def _print_progress(self, epoch, loss_components):
        """Print training progress"""
        print(f"\nEpoch {epoch:4d}:")
        print(f"  Total Loss: {loss_components['total_loss'].numpy():12.4f}")  # âœ… Changed 'total' to 'total_loss'
        print(f"  â”œâ”€ L1 (Attribute):     {loss_components['L1_total'].numpy():12.4f}")
        print(f"  â”‚   â”œâ”€ User:          {loss_components['L1_user'].numpy():12.4f}")
        print(f"  â”‚   â””â”€ POI (total):   {loss_components['L1_poi'].numpy():12.4f}")
        
        # Print L1 POI losses per level
        for level in range(self.loss_computer.num_levels):
            level_name = self.loss_computer.data_prep.poi_loader.level_names[level]
            if level in loss_components['L1_poi_levels']:
                loss_val = loss_components['L1_poi_levels'][level].numpy()
                print(f"  â”‚       â”œâ”€ Level {level} ({level_name:15s}): {loss_val:12.4f}")
        
        print(f"  â”œâ”€ L2 (Interaction):   {loss_components['L2_total'].numpy():12.4f}")
        
        # Print L2 losses per level
        for level in range(self.loss_computer.num_levels):
            level_name = self.loss_computer.data_prep.poi_loader.level_names[level]
            if level in loss_components['L2_levels']:
                loss_val = loss_components['L2_levels'][level].numpy()
                print(f"  â”‚   â”œâ”€ Level {level} ({level_name:15s}): {loss_val:12.4f}")
        
        print(f"  â””â”€ Regularization:     {loss_components['regularization'].numpy():12.4f}")
    
    def plot_training_history(self):
        """Plot training curves"""
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Total loss
        axes[0, 0].plot(self.history['total_loss'])
        axes[0, 0].set_title('Total Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True)
        
        # L1 components
        axes[0, 1].plot(self.history['L1_total'], label='L1 Total')
        axes[0, 1].plot(self.history['L1_user'], label='L1 User', alpha=0.7)
        axes[0, 1].plot(self.history['L1_poi'], label='L1 POI', alpha=0.7)
        axes[0, 1].set_title('L1 Loss Components')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # L2 per level - âœ… FIXED: Access nested dictionary correctly
        axes[1, 0].plot(self.history['L2_total'], label='L2 Total', linewidth=2)
        for level in range(self.loss_computer.num_levels):
            level_name = self.loss_computer.data_prep.poi_loader.level_names[level]
            # âœ… Changed from self.history[f'L2_level_{level}'] to self.history['L2_levels'][level]
            if level in self.history['L2_levels']:
                axes[1, 0].plot(self.history['L2_levels'][level], 
                            label=f'Level {level} ({level_name})', alpha=0.7)
        axes[1, 0].set_title('L2 Loss Per Level')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Regularization
        axes[1, 1].plot(self.history['regularization'])
        axes[1, 1].set_title('Regularization Loss')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
        plt.show()

with open('../Sources/Embeddings/user_embeddings.pkl', 'rb') as f:
    user_embeddings = pickle.load(f)

with open('../Sources/Embeddings/poi_embeddings.pkl', 'rb') as f:
    poi_embeddings = pickle.load(f)

with open('../Sources/Embeddings/interactions.pkl', 'rb') as f:
    interactions = pickle.load(f)

with open('../Sources/Embeddings/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

data_prep = MultiLevelDataPreparation(
    poi_loader,
    user_embeddings,
    poi_embeddings,
    interactions,
    metadata
)

# Initialize loss computation
loss_computer = JointLossComputation(data_prep, embedding_dim=16)

# loss_computer.validate_tensors()

# Initialize trainer and run training
trainer = MultiLevelTrainer(loss_computer, learning_rate=0.001)

# Debug inspection
print("\n" + "="*80)
print("DETAILED INTERACTION STRUCTURE INSPECTION")
print("="*80)

level_key = 'level_0'

print(f"\n1. Checking interactions['{level_key}']:")
if level_key in interactions['interactions']:
    level_data = interactions['interactions'][level_key]
    print(f"   Type: {type(level_data)}")
    print(f"   Keys: {list(level_data.keys()) if isinstance(level_data, dict) else 'Not a dict'}")
    
    if 'edges' in level_data:
        edges = level_data['edges']
        print(f"\n2. Checking edges:")
        print(f"   Type: {type(edges)}")
        print(f"   Keys: {list(edges.keys())}")
        
        if 'user_indices' in edges:
            print(f"\n3. user_indices:")
            print(f"   Type: {type(edges['user_indices'])}")
            print(f"   Shape/Length: {edges['user_indices'].shape if hasattr(edges['user_indices'], 'shape') else len(edges['user_indices'])}")
            print(f"   First 5 values: {edges['user_indices'][:5]}")
        
        if 'poi_indices' in edges:
            print(f"\n4. poi_indices:")
            print(f"   Type: {type(edges['poi_indices'])}")
            print(f"   Shape/Length: {edges['poi_indices'].shape if hasattr(edges['poi_indices'], 'shape') else len(edges['poi_indices'])}")
            print(f"   First 5 values: {edges['poi_indices'][:5]}")

print("="*80)

trainer.train(
    num_epochs=500,
    lambda1=1.0,
    lambda2=1.0,
    reg_weight=0.001,
    verbose=True,
    log_interval=10
)

# Plot results
trainer.plot_training_history()

  poi_embeddings = pickle.load(f)
  interactions = pickle.load(f)



Preparing Attribute Matrices for L1...
  User attribute matrix X:
    Shape: (21, 71)
    Normalized: mean=-0.0000, std=0.9929
  POI attribute matrix Y^0:
    Shape: (4696, 20)
    Normalized: mean=0.0000, std=0.7416
  POI attribute matrix Y^1:
    Shape: (1355, 20)
    Normalized: mean=0.0000, std=0.6708
  POI attribute matrix Y^2:
    Shape: (44, 20)
    Normalized: mean=0.0000, std=0.0000
  POI attribute matrix Y^3:
    Shape: (5, 20)
    Normalized: mean=0.0000, std=0.0000

Preparing Interaction Matrices for L2...
  Level 0: Generated 257 negative samples (1 per positive)
  Level 1: Generated 231 negative samples (1 per positive)
  Level 2: Generated 180 negative samples (1 per positive)
  Level 3: Generated 71 negative samples (1 per positive)
  Level 0: 257 positive, 257 negative samples
  Level 1: 231 positive, 231 negative samples
  Level 2: 180 positive, 180 negative samples
  Level 3: 71 positive, 71 negative samples
  V_p^0: shape (20, 16) (feature_dim=20, embed_dim=16)
  V

KeyboardInterrupt: 