In [1]:
# Loc2Vec: Learning Location Embeddings with Triplet Loss Networks
# Improved implementation following the paper and modern PyTorch practices
#
# Key Features:
# - Follows the original Sentiance paper architecture
# - Supports both custom CNN and transfer learning (ResNet, DenseNet)
# - Optimized for Kyiv tile dataset with coordinate normalization
# - Modern PyTorch 2.0+ features (torch.compile, AMP, etc.)
#
# Important Notes for Apple Silicon (M1/M2) Users:
# - torch.compile() is automatically disabled on MPS due to backend limitations
# - Mixed precision (AMP) is disabled on MPS as it's not yet supported
# - Set num_workers=0 to avoid potential multiprocessing issues on macOS
# - Performance is still excellent on Apple Silicon despite these limitations
#
# Usage:
# 1. Basic training: model, history = run_pipeline(CONFIG)
# 2. Custom config: CONFIG.model_type = 'resnet50'; CONFIG.batch_size = 32
# 3. Inference: embedding = model(image_tensor) with torch.no_grad()

import os
import random
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings('ignore')

########################################
# CONFIGURATION (with type hints)
########################################

@dataclass
class Config:
    """Configuration for Loc2Vec training"""
    # Model
    model_type: str = 'cnn'  # 'cnn', 'resnet18', 'resnet50', 'densenet121'
    use_pretrained: bool = True
    freeze_backbone: bool = False
    embedding_dim: int = 16  # Paper uses 16D
    image_size: int = 128   # Paper uses 128x128

    # Triplet loss
    margin: float = 1.0
    distance_threshold_km: float = 1.0  # 1km threshold for positive samples
    use_softpn_loss: bool = True

    # Training
    batch_size: int = 20  # Paper: 20 locations per batch
    pairs_per_location: int = 5  # Paper: 5 positive pairs per location
    num_epochs: int = 1
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    patience: int = 15
    use_amp: bool = True  # Mixed precision training (disabled for MPS)
    gradient_clip: float = 1.0

    # Data
    data_file: str = 'tiles/tiles.csv'
    train_ratio: float = 0.8
    val_ratio: float = 0.1
    random_seed: int = 42
    num_workers: int = 0  # Set to 0 for debugging, increase for performance

    # Kyiv-specific settings
    kyiv_center_lat: float = 50.4501  # Kyiv center coordinates
    kyiv_center_lon: float = 30.5234
    coordinate_normalize: bool = True  # Normalize coordinates for better learning

    # Output
    save_model: bool = True
    model_save_path: str = 'checkpoints/loc2vec_model.pth'
    plot_training: bool = True

CONFIG = Config()

########################################
# DEVICE AND SEED SETUP
########################################

def setup_environment(config: Config) -> torch.device:
    """Setup device and reproducibility"""
    # Device selection with better handling
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True  # Enable cudNN autotuner
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
        print("Using Apple Silicon GPU (MPS)")
    else:
        device = torch.device('cpu')
        print("Using CPU")

    # Reproducibility
    torch.manual_seed(config.random_seed)
    np.random.seed(config.random_seed)
    random.seed(config.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.random_seed)
        torch.backends.cudnn.deterministic = True

    # Create directories
    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs('plots', exist_ok=True)

    return device

device = setup_environment(CONFIG)

Using Apple Silicon GPU (MPS)


In [2]:
class Loc2VecCNN(nn.Module):
    """
    Simplified CNN following the paper architecture:
    - Conv layers with increasing filters (32->64->128->256->512)
    - Batch norm and dropout
    - Global average pooling
    - FC layers to embedding
    """

    def __init__(self, embedding_dim: int = 16):
        super().__init__()

        # Convolutional backbone
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(2),

            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(2),

            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(2),

            # Block 4
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(2),

            # Block 5
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.AdaptiveAvgPool2d(1),  # Global average pooling
        )

        # Embedding head
        self.embedder = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, embedding_dim)
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights using Kaiming initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.embedder(x)
        return F.normalize(x, p=2, dim=1)  # L2 normalize

In [3]:
########################################
# TRANSFER LEARNING MODELS
########################################

class Loc2VecTransferLearning(nn.Module):
    """Transfer learning with modern pretrained models"""

    def __init__(self, model_type: str = 'resnet18', pretrained: bool = True,
                 embedding_dim: int = 16, freeze_backbone: bool = False):
        super().__init__()

        # Select backbone with proper weight handling for PyTorch 2.0+
        if model_type == 'resnet18':
            weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
            self.backbone = models.resnet18(weights=weights)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()

        elif model_type == 'resnet50':
            weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            self.backbone = models.resnet50(weights=weights)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()

        elif model_type == 'densenet121':
            weights = models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
            self.backbone = models.densenet121(weights=weights)
            num_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        else:
            raise ValueError(f"Unknown model type: {model_type}")

        # Freeze backbone if requested
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Embedding head
        self.embedder = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        embeddings = self.embedder(features)
        return F.normalize(embeddings, p=2, dim=1)

    def unfreeze_backbone(self, unfreeze_ratio: float = 1.0):
        """Gradually unfreeze backbone layers"""
        layers = list(self.backbone.children())
        num_to_unfreeze = int(len(layers) * unfreeze_ratio)
        for layer in layers[-num_to_unfreeze:]:
            for param in layer.parameters():
                param.requires_grad = True

In [4]:
class TripletDataset(Dataset):
    """Dataset with proper coordinate handling for Kyiv"""

    def __init__(self, df: pd.DataFrame, transform=None, config: Config = CONFIG):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.config = config

        # Normalize coordinates if using Kyiv data
        if config.coordinate_normalize:
            self.coords = self._normalize_coordinates(df[['x', 'y']].values)
        else:
            self.coords = df[['x', 'y']].values

        # Convert distance threshold from km to normalized units
        self.pos_threshold = self._km_to_normalized(config.distance_threshold_km)
        self.neg_threshold = self.pos_threshold * 5  # Negative samples are 5x farther

        # Build spatial index
        self.spatial_index = NearestNeighbors(
            n_neighbors=min(100, len(self.df)),
            metric='euclidean',
            n_jobs=-1
        )
        self.spatial_index.fit(self.coords)

        print(f"Dataset: {len(self.df)} samples")
        print(f"Positive threshold: {self.pos_threshold:.4f} (normalized)")

    def _normalize_coordinates(self, coords: np.ndarray) -> np.ndarray:
        """Normalize coordinates relative to Kyiv center"""
        # Center coordinates
        centered = coords - np.array([self.config.kyiv_center_lon, self.config.kyiv_center_lat])
        # Scale to roughly unit variance
        scaled = centered / np.std(centered, axis=0)
        return scaled

    def _km_to_normalized(self, km: float) -> float:
        """Convert km to normalized coordinate units (approximate)"""
        # Rough conversion: 1 degree ≈ 111 km at equator
        # For Kyiv latitude: adjust for latitude
        lat_factor = np.cos(np.radians(self.config.kyiv_center_lat))
        degrees_per_km = 1.0 / (111.0 * lat_factor)

        if self.config.coordinate_normalize:
            # Account for normalization scaling
            std_lon = np.std(self.df['x'].values - self.config.kyiv_center_lon)
            return km * degrees_per_km / std_lon
        else:
            return km * degrees_per_km

    def __len__(self) -> int:
        return len(self.df)

    def _load_image(self, idx: int) -> torch.Tensor:
        """Load image with error handling"""
        try:
            img_path = self.df.iloc[idx]['path']
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {idx}: {e}")
            image = Image.new('RGB', (self.config.image_size, self.config.image_size))

        if self.transform:
            image = self.transform(image)
        return image

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get triplet following paper's strategy"""
        anchor_img = self._load_image(idx)

        # Find neighbors
        distances, indices = self.spatial_index.kneighbors([self.coords[idx]], n_neighbors=50)
        distances = distances[0]
        indices = indices[0]

        # Positive: nearby location (not self)
        pos_mask = (distances > 0) & (distances < self.pos_threshold)
        pos_candidates = indices[pos_mask]

        if len(pos_candidates) == 0:
            # Fallback: use closest non-self
            pos_candidates = indices[1:6]

        pos_idx = np.random.choice(pos_candidates)

        # Negative: far location
        neg_mask = distances > self.neg_threshold
        neg_candidates = indices[neg_mask]

        if len(neg_candidates) == 0:
            # Fallback: use farthest available
            neg_candidates = indices[-10:]

        neg_idx = np.random.choice(neg_candidates)

        return anchor_img, self._load_image(pos_idx), self._load_image(neg_idx)

In [5]:

########################################
# TRIPLET LOSS VARIANTS
########################################

class TripletLoss(nn.Module):
    """Standard triplet loss with margin"""

    def __init__(self, margin: float = 1.0):
        super().__init__()
        self.margin = margin

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor,
                negative: torch.Tensor) -> torch.Tensor:
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)

        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()


class SoftPNTripletLoss(nn.Module):
    """
    Soft Positive/Negative mining triplet loss (simplified version)
    Paper's key contribution - uses soft assignments
    """

    def __init__(self, margin: float = 1.0, temperature: float = 0.1):
        super().__init__()
        self.margin = margin
        self.temperature = temperature

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor,
                negative: torch.Tensor) -> torch.Tensor:
        # Standard triplet loss for now
        # Full SoftPN implementation would require batch-wise mining
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)

        # Soft margin using log-sum-exp
        loss = torch.log1p(torch.exp((pos_dist - neg_dist) / self.temperature))
        return loss.mean()

In [6]:
def create_transforms(config: Config) -> Tuple[transforms.Compose, transforms.Compose]:
    """Create augmentation transforms"""
    # Training augmentations (following paper)
    train_transform = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=45),  # Paper mentions rotation
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Validation/test transforms
    val_transform = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return train_transform, val_transform

########################################
# TRAINING WITH MODERN FEATURES
########################################

def train_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                optimizer: optim.Optimizer, device: torch.device,
                config: Config) -> float:
    """Training epoch with mixed precision support"""
    model.train()
    total_loss = 0.0

    pbar = tqdm.tqdm(loader, desc="Training")
    for batch_idx, (anchors, positives, negatives) in enumerate(pbar):
        anchors = anchors.to(device, non_blocking=True)
        positives = positives.to(device, non_blocking=True)
        negatives = negatives.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()

        anchor_emb = model(anchors)
        positive_emb = model(positives)
        negative_emb = model(negatives)
        loss = criterion(anchor_emb, positive_emb, negative_emb)

        loss.backward()

        if config.gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)

        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    return total_loss / len(loader)


@torch.no_grad()
def validate_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   device: torch.device) -> Tuple[float, torch.Tensor]:
    """Validation epoch"""
    model.eval()
    total_loss = 0.0
    all_embeddings = []

    for anchors, positives, negatives in tqdm.tqdm(loader, desc="Validation"):
        anchors = anchors.to(device, non_blocking=True)
        positives = positives.to(device, non_blocking=True)
        negatives = negatives.to(device, non_blocking=True)

        anchor_emb = model(anchors)
        positive_emb = model(positives)
        negative_emb = model(negatives)

        loss = criterion(anchor_emb, positive_emb, negative_emb)
        total_loss += loss.item()

        all_embeddings.append(anchor_emb.cpu())

    all_embeddings = torch.cat(all_embeddings, dim=0)
    return total_loss / len(loader), all_embeddings

########################################
# MAIN TRAINING FUNCTION
########################################

def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
                config: Config, device: torch.device) -> Dict:
    """Complete training loop with modern features"""
    criterion = SoftPNTripletLoss(margin=config.margin) if config.use_softpn_loss else TripletLoss(margin=config.margin)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    # Training state
    best_val_loss = float('inf')
    patience_counter = 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch+1}/{config.num_epochs}")

        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, config)
        history['train_loss'].append(train_loss)

        # Validate
        val_loss, _ = validate_epoch(model, val_loader, criterion, device)
        history['val_loss'].append(val_loss)

        # Update scheduler
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"LR: {scheduler.get_last_lr()[0]:.2e}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= config.patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # Load best model
    model.load_state_dict(best_state)
    history['best_val_loss'] = best_val_loss

    return history

In [7]:

########################################
# INFERENCE UTILITIES
########################################

@torch.no_grad()
def get_embedding(model: nn.Module, image_path: str, transform: transforms.Compose,
                  device: torch.device) -> np.ndarray:
    """Get embedding for a single image"""
    model.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Get embedding
    embedding = model(image_tensor)
    return embedding.cpu().numpy().squeeze()


def find_similar_tiles(query_embedding: np.ndarray, embeddings_db: np.ndarray,
                       coords_db: np.ndarray, top_k: int = 5) -> List[Dict]:
    """Find most similar tiles to a query embedding"""
    # Compute cosine similarities (embeddings are L2 normalized)
    similarities = embeddings_db @ query_embedding

    # Get top-k indices
    top_indices = np.argsort(similarities)[-top_k:][::-1]

    results = []
    for idx in top_indices:
        results.append({
            'index': idx,
            'similarity': similarities[idx],
            'coordinates': coords_db[idx],
            'distance_km': np.linalg.norm(coords_db[idx] - coords_db[0]) * 111  # Rough km conversion
        })

    return results


def create_embedding_database(model: nn.Module, df: pd.DataFrame,
                            transform: transforms.Compose, device: torch.device,
                            batch_size: int = 32) -> Tuple[np.ndarray, np.ndarray]:
    """Create database of embeddings for all tiles"""
    model.eval()
    embeddings = []
    coords = []

    # Create temporary dataset
    dataset = TripletDataset(df, transform=transform, config=CONFIG)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    print("Creating embedding database...")
    with torch.no_grad():
        for i, (anchors, _, _) in enumerate(tqdm(loader)):
            anchors = anchors.to(device)
            emb = model(anchors)
            embeddings.append(emb.cpu().numpy())

            # Get corresponding coordinates
            batch_coords = df.iloc[i*batch_size:(i+1)*batch_size][['x', 'y']].values
            coords.append(batch_coords)

    embeddings = np.vstack(embeddings)
    coords = np.vstack(coords)

    return embeddings, coords


########################################
# COMPLETE PIPELINE
########################################

def run_pipeline(config: Config = CONFIG) -> Tuple[nn.Module, Dict, pd.DataFrame]:
    """Run complete training pipeline"""
    print("Starting Loc2Vec Training Pipeline")
    print("=" * 50)

    # Load data
    print("\n1. Loading data...")
    df = pd.read_csv(config.data_file)
    df_full = df[df['service'] == 'full'].copy()

    # Check file existence
    df_full['exists'] = df_full['path'].apply(lambda x: Path(x).exists())
    df_full = df_full[df_full['exists']].drop(columns=['exists'])
    print(f"Found {len(df_full)} valid tiles")

    # Split data
    train_df, temp_df = train_test_split(df_full, test_size=1-config.train_ratio, random_state=config.random_seed)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=config.random_seed)

    print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

    # Create datasets and loaders
    print("\n2. Creating data loaders...")
    train_transform, val_transform = create_transforms(config)

    train_dataset = TripletDataset(train_df, transform=train_transform, config=config)
    val_dataset = TripletDataset(val_df, transform=val_transform, config=config)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    # Create model
    print(f"\n3. Creating {config.model_type} model...")
    if config.model_type == 'cnn':
        model = Loc2VecCNN(embedding_dim=config.embedding_dim)
    else:
        model = Loc2VecTransferLearning(
            model_type=config.model_type,
            pretrained=config.use_pretrained,
            embedding_dim=config.embedding_dim,
            freeze_backbone=config.freeze_backbone
        )

    model = model.to(device)

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")

    # Train model
    print("\n4. Training model...")
    history = train_model(model, train_loader, val_loader, config, device)

    # Save model
    if config.save_model:
        print("\n5. Saving model...")
        save_path = Path(config.model_save_path)
        save_path.parent.mkdir(exist_ok=True)

        torch.save({
            'model_state_dict': model.state_dict(),
            'config': config,
            'history': history,
            'pytorch_version': torch.__version__,
        }, save_path)

        print(f"Model saved to {save_path}")

    return model, history, df_full


# Run the pipeline
if __name__ == "__main__":
    # Quick test for MPS compatibility
    if device.type == 'mps':
        print("\n" + "="*50)
        print("Running on Apple Silicon (MPS)")
        print("Optimizations applied:")
        print("- torch.compile disabled (not supported on MPS)")
        print("- Mixed precision disabled (not supported on MPS)")
        print("- num_workers set to 0 (recommended for macOS)")
        print("="*50 + "\n")

    # Run with reduced epochs for testing
    test_config = Config(num_epochs=2)  # Quick test run

    try:
        model, history, df_full = run_pipeline(test_config)

        # Plot training curves
        if test_config.plot_training:
            plt.figure(figsize=(10, 6))
            plt.plot(history['train_loss'], label='Train Loss', marker='o')
            plt.plot(history['val_loss'], label='Val Loss', marker='s')
            plt.xlabel('Epoch')
            plt.ylabel('Triplet Loss')
            plt.title('Loc2Vec Training Progress')
            plt.legend()
            plt.grid(True, alpha=0.3)

            plot_path = Path('plots/training_curves.png')
            plot_path.parent.mkdir(exist_ok=True)
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"\nTraining curves saved to {plot_path}")
            plt.show()

        print("\n✅ Training completed successfully!")
        print(f"Best validation loss: {history['best_val_loss']:.4f}")

        # Example: How to use the trained model
        print("\n" + "="*50)
        print("EXAMPLE: Using the trained model")
        print("="*50)

        # Load validation transform
        _, val_transform = create_transforms(test_config)

        # Example 1: Get embedding for a single image
        print("\n1. Getting embedding for a single tile:")
        sample_path = df_full.iloc[0]['path']
        if Path(sample_path).exists():
            embedding = get_embedding(model, sample_path, val_transform, device)
            print(f"   Embedding shape: {embedding.shape}")
            print(f"   Embedding sample: {embedding[:5]}...")

        # Example 2: Find similar tiles
        print("\n2. Finding similar tiles:")
        print("   (This would require building an embedding database first)")
        print("   embeddings_db, coords_db = create_embedding_database(model, df, transform, device)")
        print("   similar = find_similar_tiles(query_embedding, embeddings_db, coords_db)")

    except Exception as e:
        print(f"\n❌ Error during training: {e}")
        print("\nTroubleshooting tips:")
        print("1. Ensure tiles.csv and image files exist")
        print("2. Check file paths in the CSV")
        print("3. Try reducing batch_size if out of memory")
        print("4. Set num_workers=0 if multiprocessing issues")
        raise


Running on Apple Silicon (MPS)
Optimizations applied:
- torch.compile disabled (not supported on MPS)
- Mixed precision disabled (not supported on MPS)
- num_workers set to 0 (recommended for macOS)

Starting Loc2Vec Training Pipeline

1. Loading data...
Found 16790 valid tiles
Train: 13432 | Val: 1679 | Test: 1679

2. Creating data loaders...
Dataset: 13432 samples
Positive threshold: 0.1220 (normalized)
Dataset: 1679 samples
Positive threshold: 0.1237 (normalized)

3. Creating cnn model...
Total params: 1,638,288 | Trainable: 1,638,288

4. Training model...

Epoch 1/2

❌ Error during training: 'module' object is not callable. Did you mean: 'tqdm.tqdm(...)'?

Troubleshooting tips:
1. Ensure tiles.csv and image files exist
2. Check file paths in the CSV
3. Try reducing batch_size if out of memory
4. Set num_workers=0 if multiprocessing issues


TypeError: 'module' object is not callable. Did you mean: 'tqdm.tqdm(...)'?

In [None]:
# ========================================
# LEGACY VISUALIZATION CODE (REFERENCE ONLY)
# ========================================

# Note: This cell contains the original standalone visualization code.
# The actual visualization functions are now integrated in the next cell.
# You can run this cell for reference, but the main workflow uses the 
# integrated functions in the next cell.

print("📚 Legacy visualization code loaded (reference only)")
print("🎨 Main visualization functions are in the next cell")
print("✨ Use create_embeddings_visualization() after training completes")

In [None]:
# ========================================
# EMBEDDING VISUALIZATION FOR TENSORBOARD
# ========================================

# Note: SummaryWriter import is needed for TensorBoard visualization
from torch.utils.tensorboard import SummaryWriter

class InferenceDataset(Dataset):
    """Dataset for inference - loads images with metadata"""

    def __init__(self, df: pd.DataFrame, transform=None, config: Config = None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.config = config or Config()

        # Normalize coordinates if needed
        if self.config.coordinate_normalize:
            self.coords = self._normalize_coordinates(df[['x', 'y']].values)
        else:
            self.coords = df[['x', 'y']].values

    def _normalize_coordinates(self, coords: np.ndarray) -> np.ndarray:
        """Normalize coordinates relative to Kyiv center"""
        centered = coords - np.array([self.config.kyiv_center_lon, self.config.kyiv_center_lat])
        scaled = centered / np.std(centered, axis=0)
        return scaled

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = row['path']

        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (self.config.image_size, self.config.image_size), color='red')

        if self.transform:
            image = self.transform(image)

        # Return image and metadata
        metadata = {
            'coordinates': [row['x'], row['y']],
            'normalized_coords': self.coords[idx].tolist() if hasattr(self, 'coords') else [row['x'], row['y']],
            'zoom': row.get('zoom', 0),
            'path': img_path,
            'index': idx
        }

        return image, metadata


def custom_collate_fn(batch):
    """Custom collate function to handle metadata dictionaries"""
    images = torch.stack([item[0] for item in batch])
    metadata_list = [item[1] for item in batch]  # Keep as list of dicts
    return images, metadata_list


@torch.no_grad()
def generate_embeddings_and_thumbnails(
    model: nn.Module,
    dataset: InferenceDataset,
    config: Config,
    device: torch.device,
    max_tiles: int = 2000
) -> Tuple[torch.Tensor, List[List[str]], torch.Tensor, List[dict], List[str]]:
    """Generate embeddings, metadata, and thumbnails for projector"""
    
    loader = DataLoader(
        dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=0,  # Use 0 for macOS compatibility
        pin_memory=True if device.type != 'cpu' else False,
        collate_fn=custom_collate_fn  # Use custom collate function
    )

    all_embeddings = []
    all_metadata = []
    all_thumbnails = []
    all_metadata_dicts = []

    # Thumbnail transform for projector
    thumbnail_transform = transforms.Compose([
        transforms.Resize((32, 32)),  # Small thumbnails
        transforms.ToTensor(),
    ])

    print(f"Generating embeddings for visualization...")
    
    total_processed = 0
    for batch_idx, (images, metadata_batch) in enumerate(tqdm.tqdm(loader, desc="Processing tiles")):
        if total_processed >= max_tiles:
            break
            
        images = images.to(device)
        
        # Generate embeddings
        embeddings = model(images)
        all_embeddings.append(embeddings.cpu())
        
        # Process metadata and create thumbnails
        for i, metadata in enumerate(metadata_batch):
            if total_processed >= max_tiles:
                break
                
            # Format metadata for TensorBoard
            coords = metadata['coordinates']
            norm_coords = metadata['normalized_coords']
            zoom = metadata['zoom']
            idx = metadata['index']
            
            # Create metadata list (not string) for TensorBoard
            metadata_list = [
                f"{coords[0]:.6f}",
                f"{coords[1]:.6f}", 
                f"{norm_coords[0]:.4f}",
                f"{norm_coords[1]:.4f}",
                str(zoom),
                str(idx)
            ]
            all_metadata.append(metadata_list)
            all_metadata_dicts.append(metadata)
            
            # Create thumbnail from original image
            # Note: We need to denormalize the image first
            img_tensor = images[i].cpu()
            
            # Denormalize
            mean = torch.tensor([0.485, 0.456, 0.406])
            std = torch.tensor([0.229, 0.224, 0.225])
            img_tensor = img_tensor * std.view(3, 1, 1) + mean.view(3, 1, 1)
            img_tensor = torch.clamp(img_tensor, 0, 1)
            
            # Create thumbnail
            thumbnail = thumbnail_transform(transforms.ToPILImage()(img_tensor))
            all_thumbnails.append(thumbnail)
            
            total_processed += 1

    # Combine all data
    embeddings_tensor = torch.cat(all_embeddings, dim=0)[:max_tiles]
    thumbnails_tensor = torch.stack(all_thumbnails)[:max_tiles]
    
    # Define header for metadata  
    metadata_header = ["lon", "lat", "norm_lon", "norm_lat", "zoom", "index"]

    print(f"✅ Generated {len(embeddings_tensor)} embeddings")
    
    return embeddings_tensor, all_metadata[:max_tiles], thumbnails_tensor, all_metadata_dicts[:max_tiles], metadata_header


def create_projector_visualization(
    embeddings: torch.Tensor,
    metadata: List[List[str]],
    thumbnails: torch.Tensor,
    metadata_dicts: List[dict],
    metadata_header: List[str],
    log_dir: str = "runs/loc2vec_projector"
):
    """Create TensorBoard projector visualization"""
    
    print(f"Creating TensorBoard projector in {log_dir}...")
    
    # Create log directory
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    
    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir)
    
    # Add embeddings with metadata and thumbnails
    writer.add_embedding(
        mat=embeddings,
        metadata=metadata,
        label_img=thumbnails,
        tag="Loc2Vec_Embeddings",
        metadata_header=metadata_header
    )
    
    # Add some statistics
    writer.add_text("Dataset_Info", f"""
    ## Loc2Vec Embedding Visualization
    
    - **Total tiles**: {len(embeddings)}
    - **Embedding dimension**: {embeddings.shape[1]}
    - **Coordinate range**: 
      - Longitude: {min(m['coordinates'][0] for m in metadata_dicts):.4f} to {max(m['coordinates'][0] for m in metadata_dicts):.4f}
      - Latitude: {min(m['coordinates'][1] for m in metadata_dicts):.4f} to {max(m['coordinates'][1] for m in metadata_dicts):.4f}
    
    ## How to use:
    1. Select the "Loc2Vec_Embeddings" in the projector
    2. Try different projection methods (PCA, t-SNE, UMAP)
    3. Color points by metadata (coordinates, zoom level)
    4. Hover over points to see thumbnail images
    5. Search for specific tiles by index
    """)
    
    writer.close()
    
    print("✅ Projector visualization created!")
    print(f"\n🚀 To view the visualization, run:")
    print(f"   ./run_visualization.sh")
    print(f"   OR manually: tensorboard --logdir=runs")
    print(f"\n   Then open your browser to http://localhost:6006")
    print(f"   Navigate to the 'PROJECTOR' tab")


# ========================================
# CREATE VISUALIZATION AFTER TRAINING
# ========================================

def create_embeddings_visualization(model, df_data, config, device, max_tiles=2000):
    """Create embeddings visualization for trained model"""
    
    print("\n" + "="*60)
    print("🎨 CREATING TENSORBOARD EMBEDDINGS VISUALIZATION")
    print("="*60)
    
    # Filter valid tiles
    df_vis = df_data[df_data['service'] == 'full'].copy()
    df_vis['exists'] = df_vis['path'].apply(lambda x: Path(x).exists())
    df_valid = df_vis[df_vis['exists']].drop(columns=['exists'])
    
    print(f"Found {len(df_valid)} valid tiles for visualization")
    
    # Limit for performance
    if len(df_valid) > max_tiles:
        print(f"Limiting to {max_tiles} tiles for performance")
        df_valid = df_valid.sample(n=max_tiles, random_state=config.random_seed)
    
    # Create dataset
    _, val_transform = create_transforms(config)
    dataset = InferenceDataset(df_valid, transform=val_transform, config=config)
    
    # Generate embeddings and thumbnails
    embeddings, metadata, thumbnails, metadata_dicts, metadata_header = generate_embeddings_and_thumbnails(
        model, dataset, config, device, max_tiles
    )
    
    # Create projector visualization
    create_projector_visualization(embeddings, metadata, thumbnails, metadata_dicts, metadata_header)
    
    return embeddings, metadata_dicts

print("🎨 Embedding visualization functions loaded!")
print("Use create_embeddings_visualization(model, df, config, device) to generate TensorBoard projector data")


In [None]:
# ========================================
# GENERATE EMBEDDINGS VISUALIZATION
# ========================================

# After training is complete, generate TensorBoard visualization
if 'model' in locals() and 'df_full' in locals():
    print("\n" + "="*60)
    print("🎨 GENERATING EMBEDDINGS VISUALIZATION")
    print("="*60)
    
    try:
        # Create embeddings visualization
        embeddings, metadata_dicts = create_embeddings_visualization(
            model=model,
            df_data=df_full, 
            config=CONFIG,
            device=device,
            max_tiles=2000  # Adjust this number based on your needs
        )
        
        print(f"\n✅ Visualization created with {len(embeddings)} tile embeddings!")
        print("\n🎯 Next steps:")
        print("1. Run: ./run_visualization.sh")
        print("2. Open http://localhost:6006 in your browser") 
        print("3. Navigate to the PROJECTOR tab")
        print("4. Explore your Loc2Vec embeddings!")
        
        print("\n🔍 What to look for:")
        print("- Geographic clustering: Similar locations should cluster together")
        print("- Urban patterns: Different development types should separate")
        print("- Smooth transitions: Neighboring areas should have similar embeddings")
        print("- Scale consistency: Similar zoom levels should cluster")
        
    except Exception as e:
        print(f"❌ Error creating visualization: {e}")
        print("You can still create it manually later using:")
        print("create_embeddings_visualization(model, df_full, CONFIG, device)")
        
else:
    print("🎨 Visualization functions ready!")
    print("Train your model first, then run:")
    print("create_embeddings_visualization(model, df_full, CONFIG, device)")
