In [33]:
import pandas as pd
import numpy as np
import torch
import matplotlib

matplotlib.use("pgf")
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torchvision.datasets import Caltech256, Caltech101, CIFAR100
from torch.utils.data import DataLoader
import torch.nn.functional as F

from library.taxonomy import Taxonomy
from library.models.universal_resnet import UniversalResNetModel
from library.datasets.caltech101 import Caltech101DataModule
from library.datasets.caltech256 import Caltech256DataModule
from library.datasets.cifar100 import CIFAR100ScaledDataModule
from library.datasets.util import CombinedDataModule
from library.analysis import UniversalModelAnalyzer

# LaTeX settings for plots
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "font.size": 11,
        "pgf.texsystem": "lualatex",
    }
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

Using device: cuda


In [34]:
# Load all taxonomies
hypothesis_taxonomy = Taxonomy.load("taxonomies/caltech256_caltech101_hypothesis.pkl")
mcfp_taxonomy = Taxonomy.load("taxonomies/caltech256_caltech101_mcfp.pkl")
mcfp_binary_taxonomy = Taxonomy.load("taxonomies/caltech256_caltech101_mcfp_binary.pkl")
density_threshold_taxonomy = Taxonomy.load(
    "taxonomies/caltech256_caltech101_density_threshold.pkl"
)
naive_threshold_taxonomy = Taxonomy.load(
    "taxonomies/caltech256_caltech101_naive_threshold.pkl"
)

# Load three-domain taxonomies
three_domain_hypothesis_taxonomy = Taxonomy.load(
    "taxonomies/three_domain_hypothesis.pkl"
)
three_domain_mcfp_taxonomy = Taxonomy.load("taxonomies/three_domain_mcfp.pkl")
three_domain_mcfp_binary_taxonomy = Taxonomy.load(
    "taxonomies/three_domain_mcfp_binary.pkl"
)
three_domain_density_threshold_taxonomy = Taxonomy.load(
    "taxonomies/three_domain_density_threshold.pkl"
)
three_domain_naive_threshold_taxonomy = Taxonomy.load(
    "taxonomies/three_domain_naive_threshold.pkl"
)

# Model configurations (same as in universal_model_training.ipynb)
taxonomies_config = {
    "hypothesis": {
        "taxonomy": hypothesis_taxonomy,
        "model_name": "universal-resnet50-hypothesis-multi-domain-min-val-loss",
        "display_name": "Hypothesis",
    },
    "mcfp": {
        "taxonomy": mcfp_taxonomy,
        "model_name": "universal-resnet50-mcfp-multi-domain-min-val-loss",
        "display_name": "MCFP",
    },
    "mcfp_binary": {
        "taxonomy": mcfp_binary_taxonomy,
        "model_name": "universal-resnet50-mcfp-binary-multi-domain-min-val-loss",
        "display_name": "MCFP Binary",
    },
    "density_threshold": {
        "taxonomy": density_threshold_taxonomy,
        "model_name": "universal-resnet50-density-threshold-multi-domain-min-val-loss",
        "display_name": "Density Threshold",
    },
    "naive_threshold": {
        "taxonomy": naive_threshold_taxonomy,
        "model_name": "universal-resnet50-naive-threshold-multi-domain-min-val-loss",
        "display_name": "Naive Threshold",
    },
    "three_domain_hypothesis": {
        "taxonomy": three_domain_hypothesis_taxonomy,
        "model_name": "universal-resnet50-three-domain-hypothesis-min-val-loss",
        "display_name": "Hypothesis",
    },
    "three_domain_mcfp": {
        "taxonomy": three_domain_mcfp_taxonomy,
        "model_name": "universal-resnet50-three-domain-mcfp-min-val-loss",
        "display_name": "MCFP",
    },
    "three_domain_mcfp_binary": {
        "taxonomy": three_domain_mcfp_binary_taxonomy,
        "model_name": "universal-resnet50-three-domain-mcfp-binary-min-val-loss",
        "display_name": "MCFP Binary",
    },
    "three_domain_density_threshold": {
        "taxonomy": three_domain_density_threshold_taxonomy,
        "model_name": "universal-resnet50-three-domain-density-threshold-min-val-loss",
        "display_name": "Density Threshold",
    },
    "three_domain_naive_threshold": {
        "taxonomy": three_domain_naive_threshold_taxonomy,
        "model_name": "universal-resnet50-three-domain-naive-threshold-min-val-loss",
        "display_name": "Naive Threshold",
    },
}

In [35]:
# Create dataset modules
caltech101_dm = Caltech101DataModule(batch_size=32)
caltech256_dm = Caltech256DataModule(batch_size=32)
cifar100_dm = CIFAR100ScaledDataModule(batch_size=32)

# Setup data modules
caltech101_dm.setup(stage="test")
caltech256_dm.setup(stage="test")
cifar100_dm.setup(stage="test")

# Create combined data modules for analysis
combined_2domain_dm = CombinedDataModule(
    dataset_modules=[caltech101_dm, caltech256_dm],
    domain_ids=[0, 1],
    batch_size=32,
    num_workers=4,
)
combined_2domain_dm.setup(stage="test")

combined_3domain_dm = CombinedDataModule(
    dataset_modules=[caltech101_dm, caltech256_dm, cifar100_dm],
    domain_ids=[0, 1, 2],
    batch_size=32,
    num_workers=4,
)
combined_3domain_dm.setup(stage="test")


# Create sample dataloaders for visualization (smaller subset for faster processing)
def create_sample_dataloader(dataset_module, num_samples=200, random_seed=42):
    """Create a smaller dataloader for visualization"""
    # Set random seed for reproducibility
    np.random.seed(random_seed)

    # Check if it's a CombinedDataModule or individual DataModule
    if hasattr(dataset_module, "test_dataset"):
        # Individual data module
        dataset = dataset_module.test_dataset
    else:
        # Combined data module - use test_dataloader and extract dataset
        test_loader = dataset_module.test_dataloader()
        dataset = test_loader.dataset

    indices = np.random.choice(
        len(dataset), min(num_samples, len(dataset)), replace=False
    )
    subset = torch.utils.data.Subset(dataset, indices)
    return DataLoader(subset, batch_size=32, shuffle=False)


# Create sample dataloaders
sample_2domain_loader = create_sample_dataloader(combined_2domain_dm, num_samples=400)
sample_3domain_loader = create_sample_dataloader(combined_3domain_dm, num_samples=600)

In [36]:
# Color and marker configuration for visualization
def get_class_colors(num_classes):
    """Generate distinct colors for classes"""
    if num_classes <= 20:
        cmap = plt.cm.get_cmap("tab20")
        return [cmap(i) for i in np.linspace(0, 1, num_classes)]
    else:
        cmap = plt.cm.get_cmap("hsv")
        return [cmap(i) for i in np.linspace(0, 1, num_classes)]

In [37]:
# Generate t-SNE visualizations for all universal models
print("Generating t-SNE visualizations for universal models...")

# Extract features and create visualizations for 2-domain models
features_2domain = {}
embeddings_2domain = {}

print("Processing 2-domain models...")
for model_name in [
    "hypothesis",
    "mcfp",
    "mcfp_binary",
    "density_threshold",
    "naive_threshold",
]:
    config = taxonomies_config[model_name]

    print(f"Loading model: {config['model_name']}")
    model = UniversalResNetModel.load_from_checkpoint(
        f"checkpoints/{config['model_name']}.ckpt", taxonomy=config["taxonomy"]
    )
    model = model.to(device)

    # Create analyzer
    analyzer = UniversalModelAnalyzer(model, device=device)

    print(f"Extracting universal features for {model_name}...")
    # Extract universal features from output layer
    universal_features, labels, class_names = analyzer.extract_universal_features(
        sample_2domain_loader, max_samples=400
    )

    print(f"Applying t-SNE for {model_name}...")
    # Apply t-SNE
    embeddings = analyzer.apply_tsne(universal_features, perplexity=30.0)

    features_2domain[model_name] = {
        "features": universal_features,
        "embeddings": embeddings,
        "labels": labels,
        "class_names": class_names,
        "display_name": config["display_name"],
    }

    print(f"Completed {model_name}: {universal_features.shape[0]} samples")

print("Completed 2-domain processing")

Generating t-SNE visualizations for universal models...
Processing 2-domain models...
Loading model: universal-resnet50-hypothesis-multi-domain-min-val-loss
Extracting universal features for hypothesis...
Extracting universal features for hypothesis...
Applying t-SNE for hypothesis...
Applying t-SNE to 400 samples with 667 features...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 400 samples in 0.000s...
[t-SNE] Computed neighbors for 400 samples in 0.012s...
[t-SNE] Computed conditional probabilities for sample 400 / 400
[t-SNE] Mean sigma: 7.308889
Applying t-SNE for hypothesis...
Applying t-SNE to 400 samples with 667 features...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 400 samples in 0.000s...
[t-SNE] Computed neighbors for 400 samples in 0.012s...
[t-SNE] Computed conditional probabilities for sample 400 / 400
[t-SNE] Mean sigma: 7.308889
[t-SNE] KL divergence after 250 iterations with early exaggeration: 68.336426
[t-SNE] KL divergence after 250 itera

In [38]:
# Extract features and create visualizations for 3-domain models
features_3domain = {}
embeddings_3domain = {}

print("Processing 3-domain models...")
for model_name in [
    "three_domain_hypothesis",
    "three_domain_mcfp",
    "three_domain_mcfp_binary",
    "three_domain_density_threshold",
    "three_domain_naive_threshold",
]:
    config = taxonomies_config[model_name]

    print(f"Loading model: {config['model_name']}")
    model = UniversalResNetModel.load_from_checkpoint(
        f"checkpoints/{config['model_name']}.ckpt", taxonomy=config["taxonomy"]
    )
    model = model.to(device)

    # Create analyzer
    analyzer = UniversalModelAnalyzer(model, device=device)

    print(f"Extracting universal features for {model_name}...")
    # Extract universal features from output layer
    universal_features, labels, class_names = analyzer.extract_universal_features(
        sample_3domain_loader, max_samples=600
    )

    print(f"Applying t-SNE for {model_name}...")
    # Apply t-SNE
    embeddings = analyzer.apply_tsne(universal_features, perplexity=30.0)

    features_3domain[model_name] = {
        "features": universal_features,
        "embeddings": embeddings,
        "labels": labels,
        "class_names": class_names,
        "display_name": config["display_name"],
    }

    print(f"Completed {model_name}: {universal_features.shape[0]} samples")

print("Completed 3-domain processing")

Processing 3-domain models...
Loading model: universal-resnet50-three-domain-hypothesis-min-val-loss
Extracting universal features for three_domain_hypothesis...
Extracting universal features for three_domain_hypothesis...
Applying t-SNE for three_domain_hypothesis...
Applying t-SNE to 600 samples with 1236 features...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 600 samples in 0.000s...
[t-SNE] Computed neighbors for 600 samples in 0.025s...
[t-SNE] Computed conditional probabilities for sample 600 / 600
[t-SNE] Mean sigma: 17.502026
Applying t-SNE for three_domain_hypothesis...
Applying t-SNE to 600 samples with 1236 features...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 600 samples in 0.000s...
[t-SNE] Computed neighbors for 600 samples in 0.025s...
[t-SNE] Computed conditional probabilities for sample 600 / 600
[t-SNE] Mean sigma: 17.502026
[t-SNE] KL divergence after 250 iterations with early exaggeration: 64.094063
[t-SNE] KL divergence after 250 itera

In [39]:
# Create comprehensive visualization for 2-domain models
print("Creating 2-domain feature visualization...")

# Create figure with 2 rows, 3 columns (bottom row has only 2 columns)
fig_2domain, axes_2domain = plt.subplots(2, 3, figsize=(18, 12))

# Define the order of models for 2-domain
model_order_2domain = [
    "hypothesis",
    "mcfp",
    "mcfp_binary",
    "density_threshold",
    "naive_threshold",
]

# Extend dataset_markers to handle more domain IDs if needed
extended_dataset_markers = {
    0: "o",  # Caltech-101
    1: "*",  # Caltech-256
    2: "v",  # CIFAR-100
}

# Plot each model using custom visualization
for idx, model_name in enumerate(model_order_2domain):
    if idx < 3:  # First row
        row, col = 0, idx
    else:  # Second row (only 2 plots)
        row, col = 1, idx - 3

    data = features_2domain[model_name]

    # Extract domain and class info from labels
    domain_ids = [label[0] for label in data["labels"]]
    class_ids = [label[1] for label in data["labels"]]

    # Debug: Print unique domain and class IDs
    unique_domains = sorted(list(set(domain_ids)))
    unique_classes = sorted(list(set(class_ids)))
    print(
        f"Model {model_name}: Domains = {unique_domains}, Classes = {unique_classes[:10]}..."
    )  # Show first 10 classes

    # Generate colors for classes
    class_colors = get_class_colors(len(unique_classes))

    # Plot each combination of class and domain
    ax = axes_2domain[row, col]
    for domain_id in unique_domains:
        domain_indices = [i for i, d in enumerate(domain_ids) if d == domain_id]
        domain_embeddings = data["embeddings"][domain_indices]
        domain_class_ids = [class_ids[i] for i in domain_indices]

        for class_id in unique_classes:
            class_indices = [i for i, c in enumerate(domain_class_ids) if c == class_id]
            if len(class_indices) == 0:
                continue

            class_embeddings = domain_embeddings[class_indices]
            color_idx = unique_classes.index(class_id)

            # Use extended markers, fall back to circle if domain_id not found
            marker = extended_dataset_markers.get(domain_id, "o")

            ax.scatter(
                class_embeddings[:, 0],
                class_embeddings[:, 1],
                c=[class_colors[color_idx]],
                marker=marker,
                s=30,
                alpha=0.6,
                edgecolors="black",
                linewidth=0.5,
            )

    ax.set_title(f"{data['display_name']} (2 Domains)", fontsize=10)
    ax.set_xlabel("t-SNE Component 1")
    ax.set_ylabel("t-SNE Component 2")
    ax.grid(True, alpha=0.3)

# Hide the third subplot in the second row (since we only have 5 models)
axes_2domain[1, 2].axis("off")

# Add main title
fig_2domain.suptitle("t-SNE Visualizations: 2-Domain Taxonomy Universal Models")


plt.tight_layout()
plt.savefig(
    "../thesis/figures/universal_features_2domain_tsne.pgf", bbox_inches="tight"
)
plt.show()

print("Completed 2-domain visualization")

Creating 2-domain feature visualization...
Model hypothesis: Domains = [0, 1], Classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10]...
Model mcfp: Domains = [0, 1], Classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10]...
Model mcfp_binary: Domains = [0, 1], Classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10]...


  cmap = plt.cm.get_cmap("hsv")


Model density_threshold: Domains = [0, 1], Classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10]...
Model naive_threshold: Domains = [0, 1], Classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10]...
Completed 2-domain visualization
Completed 2-domain visualization


  plt.show()


In [40]:
# Create comprehensive visualization for 3-domain models
print("Creating 3-domain feature visualization...")

# Create figure with 2 rows, 3 columns (bottom row has only 2 columns)
fig_3domain, axes_3domain = plt.subplots(2, 3, figsize=(18, 12))

# Define the order of models for 3-domain
model_order_3domain = [
    "three_domain_hypothesis",
    "three_domain_mcfp",
    "three_domain_mcfp_binary",
    "three_domain_density_threshold",
    "three_domain_naive_threshold",
]

# Extend dataset_markers to handle more domain IDs if needed
extended_dataset_markers = {
    0: "o",  # Caltech-101
    1: "*",  # Caltech-256
    2: "v",  # CIFAR-100
}

# Plot each model using custom visualization
for idx, model_name in enumerate(model_order_3domain):
    if idx < 3:  # First row
        row, col = 0, idx
    else:  # Second row (only 2 plots)
        row, col = 1, idx - 3

    data = features_3domain[model_name]

    # Extract domain and class info from labels
    domain_ids = [label[0] for label in data["labels"]]
    class_ids = [label[1] for label in data["labels"]]

    # Debug: Print unique domain and class IDs
    unique_domains = sorted(list(set(domain_ids)))
    unique_classes = sorted(list(set(class_ids)))
    print(
        f"Model {model_name}: Domains = {unique_domains}, Classes = {unique_classes[:10]}..."
    )  # Show first 10 classes

    # Generate colors for classes
    class_colors = get_class_colors(len(unique_classes))

    # Plot each combination of class and domain
    ax = axes_3domain[row, col]
    for domain_id in unique_domains:
        domain_indices = [i for i, d in enumerate(domain_ids) if d == domain_id]
        domain_embeddings = data["embeddings"][domain_indices]
        domain_class_ids = [class_ids[i] for i in domain_indices]

        for class_id in unique_classes:
            class_indices = [i for i, c in enumerate(domain_class_ids) if c == class_id]
            if len(class_indices) == 0:
                continue

            class_embeddings = domain_embeddings[class_indices]
            color_idx = unique_classes.index(class_id)

            # Use extended markers, fall back to circle if domain_id not found
            marker = extended_dataset_markers.get(domain_id, "o")

            ax.scatter(
                class_embeddings[:, 0],
                class_embeddings[:, 1],
                c=[class_colors[color_idx]],
                marker=marker,
                s=30,
                alpha=0.6,
                edgecolors="black",
                linewidth=0.5,
            )

    ax.set_title(f"{data['display_name']} (3 Domains)", fontsize=10)
    ax.set_xlabel("t-SNE Component 1")
    ax.set_ylabel("t-SNE Component 2")
    ax.grid(True, alpha=0.3)

# Hide the third subplot in the second row (since we only have 5 models)
axes_3domain[1, 2].axis("off")

# Add main title
fig_3domain.suptitle("t-SNE Visualizations: 3-Domain Taxonomy Universal Models")


plt.tight_layout()
plt.savefig(
    "../thesis/figures/universal_features_3domain_tsne.pgf", bbox_inches="tight"
)
plt.show()

print("Completed 3-domain visualization")

print("All visualizations completed successfully!")

  cmap = plt.cm.get_cmap("hsv")


Creating 3-domain feature visualization...
Model three_domain_hypothesis: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_mcfp: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_mcfp: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_mcfp_binary: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_density_threshold: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_mcfp_binary: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_density_threshold: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_naive_threshold: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Model three_domain_naive_threshold: Domains = [0, 1, 2], Classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]...
Completed 3-domain visualization
All visualizations completed successfully!
Com

  plt.show()
