In [None]:
import gc
import time

import numpy as np
import psutil
import torch.optim as optim
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

In [None]:
def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process()
    return process.memory_info().rss / 1024 / 1024


def count_parameters(model):
    """Count trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def evaluate_embeddings(model, train_loader, device, max_samples=500):
    """Evaluate embedding quality using silhouette score"""
    model.eval()
    embeddings = []
    spatial_labels = []

    with torch.no_grad():
        sample_count = 0
        for batch_data in train_loader:
            if sample_count >= max_samples:
                break

            try:
                anchor = batch_data["anchor_image"].to(device)
                anchor_emb = model(anchor).cpu().numpy()
                print("Processing batch with shape:", anchor_emb.shape)

                # Check for NaN or infinite values
                if np.any(np.isnan(anchor_emb)) or np.any(np.isinf(anchor_emb)):
                    print("Warning: NaN/Inf detected in embeddings, skipping batch")
                    continue

                embeddings.append(anchor_emb)

                # Create spatial pseudo-labels based on coordinates if available
                # If no coordinates, create labels based on batch position as approximation
                if "coordinates" in batch_data:
                    coords = batch_data["coordinates"].numpy()
                    # Discretize coordinates into spatial bins for clustering
                    lat_bins = np.digitize(
                        coords[:, 0],
                        bins=np.linspace(coords[:, 0].min(), coords[:, 0].max(), 10),
                    )
                    lon_bins = np.digitize(
                        coords[:, 1],
                        bins=np.linspace(coords[:, 1].min(), coords[:, 1].max(), 10),
                    )
                    labels = lat_bins * 10 + lon_bins
                else:
                    # Fallback: use simple sequential labeling
                    labels = np.full(
                        anchor_emb.shape[0], len(embeddings) % 5
                    )  # Create 5 clusters

                spatial_labels.append(labels)
                sample_count += anchor_emb.shape[0]

            except Exception as e:
                print(f"Warning: Error processing batch in embedding evaluation: {e}")
                continue

    if len(embeddings) < 2:
        print("Warning: Not enough valid embeddings for silhouette score")
        return 0.0  # Not enough data for silhouette score

    try:
        # Concatenate all embeddings and labels
        all_embeddings = np.vstack(embeddings)
        all_labels = np.concatenate(spatial_labels)

        # Final check for NaN values
        if np.any(np.isnan(all_embeddings)) or np.any(np.isinf(all_embeddings)):
            print("Warning: NaN/Inf found in concatenated embeddings")
            return 0.0

        # If we don't have real spatial labels, use KMeans clustering
        if "coordinates" not in batch_data:
            n_clusters = min(
                5, max(2, len(np.unique(all_labels)))
            )  # Ensure 2-5 clusters
            if n_clusters > 1 and len(all_embeddings) >= n_clusters:
                try:
                    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
                    cluster_labels = kmeans.fit_predict(all_embeddings)
                except Exception as e:
                    print(f"Warning: KMeans clustering failed: {e}")
                    return 0.0
            else:
                return 0.0
        else:
            cluster_labels = all_labels

        # Calculate silhouette score
        if len(np.unique(cluster_labels)) > 1 and len(all_embeddings) > 1:
            try:
                # Ensure we have enough samples per cluster
                unique_labels, counts = np.unique(cluster_labels, return_counts=True)
                if np.all(counts >= 1) and len(unique_labels) >= 2:
                    silhouette_avg = silhouette_score(all_embeddings, cluster_labels)
                    return float(silhouette_avg)
                else:
                    return 0.0
            except Exception as e:
                print(f"Warning: Silhouette score calculation failed: {e}")
                return 0.0
        else:
            return 0.0

    except Exception as e:
        print(f"Warning: Error in embedding evaluation: {e}")
        return 0.0

# Model Definitions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import (
    ConvNeXt_Small_Weights,
    EfficientNet_B0_Weights,
    EfficientNet_V2_M_Weights,
    EfficientNet_V2_S_Weights,
    MobileNet_V3_Large_Weights,
    MobileNet_V3_Small_Weights,
    ResNet50_Weights,
    Swin_S_Weights,
)


class Loc2VecModel(nn.Module):
    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
    ):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 32, kernel_size=3, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, 3, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 64, 1, padding=0, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(1024, 64, bias=True),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, embedding_dim, bias=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class SoftmaxTripletLoss(nn.Module):
    """
    Triplet Soft-max ratio loss (Ailon et al.) with optional SoftPN variant.
    Minimises  MSE( (d_plus, d_minus), (0, 1) ).

    Args
    ----
    softpn : bool
        If True, use SoftPN (replace Δ(a,n) by min(Δ(a,n), Δ(p,n))).
    squared : bool
        If True, use squared Euclidean distance; else plain L2.
    reduction : str
        'mean' | 'sum' | 'none'   (mirrors PyTorch's reduction semantics)
    eps : float
        Numerical stabiliser added to denominator.
    """

    def __init__(
        self,
        softpn: bool = False,
        squared: bool = True,
        reduction: str = "mean",
        eps: float = 1e-8,
    ):
        super().__init__()
        if reduction not in ("mean", "sum", "none"):
            raise ValueError("reduction must be 'mean', 'sum' or 'none'")
        self.softpn = softpn
        self.squared = squared
        self.reduction = reduction
        self.eps = eps

    @staticmethod
    def _l2(a: torch.Tensor, b: torch.Tensor, squared: bool) -> torch.Tensor:
        out = (a - b).pow(2).sum(dim=1)
        return out if squared else out.clamp_min(1e-12).sqrt()

    def forward(
        self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor
    ) -> torch.Tensor:
        """
        anchor, positive, negative  : shape (B, embedding_dim)
        returns scalar loss (or per-sample loss if reduction='none')
        """
        delta_ap = self._l2(anchor, positive, self.squared)
        delta_an = self._l2(anchor, negative, self.squared)

        if self.softpn:
            delta_pn = self._l2(positive, negative, self.squared)
            delta_neg = torch.min(delta_an, delta_pn)
        else:
            delta_neg = delta_an

        exp_ap = torch.exp(delta_ap)
        exp_neg = torch.exp(delta_neg)
        denom = exp_ap + exp_neg + self.eps

        d_plus = exp_ap / denom  # expected → 0
        d_minus = exp_neg / denom  # expected → 1

        loss_vec = (d_plus**2) + ((d_minus - 1) ** 2)  # MSE vs (0,1)

        if self.reduction == "mean":
            return loss_vec.mean()
        elif self.reduction == "sum":
            return loss_vec.sum()
        else:  # 'none'
            return loss_vec


class Loc2VecTripletLoss(nn.Module):
    def __init__(self, pos_target=0.0, neg_target=1.0):
        super().__init__()
        self.pos_target = pos_target
        self.neg_target = neg_target

    def forward(self, anchor_i, anchor_p, anchor_n):
        distance_i_p = F.pairwise_distance(anchor_i, anchor_p)
        distance_i_n = F.pairwise_distance(anchor_i, anchor_n)

        loss = (
            (distance_i_p - self.pos_target) ** 2
            + (distance_i_n - self.neg_target) ** 2
        ).mean()

        np_distance_a_pos = distance_i_p.mean().item()
        np_distance_a_neg = distance_i_n.mean().item()

        loss_log = f"LOSS: {loss.item():.3f} | (+) DIST: {np_distance_a_pos:.3f} | (-) DIST: {np_distance_a_neg:.3f}"

        return loss  # loss_log # np_distance_a_pos, np_distance_a_neg, distance_i_n.min().item()


class EfficientNetLoc2Vec(nn.Module):
    """Transfer learning model using EfficientNet B0 as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
        self.backbone = models.efficientnet_b0(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate), nn.Linear(num_features, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class EfficientNetV2SLoc2Vec(nn.Module):
    """Transfer learning model using EfficientNetV2-S as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
        self.backbone = models.efficientnet_v2_s(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate), nn.Linear(num_features, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class EfficientNetV2MLoc2Vec(nn.Module):
    """Transfer learning model using EfficientNetV2-M as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = EfficientNet_V2_M_Weights.DEFAULT if pretrained else None
        self.backbone = models.efficientnet_v2_m(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate), nn.Linear(num_features, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class ResNetLoc2Vec(nn.Module):
    """Transfer learning model using ResNet50 as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = ResNet50_Weights.DEFAULT if pretrained else None
        self.backbone = models.resnet50(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.conv1
            self.backbone.conv1 = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p=dropout_rate), nn.Linear(num_features, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class ConvNeXtLoc2Vec(nn.Module):
    """Transfer learning model using ConvNeXt-Small as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = ConvNeXt_Small_Weights.DEFAULT if pretrained else None
        self.backbone = models.convnext_small(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        # Fix ConvNeXt shape issue - replace classifier completely
        num_features = self.backbone.classifier[2].in_features  # 768
        self.backbone.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # Global average pooling
            nn.Flatten(),  # Flatten to (batch_size, 768)
            nn.LayerNorm(num_features),  # Layer normalization
            nn.Dropout(p=dropout_rate),  # Dropout
            nn.Linear(num_features, embedding_dim),  # Final linear layer
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class SwinTransformerLoc2Vec(nn.Module):
    """Transfer learning model using Swin Transformer Small as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        if input_channels != 3:
            raise NotImplementedError("Swin Transformer only supports 3 input channels")

        weights = Swin_S_Weights.DEFAULT if pretrained else None
        self.backbone = models.swin_s(weights=weights)

        num_features = self.backbone.head.in_features
        self.backbone.head = nn.Sequential(
            nn.Dropout(p=dropout_rate), nn.Linear(num_features, embedding_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class MobileNetV3Loc2Vec(nn.Module):
    """Transfer learning model using MobileNetV3-Large as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
        self.backbone = models.mobilenet_v3_large(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        num_features = self.backbone.classifier[3].in_features
        self.backbone.classifier[3] = nn.Linear(num_features, embedding_dim)
        self.backbone.classifier[0] = nn.Dropout(p=dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


class MobileNetV3SmallLoc2Vec(nn.Module):
    """Transfer learning model using MobileNetV3-Small as backbone for Loc2Vec embeddings."""

    def __init__(
        self,
        input_channels: int = 3,
        embedding_dim: int = 16,
        dropout_rate: float = 0.5,
        pretrained: bool = True,
    ):
        super().__init__()

        weights = MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
        self.backbone = models.mobilenet_v3_small(weights=weights)

        if input_channels != 3:
            original_conv = self.backbone.features[0][0]
            self.backbone.features[0][0] = nn.Conv2d(
                input_channels,
                original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None,
            )

        # Fix MobileNetV3-Small shape issue
        num_features = self.backbone.classifier[0].in_features  # 576
        self.backbone.classifier = nn.Sequential(
            nn.Linear(num_features, num_features // 2),  # 576 -> 288
            nn.Hardswish(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(num_features // 2, embedding_dim),  # 288 -> 16
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm


def train_epoch(model, train_loader, optimizer, loss_fn, device, scheduler=None):
    """
    Train the model for one epoch.

    Args:
        model (nn.Module): The Pytorch model instance
        train_loader (DataLoader): DataLoader for training data
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters
        loss_fn (nn.Module): Loss function to compute the loss
        device (torch.device): Device to run the training on

    Returns:
        float: Average loss for the epoch
    """
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc="Training", total=len(train_loader)):
        # Move data to the specified device
        anchor = batch["anchor_image"].to(device)
        positive = batch["pos_image"].to(device)
        negative = batch["neg_image"].to(device)

        # Forward pass
        anchor_out = model(anchor)
        positive_out = model(positive)
        negative_out = model(negative)

        # Compute loss
        loss = loss_fn(anchor_out, positive_out, negative_out)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

In [None]:
def load_image(filename) -> ImageFile:
    return Image.open(filename).convert("RGB")


def list_tiles_to_df(tiles_dir):
    files = []
    for file in Path(tiles_dir).resolve().glob("**/*.png"):
        # if filesize is equal to 103 bytes, skip it
        if file.stat().st_size == 103:  # empty tile
            continue

        files.append((file.parent.name, file.stem, file.parent.parent.name, str(file)))

    return pd.DataFrame(files, columns=["x", "y", "zoom", "filename"])


class OptimizedTilesDataset(Dataset):
    def __init__(
        self,
        tiles_root_dir,
        pos_radius=1,
        neg_radius_min=10,
        transform=None,
        preload_images=True,
    ):
        self.tiles_root_dir = tiles_root_dir
        self.df = list_tiles_to_df(tiles_root_dir)
        self.pos_radius = pos_radius
        self.neg_radius_min = neg_radius_min
        self.transform = transform
        self.preload_images = preload_images

        print(f"Found {len(self.df)} valid tiles")

        # Convert coordinates to integers once
        self.df["x_int"] = self.df["x"].astype(int)
        self.df["y_int"] = self.df["y"].astype(int)

        # Preload all images if requested
        if preload_images:
            print("Preloading all images into memory...")
            self.images = {}
            self.positive_candidates = {}
            self.negative_candidates = {}

            self._preload_images()
            self._precompute_candidates()
            print("Preloading complete!")
        else:
            self.images = None
            self._precompute_candidates_lazy()

    def _preload_images(self):
        """Load all images into memory during initialization."""
        from tqdm import tqdm

        for idx, row in tqdm(
            self.df.iterrows(), total=len(self.df), desc="Loading images"
        ):
            filename = row["filename"]
            try:
                self.images[idx] = load_image(filename)
            except Exception as e:
                print(f"Warning: Failed to load {filename}: {e}")
                # Create a blank image as fallback
                self.images[idx] = Image.new("RGB", (256, 256), color="black")

    def _precompute_candidates(self):
        """Precompute positive and negative candidates for each sample."""
        from tqdm import tqdm

        for idx, row in tqdm(
            self.df.iterrows(), total=len(self.df), desc="Computing candidates"
        ):
            x, y, zoom = row["x_int"], row["y_int"], row["zoom"]

            # Find positive candidates (within pos_radius)
            pos_mask = (
                (self.df["zoom"] == zoom)
                & (abs(self.df["x_int"] - x) <= self.pos_radius)
                & (abs(self.df["y_int"] - y) <= self.pos_radius)
                & (self.df.index != idx)  # Don't include self
            )
            pos_indices = self.df[pos_mask].index.tolist()

            # If no positive candidates, use self as fallback
            if not pos_indices:
                pos_indices = [idx]

            self.positive_candidates[idx] = pos_indices

            # Find negative candidates (outside neg_radius_min)
            neg_mask = (
                (self.df["zoom"] == zoom)
                & (
                    (abs(self.df["x_int"] - x) > self.neg_radius_min)
                    | (abs(self.df["y_int"] - y) > self.neg_radius_min)
                )
                & (self.df.index != idx)  # Don't include self
            )
            neg_indices = self.df[neg_mask].index.tolist()

            # If no suitable negatives, use all others as candidates
            if not neg_indices:
                neg_indices = [i for i in self.df.index if i != idx]

            self.negative_candidates[idx] = neg_indices

    def _precompute_candidates_lazy(self):
        """Lighter version that precomputes candidate indices without loading images."""
        from tqdm import tqdm

        self.positive_candidates = {}
        self.negative_candidates = {}

        for idx, row in tqdm(
            self.df.iterrows(), total=len(self.df), desc="Computing candidates"
        ):
            x, y, zoom = row["x_int"], row["y_int"], row["zoom"]

            # Find positive candidates
            pos_mask = (
                (self.df["zoom"] == zoom)
                & (abs(self.df["x_int"] - x) <= self.pos_radius)
                & (abs(self.df["y_int"] - y) <= self.pos_radius)
                & (self.df.index != idx)
            )
            pos_indices = self.df[pos_mask].index.tolist()
            if not pos_indices:
                pos_indices = [idx]
            self.positive_candidates[idx] = pos_indices

            # Find negative candidates
            neg_mask = (
                (self.df["zoom"] == zoom)
                & (
                    (abs(self.df["x_int"] - x) > self.neg_radius_min)
                    | (abs(self.df["y_int"] - y) > self.neg_radius_min)
                )
                & (self.df.index != idx)
            )
            neg_indices = self.df[neg_mask].index.tolist()
            if not neg_indices:
                neg_indices = [i for i in self.df.index if i != idx]
            self.negative_candidates[idx] = neg_indices

    def _get_image(self, idx):
        """Get image either from preloaded cache or load on demand."""
        if self.preload_images:
            return self.images[idx]
        else:
            filename = self.df.iloc[idx]["filename"]
            try:
                return load_image(filename)
            except Exception as e:
                print(f"Warning: Failed to load {filename}: {e}")
                return Image.new("RGB", (256, 256), color="black")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get anchor image
        anchor_image = self._get_image(idx)

        # Get positive sample (random choice from precomputed candidates)
        pos_idx = random.choice(self.positive_candidates[idx])
        pos_image = self._get_image(pos_idx)

        # Get negative sample (random choice from precomputed candidates)
        neg_idx = random.choice(self.negative_candidates[idx])
        neg_image = self._get_image(neg_idx)

        # Apply transforms if specified
        if self.transform:
            anchor_image = self.transform(anchor_image)
            pos_image = self.transform(pos_image)
            neg_image = self.transform(neg_image)

        # Get metadata
        row = self.df.iloc[idx]

        return {
            "anchor_image": anchor_image,
            "pos_image": pos_image,
            "neg_image": neg_image,
            "x": row["x"],
            "y": row["y"],
            "zoom": row["zoom"],
            "filename": row["filename"],
        }


# Backward compatibility - use optimized version by default
class TilesDataset(OptimizedTilesDataset):
    def __init__(self, tiles_root_dir, pos_radius=1, neg_radius_min=10, transform=None):
        # Default to preloading images for maximum performance
        super().__init__(
            tiles_root_dir, pos_radius, neg_radius_min, transform, preload_images=True
        )


# Memory-efficient version for very large datasets
class LazyTilesDataset(OptimizedTilesDataset):
    def __init__(self, tiles_root_dir, pos_radius=1, neg_radius_min=10, transform=None):
        # Don't preload images, but still precompute candidates
        super().__init__(
            tiles_root_dir, pos_radius, neg_radius_min, transform, preload_images=False
        )

# Training and Benchmarking Loc2Vec Models

In [None]:
def benchmark_model(model_class, model_name, train_loader, device, epochs=3):
    """Benchmark a single model configuration"""
    results = []

    base_lr = 1e-4  # Standard LR for pre-trained models

    # Test different optimizers and schedulers
    configs = [
        {"optimizer": "Adam", "lr": base_lr, "scheduler": None},
        {"optimizer": "AdamW", "lr": base_lr, "scheduler": None},
    ]

    for config in configs:
        try:
            # Clear memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

            # Initialize model
            model = model_class(input_channels=3, embedding_dim=16, dropout_rate=0.5)
            model.to(device)

            # Count parameters
            param_count = count_parameters(model)

            # Setup optimizer
            if config["optimizer"] == "AdamW":
                optimizer = optim.Adam(model.parameters(), lr=config["lr"])
            else:  # AdamW
                optimizer = optim.AdamW(model.parameters(), lr=config["lr"])

            # Setup loss function (Step 3 - More stable loss)
            loss_fn = SoftmaxTripletLoss()

            # Memory before training
            memory_before = get_memory_usage()

            # Training with gradient clipping (Step 2)
            start_time = time.time()
            epoch_losses = []

            for epoch in range(epochs):
                epoch_loss = train_epoch(
                    model, train_loader, optimizer, loss_fn, device, scheduler=None
                )
                # Add gradient clipping after training step
                epoch_losses.append(epoch_loss)

            training_time = time.time() - start_time

            # Better loss analysis
            avg_loss = np.mean(epoch_losses)
            final_loss = epoch_losses[-1] if epoch_losses else 0
            loss_std = np.std(epoch_losses) if len(epoch_losses) > 1 else 0
            min_loss = np.min(epoch_losses) if epoch_losses else 0
            max_loss = np.max(epoch_losses) if epoch_losses else 0
            loss_trend = (
                "Improving"
                if len(epoch_losses) > 1 and epoch_losses[-1] < epoch_losses[0]
                else "Stable/Degrading"
            )
            loss_string = ", ".join([f"{loss:.4f}" for loss in epoch_losses])

            # Memory after training
            memory_after = get_memory_usage()
            memory_used = memory_after - memory_before

            # Evaluate embedding quality
            silhouette_avg = evaluate_embeddings(model, train_loader, device)

            # Record results
            result = {
                "Model": model_name,
                "Optimizer": config["optimizer"],
                "Scheduler": config["scheduler"] or "None",
                "Parameters (M)": param_count / 1e6,
                "Training Time (s)": training_time,
                "Memory Used (MB)": memory_used,
                "Final Loss": final_loss,
                "Avg Loss": avg_loss,
                "Min Loss": min_loss,
                "Max Loss": max_loss,
                "Loss Std": loss_std,
                "Loss Trend": loss_trend,
                "All Losses": loss_string,
                "Silhouette Score": silhouette_avg,
                "Time per Epoch (s)": training_time / epochs,
            }
            results.append(result)

            print(
                f"✓ {model_name} - {config['optimizer']} - {config['scheduler'] or 'None'}"
            )

        except Exception as e:
            print(
                f"✗ {model_name} - {config['optimizer']} - {config['scheduler'] or 'None'}: {str(e)}"
            )
            continue

    return results

In [None]:
def run_comprehensive_comparison(train_loader, device):
    """Run comprehensive comparison of all models"""

    # Model configurations
    models_to_test = [
        (Loc2VecModel, "Custom Loc2Vec"),
        # (EfficientNetLoc2Vec, "EfficientNet B0"),
        # (EfficientNetV2SLoc2Vec, "EfficientNetV2-S"),
        # (EfficientNetV2MLoc2Vec, "EfficientNetV2-M"),
        # (ResNetLoc2Vec, "ResNet50"),
        # (ConvNeXtLoc2Vec, "ConvNeXt-Small"),
        # (SwinTransformerLoc2Vec, "Swin-Small"),
        # (MobileNetV3Loc2Vec, "MobileNetV3-Large"),
        # (MobileNetV3SmallLoc2Vec, "MobileNetV3-Small"),
    ]

    all_results = []

    print("Starting comprehensive model comparison...")
    print(f"Device: {device}")
    print("-" * 50)

    for model_class, model_name in models_to_test:
        print(f"\nTesting {model_name}...")
        try:
            results = benchmark_model(model_class, model_name, train_loader, device)
        except Exception as e:
            print(f"Error benchmarking {model_name}: {e}")
            continue
        else:
            all_results.extend(results)

    # Create DataFrame and sort by performance
    df = pd.DataFrame(all_results)

    print("\n" + "=" * 80)
    print("COMPREHENSIVE COMPARISON RESULTS")
    print("=" * 80)

    # Display results
    print("\n📊 FULL RESULTS TABLE:")
    print(df.round(4).to_string(index=False))

    # Best performers analysis
    print("\n🏆 BEST PERFORMERS BY CATEGORY:")
    print("-" * 40)

    best_speed = df.loc[df["Training Time (s)"].idxmin()]
    print(
        f"⚡ Fastest: {best_speed['Model']} ({best_speed['Optimizer']}) - {best_speed['Training Time (s)']:.2f}s"
    )

    best_memory = df.loc[df["Memory Used (MB)"].idxmin()]
    print(
        f"💾 Most Memory Efficient: {best_memory['Model']} ({best_memory['Optimizer']}) - {best_memory['Memory Used (MB)']:.1f}MB"
    )

    best_final_loss = df.loc[df["Final Loss"].idxmin()]
    print(
        f"🎯 Best Final Loss: {best_final_loss['Model']} ({best_final_loss['Optimizer']}) - {best_final_loss['Final Loss']:.4f}"
    )

    most_stable = df.loc[df["Loss Std"].idxmin()]
    print(
        f"📈 Most Stable Training: {most_stable['Model']} ({most_stable['Optimizer']}) - Std: {most_stable['Loss Std']:.4f}"
    )

    best_silhouette = df.loc[df["Silhouette Score"].idxmax()]
    print(
        f"🎨 Best Embeddings: {best_silhouette['Model']} ({best_silhouette['Optimizer']}) - {best_silhouette['Silhouette Score']:.4f}"
    )

    smallest_model = df.loc[df["Parameters (M)"].idxmin()]
    print(
        f"📦 Smallest Model: {smallest_model['Model']} - {smallest_model['Parameters (M)']:.1f}M params"
    )
    return df

In [None]:
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
dataset = TilesDataset(
    "full",
    pos_radius=1,
    transform=T.Compose(
        [
            T.Resize((128, 128)),
            T.ToTensor(),
            T.Normalize([0.8107, 0.8611, 0.7814], [0.1215, 0.0828, 0.1320]),
        ]
    ),
)
import multiprocessing as mp

try:
    mp.set_start_method("fork", force=True)
except RuntimeError:
    pass  # Already set

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    prefetch_factor=10,
    persistent_workers=True,
)

In [None]:
run_comprehensive_comparison(train_loader, device)