In [1]:
import json
import pickle
import pandas as pd
from typing import Dict, List, Set
from datetime import datetime

In [4]:
def extract_user_ids(users_file: str) -> List[str]:
    """
    Extract user UUIDs from user_preferences.csv
    
    Args:
        users_file: Path to user_preferences.csv
    
    Returns:
        List of user IDs in order
    """
    print(f"Loading users from: {users_file}")
    users_df = pd.read_csv(users_file)
    
    # Handle both 'uuid' and 'uudi' column names
    if 'uuid' in users_df.columns:
        user_ids = users_df['uuid'].tolist()
    elif 'uudi' in users_df.columns:
        user_ids = users_df['uudi'].tolist()
    else:
        raise ValueError("No 'uuid' or 'uudi' column found in user preferences file")
    
    print(f"  Extracted {len(user_ids)} user IDs")
    return user_ids


def extract_poi_ids_from_tree(poi_tree: Dict, level: int) -> List[str]:
    """
    Extract POI IDs from a specific level of the POI tree
    
    Args:
        poi_tree: Loaded POI tree dictionary
        level: Tree level (0-3)
    
    Returns:
        List of POI IDs at that level
    """
    level_key = f'level_{level}'
    
    if level_key not in poi_tree:
        print(f"  Warning: {level_key} not found in POI tree")
        return []
    
    poi_ids = list(poi_tree[level_key].keys())
    return poi_ids


def extract_all_poi_ids_recursive(poi_tree: Dict) -> Dict[int, List[str]]:
    """
    Extract all POI IDs from all levels using recursive traversal
    
    This method traverses the tree structure to ensure we capture
    all POIs including those that might only be referenced as children.
    
    Args:
        poi_tree: Loaded POI tree dictionary
    
    Returns:
        Dictionary mapping level -> list of POI IDs
    """
    level_poi_ids = {0: set(), 1: set(), 2: set(), 3: set()}
    
    def traverse_node(node: Dict, node_id: str, current_level: int):
        """Recursively traverse a node and its children"""
        # Add current node
        level_poi_ids[current_level].add(node_id)
        
        # Traverse children if they exist
        if 'children' in node and node.get('children'):
            children = node['children']
            child_level = current_level - 1  # Children are at finer granularity
            
            if child_level >= 0:
                if isinstance(children, list):
                    for child_id in children:
                        level_poi_ids[child_level].add(child_id)
                elif isinstance(children, dict):
                    for child_id, child_data in children.items():
                        if isinstance(child_data, dict):
                            traverse_node(child_data, child_id, child_level)
                        else:
                            level_poi_ids[child_level].add(child_id)
    
    # Method 1: Direct level extraction (primary method)
    for level in range(4):
        level_key = f'level_{level}'
        if level_key in poi_tree:
            for poi_id in poi_tree[level_key].keys():
                level_poi_ids[level].add(poi_id)
    
    # Method 2: Traverse from top level to catch any missed nodes
    if 'level_3' in poi_tree:
        for poi_id, poi_data in poi_tree['level_3'].items():
            traverse_node(poi_data, poi_id, 3)
    
    # Convert sets to sorted lists for consistent ordering
    return {level: sorted(list(ids)) for level, ids in level_poi_ids.items()}


def create_id_mappings(user_ids: List[str], 
					poi_ids_by_level: Dict[int, List[str]]) -> Dict:
    """
    Create bidirectional ID mappings (string <-> int)
    
    Args:
        user_ids: List of user IDs
        poi_ids_by_level: Dictionary mapping level -> list of POI IDs
    
    Returns:
        Dictionary containing all mappings
    """
    mappings = {}
    
    # User mappings
    mappings['user'] = {
        'id_to_idx': {uid: idx for idx, uid in enumerate(user_ids)},
        'idx_to_id': {idx: uid for idx, uid in enumerate(user_ids)},
        'count': len(user_ids)
    }
    
    # POI mappings per level
    mappings['poi'] = {}
    
    for level, poi_ids in poi_ids_by_level.items():
        mappings['poi'][f'level_{level}'] = {
            'id_to_idx': {pid: idx for idx, pid in enumerate(poi_ids)},
            'idx_to_id': {idx: pid for idx, pid in enumerate(poi_ids)},
            'count': len(poi_ids)
        }
    
    # Create a global POI mapping (all levels combined with level prefix)
    all_poi_ids = []
    poi_level_info = {}  # Maps global_idx -> (level, local_idx)
    
    global_idx = 0
    for level in range(4):
        level_key = f'level_{level}'
        if level_key in mappings['poi']:
            for local_idx, poi_id in enumerate(poi_ids_by_level[level]):
                all_poi_ids.append((level, poi_id))
                poi_level_info[global_idx] = (level, local_idx, poi_id)
                global_idx += 1
    
    mappings['poi']['global'] = {
        'id_to_idx': {f"L{level}_{pid}": idx for idx, (level, pid) in enumerate(all_poi_ids)},
        'idx_to_id': {idx: f"L{level}_{pid}" for idx, (level, pid) in enumerate(all_poi_ids)},
        'idx_to_level_info': poi_level_info,
        'count': len(all_poi_ids)
    }
    
    return mappings


def generate_metadata(users_file: str, 
					poi_tree_file: str, 
					output_file: str = 'metadata.pkl'):
    """
    Generate metadata.pkl with all ID mappings
    
    Args:
        users_file: Path to user_preferences.csv
        poi_tree_file: Path to poi_tree_with_uuids.json
        output_file: Output pickle file path
    """
    print("=" * 60)
    print("Generating Metadata (ID Mappings)")
    print("=" * 60)
    
    # Extract user IDs
    print("\n[Step 1] Extracting User IDs")
    user_ids = extract_user_ids(users_file)
    
    # Load POI tree
    print(f"\n[Step 2] Loading POI Tree from: {poi_tree_file}")
    with open(poi_tree_file, 'r', encoding='utf-8') as f:
        poi_tree = json.load(f)
    
    # Extract POI IDs from all levels
    print("\n[Step 3] Extracting POI IDs from all levels")
    poi_ids_by_level = extract_all_poi_ids_recursive(poi_tree)
    
    for level, poi_ids in poi_ids_by_level.items():
        print(f"  Level {level}: {len(poi_ids)} POIs")
    
    # Create mappings
    print("\n[Step 4] Creating ID mappings")
    mappings = create_id_mappings(user_ids, poi_ids_by_level)
    
    # Build complete metadata structure
    metadata = {
        # Core mappings
        'mappings': mappings,
        
        # Quick access to counts
        'counts': {
            'users': len(user_ids),
            'pois_level_0': len(poi_ids_by_level.get(0, [])),
            'pois_level_1': len(poi_ids_by_level.get(1, [])),
            'pois_level_2': len(poi_ids_by_level.get(2, [])),
            'pois_level_3': len(poi_ids_by_level.get(3, [])),
            'pois_total': sum(len(ids) for ids in poi_ids_by_level.values())
        },
        
        # Raw ID lists (for iteration)
        'user_ids': user_ids,
        'poi_ids': poi_ids_by_level,
        
        # Level names for reference
        'level_names': {
            0: 'Building',
            1: 'Street', 
            2: 'District',
            3: 'Region'
        },
        
        # Metadata
        'info': {
            'created_at': datetime.now().isoformat(),
            'source_files': {
                'users': users_file,
                'poi_tree': poi_tree_file
            },
            'version': '1.0'
        }
    }
    
    # Save to pickle
    print(f"\n[Step 5] Saving to: {output_file}")
    with open(output_file, 'wb') as f:
        pickle.dump(metadata, f)
    
    # Print summary
    print("\n" + "=" * 60)
    print("METADATA SUMMARY")
    print("=" * 60)
    print(f"\nUsers: {metadata['counts']['users']}")
    print(f"\nPOIs by Level:")
    for level in range(4):
        level_name = metadata['level_names'][level]
        count = metadata['counts'][f'pois_level_{level}']
        print(f"  Level {level} ({level_name}): {count}")
    print(f"\nTotal POIs: {metadata['counts']['pois_total']}")
    print(f"\nSaved to: {output_file}")
    print("=" * 60)
    
    return metadata


def load_metadata(input_file: str = 'metadata.pkl') -> Dict:
    """
    Load metadata from pickle file
    
    Args:
        input_file: Path to metadata.pkl
    
    Returns:
        Metadata dictionary
    """
    with open(input_file, 'rb') as f:
        metadata = pickle.load(f)
    return metadata


# Utility functions for using the metadata
class IDMapper:
    """
    Utility class for ID mapping operations
    """
    
    def __init__(self, metadata_file: str = 'metadata.pkl'):
        """
        Initialize ID mapper from metadata file
        
        Args:
            metadata_file: Path to metadata.pkl
        """
        self.metadata = load_metadata(metadata_file)
        self.mappings = self.metadata['mappings']
    
    def user_to_idx(self, user_id: str) -> int:
        """Convert user ID to index"""
        return self.mappings['user']['id_to_idx'].get(user_id, -1)
    
    def idx_to_user(self, idx: int) -> str:
        """Convert index to user ID"""
        return self.mappings['user']['idx_to_id'].get(idx, None)
    
    def poi_to_idx(self, poi_id: str, level: int) -> int:
        """Convert POI ID to index at specific level"""
        level_key = f'level_{level}'
        return self.mappings['poi'][level_key]['id_to_idx'].get(poi_id, -1)
    
    def idx_to_poi(self, idx: int, level: int) -> str:
        """Convert index to POI ID at specific level"""
        level_key = f'level_{level}'
        return self.mappings['poi'][level_key]['idx_to_id'].get(idx, None)
    
    def get_user_count(self) -> int:
        """Get total number of users"""
        return self.metadata['counts']['users']
    
    def get_poi_count(self, level: int) -> int:
        """Get number of POIs at specific level"""
        return self.metadata['counts'][f'pois_level_{level}']
    
    def get_all_user_ids(self) -> List[str]:
        """Get list of all user IDs"""
        return self.metadata['user_ids']
    
    def get_all_poi_ids(self, level: int) -> List[str]:
        """Get list of all POI IDs at specific level"""
        return self.metadata['poi_ids'].get(level, [])
    
    def batch_user_to_idx(self, user_ids: List[str]) -> List[int]:
        """Convert batch of user IDs to indices"""
        return [self.user_to_idx(uid) for uid in user_ids]
    
    def batch_poi_to_idx(self, poi_ids: List[str], level: int) -> List[int]:
        """Convert batch of POI IDs to indices"""
        return [self.poi_to_idx(pid, level) for pid in poi_ids]


In [6]:
if __name__ == "__main__":
    user_preferences_file = "../../Sources/user_preferences.csv"
    poi_tree_file = "../../Sources/poi_tree_with_uuids.json"
    
    metadata = generate_metadata(
        users_file=user_preferences_file,
        poi_tree_file=poi_tree_file
    )
    
    print("\n" + "=" * 60)
    print("DEMO: Using IDMapper")
    print("=" * 60)
    
    mapper = IDMapper('metadata.pkl')
    
    print("\nExample User Mapping:")
    if metadata["user_ids"]:
        sample_user = metadata["user_ids"][0]
        user_idx = mapper.user_to_idx(sample_user)
        print(f"  User '{sample_user}' -> Index {user_idx}")
        print(f"  Index {user_idx} -> User '{mapper.idx_to_user(user_idx)}'")
    
    print("\nExample POI Mappings (Level 0):")
    if metadata["poi_ids"].get(0):
        sample_poi = metadata["poi_ids"][0][0]
        poi_idx = mapper.poi_to_idx(sample_poi, level=0)
        print(f"  POI '{sample_poi}' -> Index {poi_idx}")
        print(f"  Index {poi_idx} -> POI '{mapper.idx_to_poi(poi_idx, level=0)}'")
    
    print("\nCounts:")
    print(f"  Users: {mapper.get_user_count()}")
    for level in range(4):
        print(f"  POIs Level {level}: {mapper.get_poi_count(level)}")


Generating Metadata (ID Mappings)

[Step 1] Extracting User IDs
Loading users from: ../../Sources/user_preferences.csv
  Extracted 21 user IDs

[Step 2] Loading POI Tree from: ../../Sources/poi_tree_with_uuids.json

[Step 3] Extracting POI IDs from all levels
  Level 0: 4696 POIs
  Level 1: 1355 POIs
  Level 2: 44 POIs
  Level 3: 5 POIs

[Step 4] Creating ID mappings

[Step 5] Saving to: metadata.pkl

METADATA SUMMARY

Users: 21

POIs by Level:
  Level 0 (Building): 4696
  Level 1 (Street): 1355
  Level 2 (District): 44
  Level 3 (Region): 5

Total POIs: 6100

Saved to: metadata.pkl

DEMO: Using IDMapper

Example User Mapping:
  User '966592ed-5bfd-4113-9c4d-d93cd3637b40' -> Index 0
  Index 0 -> User '966592ed-5bfd-4113-9c4d-d93cd3637b40'

Example POI Mappings (Level 0):
  POI 'poi_0_Giant' -> Index 0
  Index 0 -> POI 'poi_0_Giant'

Counts:
  Users: 21
  POIs Level 0: 4696
  POIs Level 1: 1355
  POIs Level 2: 44
  POIs Level 3: 5
