In [1]:
import pandas as pd
import numpy as np
import json
import networkx as nx
from typing import Dict, List, Tuple
from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity
import math
import pickle

In [None]:
class InteractionBasedRepresentationLearning:
    def __init__(self, 
                 embeddings_file: str,
                 interactions_file: str,
                 poi_tree_file: str,
                 users_file: str):
        """
        Initialize interaction-based representation learning
        
        Args:
            embeddings_file: Path to embeddings.pkl (from attribute-based learning)
            interactions_file: Path to user-POI interactions CSV
            poi_tree_file: Path to POI tree JSON
            users_file: Path to user preferences CSV
        """
        # Load pre-computed embeddings
        with open(embeddings_file, 'rb') as f:
            embeddings = pickle.load(f)
        
        self.X_A = embeddings['X_A']  # Explicit user features
        self.X_T = embeddings['X_T']  # Implicit user features
        self.user_embeddings = embeddings['user_embeddings']
        self.poi_embeddings = embeddings['poi_embeddings']
        self.user_id_to_idx = embeddings['user_id_to_idx']
        
        # Load raw data
        self.interactions_df = pd.read_csv(interactions_file)
        self.users_df = pd.read_csv(users_file)
        
        with open(poi_tree_file, 'r') as f:
            self.poi_tree = json.load(f)
        
        # Will store computed components
        self.Theta_u = None  # Trainable user parameters
        self.A_l_p = {}      # Inter-level POI features for each level
        self.G_l = {}        # POI context graphs for each level
        self.attention_params = {}  # Attention network parameters
        self.P_l = {}        # Complete user representation per level
        self.Q_l = {}        # Complete POI representation per level
        self.S_l = {}        # Feature-based check-in matrix per level
        self.U_l_g = {}      # Graph-based check-in representation per level
        
        print("Initialized Interaction-based Representation Learning")
        print(f"Users: {len(self.users_df)}")
        print(f"Interactions: {len(self.interactions_df)}")
        print(f"User embedding dim: {self.X_A.shape[1] + self.X_T.shape[1]}")
    
    # ========================================================================
    # STEP 2A: Initialize Trainable Parameters
    # ========================================================================
    
    def initialize_trainable_parameters(self, 
                                       user_embedding_dim: int = 16,
                                       attention_hidden_dim: int = 32,
                                       seed: int = 42):
        """
        Initialize trainable parameters:
        - Θ_u: Additional user embedding parameters
        - Attention network parameters (W_1, b_1, W_2, b_2)
        
        Args:
            user_embedding_dim: Dimension of trainable user embeddings
            attention_hidden_dim: Hidden layer size for attention network
            seed: Random seed for reproducibility
        """
        print("\n" + "="*60)
        print("STEP 2A: Initializing Trainable Parameters")
        print("="*60)
        
        np.random.seed(seed)
        
        num_users = len(self.users_df)
        implicit_poi_dim = self.X_T.shape[1]  # Should be 32
        
        # 1. Trainable user parameters Θ_u
        self.Theta_u = np.random.randn(num_users, user_embedding_dim) * 0.01
        print(f"Initialized Θ_u: shape {self.Theta_u.shape}")
        
        # 2. Attention network parameters (shared across all levels)
        self.attention_params = {
            'W_1': np.random.randn(attention_hidden_dim, implicit_poi_dim) * 0.01,
            'b_1': np.zeros(attention_hidden_dim),
            'W_2': np.random.randn(1, attention_hidden_dim) * 0.01,
            'b_2': np.zeros(1)
        }
        
        print(f"Initialized attention network:")
        print(f"  W_1: {self.attention_params['W_1'].shape}")
        print(f"  b_1: {self.attention_params['b_1'].shape}")
        print(f"  W_2: {self.attention_params['W_2'].shape}")
        print(f"  b_2: {self.attention_params['b_2'].shape}")
        
        return self.Theta_u, self.attention_params
    
    # ========================================================================
    # STEP 2B: Compute Inter-level POI Features with Attention
    # ========================================================================
    
    def _relu(self, x):
        """ReLU activation"""
        return np.maximum(0, x)
    
    def _sigmoid(self, x):
        """Sigmoid activation"""
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
    
    def _softmax(self, x):
        """Softmax normalization"""
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum()
    
    def compute_attention_weight(self, child_embedding: np.ndarray) -> float:
        """
        Compute attention weight for a child POI
        
        α = Sigmoid(W_2 * ReLU(W_1 * h_child + b_1) + b_2)
        
        Args:
            child_embedding: Implicit feature embedding of child POI
        
        Returns:
            Attention weight (scalar)
        """
        W_1 = self.attention_params['W_1']
        b_1 = self.attention_params['b_1']
        W_2 = self.attention_params['W_2']
        b_2 = self.attention_params['b_2']
        
        # Hidden layer with ReLU
        z = self._relu(W_1 @ child_embedding + b_1)  # Shape: (attention_hidden_dim,)
        
        # Output layer with Sigmoid
        alpha = self._sigmoid(W_2 @ z + b_2)  # Shape: (1,)
        
        return float(alpha[0])
    
    def compute_inter_level_features(self, level: int) -> Dict[str, np.ndarray]:
        """
        Compute inter-level POI features A^l_p for level l
        
        For each parent POI at level l:
        1. Get children POIs at level l-1
        2. Compute attention weights for each child
        3. Aggregate child embeddings using attention weights
        
        Args:
            level: Target level (must be >= 1, as level 0 has no children)
        
        Returns:
            Dictionary mapping parent_poi_id -> aggregated_embedding
        """
        if level == 0:
            print(f"Level 0 has no children, skipping inter-level features")
            return {}
        
        print(f"\nComputing A^{level}_p (Inter-level features for level {level})")
        print("-" * 60)
        
        parent_level_key = f'level_{level}'
        child_level_key = f'level_{level-1}'
        
        parent_pois = self.poi_tree[parent_level_key]
        child_pois_data = self.poi_tree[child_level_key]
        
        # Get child POI embeddings (implicit features Y_T)
        child_Y_T = self.poi_embeddings[child_level_key]['Y_T']
        child_poi_ids = self.poi_embeddings[child_level_key]['poi_ids']
        child_id_to_idx = {pid: idx for idx, pid in enumerate(child_poi_ids)}
        
        # Get child POI full embeddings (Y = [Y_A | Y_T])
        child_Y_full = self.poi_embeddings[child_level_key]['embeddings']
        
        A_l_p = {}
        aggregation_dim = child_Y_full.shape[1]  # Full embedding dimension
        
        processed = 0
        for parent_id, parent_data in parent_pois.items():
            children_ids = parent_data.get('children', [])
            
            if len(children_ids) == 0:
                # No children, use zero vector
                A_l_p[parent_id] = np.zeros(aggregation_dim)
                continue
            
            attention_weights = []
            child_embeddings_implicit = []
            child_embeddings_full = []
            
            for child_id in children_ids:
                if child_id not in child_id_to_idx:
                    continue
                
                child_idx = child_id_to_idx[child_id]
                
                # Get child implicit embedding (for attention computation)
                h_child_implicit = child_Y_T[child_idx]  # Shape: (32,)
                
                # Get child full embedding (for aggregation)
                h_child_full = child_Y_full[child_idx]
                
                # Compute attention weight using implicit features
                alpha = self.compute_attention_weight(h_child_implicit)
                
                attention_weights.append(alpha)
                child_embeddings_implicit.append(h_child_implicit)
                child_embeddings_full.append(h_child_full)
            
            if len(attention_weights) == 0:
                A_l_p[parent_id] = np.zeros(aggregation_dim)
                continue
            
            # Normalize attention weights using softmax
            attention_weights = self._softmax(np.array(attention_weights))
            
            # Aggregate child full embeddings using attention weights
            aggregated = sum(w * emb for w, emb in zip(attention_weights, child_embeddings_full))
            A_l_p[parent_id] = aggregated
            
            processed += 1
            if processed % 100 == 0:
                print(f"  Processed {processed}/{len(parent_pois)} parent POIs")
        
        print(f"Completed A^{level}_p: {len(A_l_p)} parent POIs")
        print(f"  Aggregated embedding dimension: {aggregation_dim}")
        
        return A_l_p
    
    def compute_all_inter_level_features(self, levels: List[int] = [1, 2, 3]):
        """
        Compute inter-level features for all specified levels
        """
        print("\n" + "="*60)
        print("STEP 2B: Computing Inter-level POI Features with Attention")
        print("="*60)
        
        for level in levels:
            self.A_l_p[f'level_{level}'] = self.compute_inter_level_features(level)
        
        print(f"\nCompleted inter-level features for levels: {levels}")
    
    # ========================================================================
    # STEP 2C: Build POI Context Graphs
    # ========================================================================
    
    def _haversine_distance(self, lat1: float, lon1: float, 
                           lat2: float, lon2: float) -> float:
        """Calculate distance between two coordinates in km"""
        R = 6371
        lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
        c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
        return R * c
    
    def compute_co_occurrence_matrix(self, 
                                    level: int, 
                                    interaction_type: str = 'visit',
                                    window_size: int = 1) -> np.ndarray:
        """
        Compute co-occurrence matrix for POIs at a given level
        
        Two POIs co-occur if a user interacts with both within a time window
        
        Args:
            level: POI tree level
            interaction_type: 'visit' or 'search'
            window_size: Time window in days
        
        Returns:
            Co-occurrence matrix (symmetric)
        """
        level_key = f'level_{level}'
        poi_ids = self.poi_embeddings[level_key]['poi_ids']
        poi_to_idx = {pid: idx for idx, pid in enumerate(poi_ids)}
        n_pois = len(poi_ids)
        
        co_matrix = np.zeros((n_pois, n_pois))
        
        # Filter interactions by type
        interactions = self.interactions_df[
            self.interactions_df['interaction_type'] == interaction_type
        ].copy()
        
        # Convert timestamp to datetime
        interactions['timestamp'] = pd.to_datetime(interactions['timestamp'])
        
        # Group by user
        for user_id, user_interactions in interactions.groupby('user_id'):
            user_interactions = user_interactions.sort_values('timestamp')
            
            poi_list = []
            for _, row in user_interactions.iterrows():
                poi_id = row['poi_id']
                
                # Map to current level if needed
                if level > 0:
                    poi_id = self._get_parent_at_level(poi_id, level)
                
                if poi_id in poi_to_idx:
                    poi_list.append(poi_id)
            
            # Count co-occurrences
            for i in range(len(poi_list)):
                for j in range(i+1, min(i+window_size+1, len(poi_list))):
                    poi_i, poi_j = poi_list[i], poi_list[j]
                    idx_i, idx_j = poi_to_idx[poi_i], poi_to_idx[poi_j]
                    
                    co_matrix[idx_i, idx_j] += 1
                    co_matrix[idx_j, idx_i] += 1  # Symmetric
        
        # Normalize
        max_val = co_matrix.max()
        if max_val > 0:
            co_matrix /= max_val
        
        return co_matrix
    
    def _get_parent_at_level(self, poi_id: str, target_level: int) -> str:
        """Get parent node of poi_id at target_level"""
        current_level = 0
        current_id = poi_id
        
        while current_level < target_level:
            level_key = f'level_{current_level}'
            if current_id in self.poi_tree[level_key]:
                parent = self.poi_tree[level_key][current_id].get('parent')
                if parent:
                    current_id = parent
                    current_level += 1
                else:
                    break
            else:
                break
        
        return current_id
    
    def build_poi_context_graph(self, 
                                level: int,
                                distance_threshold: float = 2.0,
                                co_search_weight: float = 0.3,
                                co_visit_weight: float = 0.5,
                                geo_weight: float = 0.2,
                                edge_threshold: float = 0.1) -> nx.Graph:
        """
        Build POI context graph G^l for level l
        
        G^l = <V^l, E^l>
        - V^l: POIs at level l
        - E^l: Edges with relationships (co-search, co-visit, geo-proximity)
        
        Args:
            level: POI tree level
            distance_threshold: Max distance (km) for geo-proximity edges
            co_search_weight: Weight for co-search edges
            co_visit_weight: Weight for co-visit edges
            geo_weight: Weight for geo-proximity edges
            edge_threshold: Minimum weight to create edge
        
        Returns:
            NetworkX graph
        """
        print(f"\nBuilding POI context graph G^{level}")
        print("-" * 60)
        
        level_key = f'level_{level}'
        pois = list(self.poi_tree[level_key].keys())
        n_pois = len(pois)
        
        print(f"Number of POIs: {n_pois}")
        
        # Initialize graph
        G = nx.Graph()
        G.add_nodes_from(pois)
        
        # 1. Compute co-search matrix
        print("Computing co-search relationships...")
        co_search = self.compute_co_occurrence_matrix(level, 'search', window_size=1)
        
        # 2. Compute co-visit matrix
        print("Computing co-visit relationships...")
        co_visit = self.compute_co_occurrence_matrix(level, 'visit', window_size=3)
        
        # 3. Add edges
        print("Building edges...")
        poi_to_idx = {pid: idx for idx, pid in enumerate(pois)}
        edges_added = 0
        
        for i, poi_i in enumerate(pois):
            poi_i_data = self.poi_tree[level_key][poi_i]
            spatial_i = poi_i_data['spatial']
            if isinstance(spatial_i, str):
                spatial_i = eval(spatial_i)
            lat_i, lon_i = spatial_i
            
            for j in range(i+1, n_pois):
                poi_j = pois[j]
                poi_j_data = self.poi_tree[level_key][poi_j]
                spatial_j = poi_j_data['spatial']
                if isinstance(spatial_j, str):
                    spatial_j = eval(spatial_j)
                lat_j, lon_j = spatial_j
                
                # Compute geo-proximity score
                dist = self._haversine_distance(lat_i, lon_i, lat_j, lon_j)
                
                if dist < distance_threshold:
                    geo_score = 1 / (1 + dist)  # Closer = higher score
                else:
                    geo_score = 0
                
                # Compute combined edge weight
                idx_i, idx_j = poi_to_idx[poi_i], poi_to_idx[poi_j]
                
                edge_weight = (co_search_weight * co_search[idx_i, idx_j] +
                              co_visit_weight * co_visit[idx_i, idx_j] +
                              geo_weight * geo_score)
                
                # Add edge if weight exceeds threshold
                if edge_weight > edge_threshold:
                    G.add_edge(poi_i, poi_j, weight=edge_weight)
                    edges_added += 1
            
            if (i + 1) % 100 == 0:
                print(f"  Processed {i+1}/{n_pois} POIs")
        
        print(f"Graph G^{level} constructed:")
        print(f"  Nodes: {G.number_of_nodes()}")
        print(f"  Edges: {G.number_of_edges()}")
        print(f"  Density: {nx.density(G):.4f}")
        print(f"  Avg degree: {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}")
        
        return G
    
    def build_all_poi_context_graphs(self, 
                                     levels: List[int] = [0, 1, 2, 3],
                                     distance_thresholds: Dict[int, float] = None):
        """
        Build POI context graphs for all specified levels
        
        Args:
            levels: List of levels to build graphs for
            distance_thresholds: Dict mapping level -> distance threshold
                                Default: level 0=2km, level 1=5km, level 2=10km, level 3=20km
        """
        print("\n" + "="*60)
        print("STEP 2C: Building POI Context Graphs")
        print("="*60)
        
        if distance_thresholds is None:
            distance_thresholds = {
                0: 2.0,   # 2km for individual POIs
                1: 5.0,   # 5km for containers/streets
                2: 10.0,  # 10km for districts
                3: 20.0   # 20km for regions
            }
        
        for level in levels:
            threshold = distance_thresholds.get(level, 5.0)
            self.G_l[f'level_{level}'] = self.build_poi_context_graph(
                level=level,
                distance_threshold=threshold
            )
        
        print(f"\nCompleted POI context graphs for levels: {levels}")
    
    # ========================================================================
    # STEP 2D: Compute Graph-based Check-in Representation
    # ========================================================================
    
    def get_user_history(self, user_id: str, level: int = 0) -> List[str]:
        """
        Get list of POIs user has visited at given level
        
        Args:
            user_id: User ID
            level: POI tree level
        
        Returns:
            List of POI IDs
        """
        user_interactions = self.interactions_df[
            (self.interactions_df['user_id'] == user_id) &
            (self.interactions_df['interaction_type'] == 'visit')
        ]
        
        history = []
        for _, row in user_interactions.iterrows():
            poi_id = row['poi_id']
            
            # Map to current level if needed
            if level > 0:
                poi_id = self._get_parent_at_level(poi_id, level)
            
            if poi_id not in history:
                history.append(poi_id)
        
        return history
    
    def compute_graph_based_checkin_single(self,
                                          user_id: str,
                                          candidate_poi: str,
                                          level: int) -> np.ndarray:
        """
        Compute graph-based check-in representation U^l_g for a single user-POI pair
        
        U^l_g[u, p_i] = (1/|H_u|) * Σ w(p_i, p_j) * Y^l[p_j]
                         for all p_j in H_u (user's history)
        
        Args:
            user_id: User ID
            candidate_poi: Candidate POI ID to recommend
            level: POI tree level
        
        Returns:
            Graph-based representation vector
        """
        level_key = f'level_{level}'
        G = self.G_l.get(level_key)
        
        if G is None:
            raise ValueError(f"Graph for level {level} not built yet")
        
        # Get user's historical visited POIs
        history = self.get_user_history(user_id, level)
        
        if len(history) == 0:
            # No history, return zero vector
            poi_embedding_dim = self.poi_embeddings[level_key]['embeddings'].shape[1]
            return np.zeros(poi_embedding_dim)
        
        # Get POI embeddings
        poi_ids = self.poi_embeddings[level_key]['poi_ids']
        poi_embeddings = self.poi_embeddings[level_key]['embeddings']
        poi_to_idx = {pid: idx for idx, pid in enumerate(poi_ids)}
        
        # Aggregate geo-spatial influence from history
        geo_influence = np.zeros(poi_embeddings.shape[1])
        
        for visited_poi in history:
            if visited_poi not in poi_to_idx:
                continue
            
            # Check if edge exists in graph
            if G.has_edge(candidate_poi, visited_poi):
                edge_weight = G[candidate_poi][visited_poi]['weight']
                visited_poi_idx = poi_to_idx[visited_poi]
                visited_poi_embedding = poi_embeddings[visited_poi_idx]
                
                geo_influence += edge_weight * visited_poi_embedding
        
        # Average over history
        geo_influence /= len(history)
        
        return geo_influence
    
    def compute_graph_based_checkin_matrix(self, level: int) -> np.ndarray:
        """
        Compute graph-based check-in matrix U^l_g for all user-POI pairs at level l
        
        Args:
            level: POI tree level
        
        Returns:
            U^l_g matrix: (num_users, num_pois, embedding_dim)
        """
        print(f"\nComputing U^{level}_g (Graph-based check-in matrix for level {level})")
        print("-" * 60)
        
        level_key = f'level_{level}'
        poi_ids = self.poi_embeddings[level_key]['poi_ids']
        user_ids = self.users_df['uudi'].tolist()
        
        n_users = len(user_ids)
        n_pois = len(poi_ids)
        embedding_dim = self.poi_embeddings[level_key]['embeddings'].shape[1]
        
        print(f"Computing for {n_users} users × {n_pois} POIs")
        
        U_l_g = np.zeros((n_users, n_pois, embedding_dim))
        
        for u_idx, user_id in enumerate(user_ids):
            for p_idx, poi_id in enumerate(poi_ids):
                U_l_g[u_idx, p_idx, :] = self.compute_graph_based_checkin_single(
                    user_id, poi_id, level
                )
            
            if (u_idx + 1) % 5 == 0:
                print(f"  Processed {u_idx+1}/{n_users} users")
        
        print(f"U^{level}_g computed: shape {U_l_g.shape}")
        
        return U_l_g
    
    def compute_all_graph_based_checkin_matrices(self, levels: List[int] = [0, 1, 2, 3]):
        """
        Compute graph-based check-in matrices for all levels
        """
        print("\n" + "="*60)
        print("STEP 2D: Computing Graph-based Check-in Representations")
        print("="*60)
        
        for level in levels:
            self.U_l_g[f'level_{level}'] = self.compute_graph_based_checkin_matrix(level)
        
        print(f"\nCompleted graph-based check-in matrices for levels: {levels}")
    
    # ========================================================================
    # Build Complete User & POI Representations P^l and Q^l
    # ========================================================================
    
    def build_complete_representations(self, level: int):
        """
        Build complete user and POI representations for level l:
        
        P^l = [X_A | X_T | Θ_u]
        Q^l = [Y_A^l | Y_T^l | A^l_p]
        
        Args:
            level: POI tree level
        """
        print(f"\nBuilding complete representations for level {level}")
        print("-" * 60)
        
        level_key = f'level_{level}'
        
        # Build P^l (user representation)
        P_l = np.hstack([self.X_A, self.X_T, self.Theta_u])
        self.P_l[level_key] = P_l
        
        print(f"P^{level} shape: {P_l.shape}")
        print(f"  X_A: {self.X_A.shape[1]} dims")
        print(f"  X_T: {self.X_T.shape[1]} dims")
        print(f"  Θ_u: {self.Theta_u.shape[1]} dims")
        
        # Build Q^l (POI representation)
        Y_A_l = self.poi_embeddings[level_key]['Y_A']
        Y_T_l = self.poi_embeddings[level_key]['Y_T']
        
        if level == 0:
            # Level 0 has no inter-level features
            Q_l = np.hstack([Y_A_l, Y_T_l])
        else:
            # Levels 1+ have inter-level features
            A_l_p_dict = self.A_l_p[level_key]
            poi_ids = self.poi_embeddings[level_key]['poi_ids']
            
            # Convert dict to matrix
            A_l_p_matrix = np.array([A_l_p_dict[pid] for pid in poi_ids])
            Q_l = np.hstack([Y_A_l, Y_T_l, A_l_p_matrix])
        
        self.Q_l[level_key] = Q_l
        
        print(f"Q^{level} shape: {Q_l.shape}")
        print(f"  Y_A^{level}: {Y_A_l.shape[1]} dims")
        print(f"  Y_T^{level}: {Y_T_l.shape[1]} dims")
        if level > 0:
            print(f"  A^{level}_p: {A_l_p_matrix.shape[1]} dims")
    
    def build_all_complete_representations(self, levels: List[int] = [0, 1, 2, 3]):
        """Build P^l and Q^l for all levels"""
        print("\n" + "="*60)
        print("Building Complete User & POI Representations")
        print("="*60)
        
        for level in levels:
            self.build_complete_representations(level)
    
    # ========================================================================
    # Compute Feature-based Check-in Matrix S^l
    # ========================================================================
    
    def compute_feature_based_checkin_matrix(self, level: int) -> np.ndarray:
        """
        Compute feature-based check-in matrix S^l
        
        S^l = P^l @ (Q^l)^T
        
        Args:
            level: POI tree level
        
        Returns:
            S^l: (num_users, num_pois) affinity matrix
        """
        print(f"\nComputing S^{level} (Feature-based check-in matrix)")
        print("-" * 60)
        
        level_key = f'level_{level}'
        P_l = self.P_l[level_key]
        Q_l = self.Q_l[level_key]
        
        # Pad to match dimensions if needed
        if P_l.shape[1] != Q_l.shape[1]:
            max_dim = max(P_l.shape[1], Q_l.shape[1])
            
            if P_l.shape[1] < max_dim:
                padding = np.zeros((P_l.shape[0], max_dim - P_l.shape[1]))
                P_l = np.hstack([P_l, padding])
            
            if Q_l.shape[1] < max_dim:
                padding = np.zeros((Q_l.shape[0], max_dim - Q_l.shape[1]))
                Q_l = np.hstack([Q_l, padding])
        
        # Compute affinity matrix
        S_l = P_l @ Q_l.T
        
        print(f"S^{level} shape: {S_l.shape}")
        print(f"  Min score: {S_l.min():.4f}")
        print(f"  Max score: {S_l.max():.4f}")
        print(f"  Mean score: {S_l.mean():.4f}")
        
        self.S_l[level_key] = S_l
        
        return S_l
    
    def compute_all_feature_based_checkin_matrices(self, levels: List[int] = [0, 1, 2, 3]):
        """Compute S^l for all levels"""
        print("\n" + "="*60)
        print("Computing Feature-based Check-in Matrices")
        print("="*60)
        
        for level in levels:
            self.compute_feature_based_checkin_matrix(level)
    
    # ========================================================================
    # Save & Load
    # ========================================================================
    
    def save_all(self, output_file: str = 'interaction_learning.pkl'):
        """Save all computed components"""
        data = {
            'Theta_u': self.Theta_u,
            'attention_params': self.attention_params,
            'A_l_p': self.A_l_p,
            'G_l': self.G_l,
            'P_l': self.P_l,
            'Q_l': self.Q_l,
            'S_l': self.S_l,
            'U_l_g': self.U_l_g
        }
        
        with open(output_file, 'wb') as f:
            pickle.dump(data, f)
        
        print(f"\n{'='*60}")
        print(f"All components saved to: {output_file}")
        print(f"{'='*60}")
    
    def load_all(self, input_file: str = 'interaction_learning.pkl'):
        """Load all components"""
        with open(input_file, 'rb') as f:
            data = pickle.load(f)
        
        self.Theta_u = data['Theta_u']
        self.attention_params = data['attention_params']
        self.A_l_p = data['A_l_p']
        self.G_l = data['G_l']
        self.P_l = data['P_l']
        self.Q_l = data['Q_l']
        self.S_l = data['S_l']
        self.U_l_g = data['U_l_g']
        
        print(f"All components loaded from: {input_file}")


In [3]:
if __name__ == "__main__":
    print("="*60)
    print("INTERACTION-BASED REPRESENTATION LEARNING")
    print("="*60)
    
    # Initialize
    learner = InteractionBasedRepresentationLearning(
        embeddings_file='embeddings.pkl',
        interactions_file='user_poi_interactions.csv',
        poi_tree_file='poi_tree_with_uuids.json',
        users_file='user_preferences.csv'
    )
    
    # STEP 2A: Initialize trainable parameters
    learner.initialize_trainable_parameters(
        user_embedding_dim=16,
        attention_hidden_dim=32
    )
    
    # STEP 2B: Compute inter-level POI features with attention
    learner.compute_all_inter_level_features(levels=[1, 2, 3])
    
    # STEP 2C: Build POI context graphs
    learner.build_all_poi_context_graphs(levels=[0, 1, 2, 3])
    
    # STEP 2D: Compute graph-based check-in representations
    learner.compute_all_graph_based_checkin_matrices(levels=[0, 1, 2, 3])
    
    # Build complete representations P^l and Q^l
    learner.build_all_complete_representations(levels=[0, 1, 2, 3])
    
    # Compute feature-based check-in matrices S^l
    learner.compute_all_feature_based_checkin_matrices(levels=[0, 1, 2, 3])
    
    # Save everything
    learner.save_all('interaction_learning.pkl')
    
    print("\n" + "="*60)
    print("INTERACTION-BASED LEARNING COMPLETE!")
    print("="*60)
    print("\nGenerated components:")
    print("  ✓ Θ_u: Trainable user parameters")
    print("  ✓ A^l_p: Inter-level POI features (levels 1-3)")
    print("  ✓ G^l: POI context graphs (levels 0-3)")
    print("  ✓ U^l_g: Graph-based check-in matrices (levels 0-3)")
    print("  ✓ P^l: Complete user representations (levels 0-3)")
    print("  ✓ Q^l: Complete POI representations (levels 0-3)")
    print("  ✓ S^l: Feature-based check-in matrices (levels 0-3)")
    
    # Example: Get recommendation scores for a user
    print("\n" + "="*60)
    print("EXAMPLE: Recommendation Scores")
    print("="*60)
    
    user_idx = 0
    level = 0
    
    S_0 = learner.S_l['level_0']
    user_scores = S_0[user_idx]  # Scores for all POIs
    
    # Top 10 recommendations
    top_10_indices = np.argsort(user_scores)[-10:][::-1]
    top_10_poi_ids = [learner.poi_embeddings['level_0']['poi_ids'][i] for i in top_10_indices]
    
    print(f"\nTop 10 POI recommendations for user {learner.users_df.iloc[user_idx]['name']}:")
    for rank, (poi_id, score) in enumerate(zip(top_10_poi_ids, user_scores[top_10_indices]), 1):
        poi_name = learner.poi_tree['level_0'][poi_id]['name']
        print(f"  {rank}. {poi_name} (score: {score:.4f})")

INTERACTION-BASED REPRESENTATION LEARNING
Initialized Interaction-based Representation Learning
Users: 21
Interactions: 529
User embedding dim: 82

STEP 2A: Initializing Trainable Parameters
Initialized Θ_u: shape (21, 16)
Initialized attention network:
  W_1: (32, 32)
  b_1: (32,)
  W_2: (1, 32)
  b_2: (1,)

STEP 2B: Computing Inter-level POI Features with Attention

Computing A^1_p (Inter-level features for level 1)
------------------------------------------------------------
  Processed 100/1355 parent POIs
  Processed 200/1355 parent POIs
  Processed 300/1355 parent POIs
  Processed 400/1355 parent POIs
  Processed 500/1355 parent POIs
  Processed 600/1355 parent POIs
  Processed 700/1355 parent POIs
  Processed 800/1355 parent POIs
  Processed 900/1355 parent POIs
Completed A^1_p: 1355 parent POIs
  Aggregated embedding dimension: 86

Computing A^2_p (Inter-level features for level 2)
------------------------------------------------------------
Completed A^2_p: 44 parent POIs
  Ag