In [3]:
import pandas as pd
import numpy as np
import numpy.typing as npt
from torchvision.datasets import Caltech101, Caltech256

# Get label names
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories

# Get targets
caltech101_targets = pd.read_csv("outputs/caltech101.csv")
caltech256_targets = pd.read_csv("outputs/caltech256.csv")
caltech101_model_caltech256 = pd.read_csv("outputs/caltech101_model_caltech256.csv")
caltech256_model_caltech101 = pd.read_csv("outputs/caltech256_model_caltech101.csv")

In [None]:
from taxonomy import Taxonomy

taxonomy = Taxonomy(
    a_to_b_predictions=caltech101_model_caltech256["predictions"].to_numpy(),
    a_targets=caltech101_targets["targets"].to_numpy(),
    b_to_a_predictions=caltech256_model_caltech101["predictions"].to_numpy(),
    b_targets=caltech256_targets["targets"].to_numpy(),
)

# Visualize the original taxonomy graph
net = taxonomy.visualize_graph(
    domain_a_labels=caltech101_labels,
    domain_b_labels=caltech256_labels,
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("outputs/caltech101_caltech256_relationships.html")

# Visualize the universal taxonomy graph
taxonomy.build_universal_taxonomy()
net = taxonomy.visualize_graph(
    domain_a_labels=caltech101_labels,
    domain_b_labels=caltech256_labels,
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("outputs/caltech101_caltech256_universal_taxonomy.html")