In [1]:
import pandas as pd
import numpy as np
import numpy.typing as npt
from torchvision.datasets import Caltech101, Caltech256

# Get label names
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories

# Get targets
caltech101_targets = pd.read_csv("output/caltech101.csv")
caltech256_targets = pd.read_csv("output/caltech256.csv")
caltech101_model_caltech256 = pd.read_csv("output/caltech101_model_caltech256.csv")
caltech256_model_caltech101 = pd.read_csv("output/caltech256_model_caltech101.csv")

In [2]:
from taxonomy import Taxonomy

cross_domain_predictions = [
    (
        0,
        1,
        caltech101_model_caltech256["predictions"].to_numpy(),
    ),  # Caltech101 model predicting Caltech256
    (
        1,
        0,
        caltech256_model_caltech101["predictions"].to_numpy(),
    ),  # Caltech256 model predicting Caltech101
]
domain_targets = [
    (0, caltech101_targets["targets"].to_numpy()),
    (1, caltech256_targets["targets"].to_numpy()),
]
domain_labels = {
    0: caltech101_labels,
    1: caltech256_labels,
}
taxonomy = Taxonomy(
    cross_domain_predictions=cross_domain_predictions,
    domain_targets=domain_targets,
    domain_labels=domain_labels,
)

# Visualize the original taxonomy graph
net = taxonomy.visualize_graph(
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("output/caltech101_caltech256_relationships.html")

# Visualize the universal taxonomy graph
taxonomy.build_universal_taxonomy()
net = taxonomy.visualize_graph(
    title="Caltech101 <=> Caltech256 Universal Taxonomy Graph",
)
net.save_graph("output/caltech101_caltech256_universal_taxonomy.html")

In [3]:
from synthetic_taxonomy import SyntheticTaxonomy

synthetic_taxonomy = SyntheticTaxonomy(
    num_atomic_concepts=50,
    num_domains=2,
    domain_class_count_mean=10,
    domain_class_count_variance=3,
    concept_cluster_size_mean=2,
    concept_cluster_size_variance=1,
)
synthetic_net = synthetic_taxonomy.visualize_graph(
    title="Synthetic Taxonomy Graph",
)
synthetic_net.save_graph("output/synthetic_taxonomy_graph.html")
synthetic_taxonomy.build_universal_taxonomy()
synthetic_net = synthetic_taxonomy.visualize_graph(
    title="Synthetic Universal Taxonomy Graph",
)
synthetic_net.save_graph("output/synthetic_universal_taxonomy_graph.html")