In [1]:
import pandas as pd
import numpy as np
import numpy.typing as npt
from torchvision.datasets import Caltech101, Caltech256

# Get label names
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories

# Get targets
caltech101_targets = pd.read_csv("output/caltech101.csv")
caltech256_targets = pd.read_csv("output/caltech256.csv")
caltech101_model_caltech256 = pd.read_csv("output/caltech101_model_caltech256.csv")
caltech256_model_caltech101 = pd.read_csv("output/caltech256_model_caltech101.csv")

In [2]:
from taxonomy import Taxonomy

# Create cross-domain predictions in the new format [(model_domain_id, dataset_domain_id, predictions)]
# Domain 0 = Caltech101, Domain 1 = Caltech256
cross_domain_predictions = [
    (
        0,
        1,
        caltech101_model_caltech256["predictions"].to_numpy(),
    ),  # Caltech101 model predicting Caltech256
    (
        1,
        0,
        caltech256_model_caltech101["predictions"].to_numpy(),
    ),  # Caltech256 model predicting Caltech101
]

# Create domain targets in the new format [(domain_id, targets)]
domain_targets = [
    (0, caltech101_targets["targets"].to_numpy()),  # Domain 0 (Caltech101) targets
    (1, caltech256_targets["targets"].to_numpy()),  # Domain 1 (Caltech256) targets
]

# Initialize taxonomy with the new API
taxonomy = Taxonomy(
    cross_domain_predictions=cross_domain_predictions,
    domain_targets=domain_targets,
)

# Create domain labels dictionary for visualization
domain_labels = {
    0: caltech101_labels,  # Domain 0 = Caltech101
    1: caltech256_labels,  # Domain 1 = Caltech256
}

# Visualize the original taxonomy graph
net = taxonomy.visualize_graph(
    domain_labels=domain_labels,
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("output/caltech101_caltech256_relationships.html")

# Visualize the universal taxonomy graph
taxonomy.build_universal_taxonomy()
net = taxonomy.visualize_graph(
    domain_labels=domain_labels,
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("output/caltech101_caltech256_universal_taxonomy.html")