In [None]:
import pandas as pd
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import ManualTaxonomy, CrossPredictionsTaxonomy
from library.models.resnet import ResNetModel
from library.datasets.mnist import MNISTMappedDataModule
from library.datasets.svhn import SVHNMappedDataModule
from library.taxonomy import DomainClass, Relationship

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

In [None]:
# Create ground truth manual taxonomy
# This represents the "perfect" taxonomy where each digit maps to its corresponding digit
# Domain 0: MNIST, Domain 1: SVHN

# Define domain labels for clarity
domain_labels = {
    0: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],  # MNIST digits
    1: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],  # SVHN digits
}

# Create perfect 1:1 mapping relationships between matching digits
# Each MNIST digit should map to the corresponding SVHN digit with confidence 1.0
relationships = []
for digit in range(10):
    # Create domain classes for MNIST and SVHN
    mnist_class = DomainClass((np.intp(0), np.intp(digit)))  # Domain 0 (MNIST), digit
    svhn_class = DomainClass((np.intp(1), np.intp(digit)))  # Domain 1 (SVHN), digit

    # Create bidirectional relationships
    # MNIST digit -> SVHN digit
    relationships.append(Relationship((mnist_class, svhn_class, 1.0)))
    # SVHN digit -> MNIST digit
    relationships.append(Relationship((svhn_class, mnist_class, 1.0)))

# Create the ground truth manual taxonomy using the normal constructor
ground_truth_taxonomy = ManualTaxonomy(
    num_domains=2,
    num_nodes=10,
    relationships=relationships,
    domain_labels=domain_labels,
)

print(
    f"Ground truth taxonomy created with {len(ground_truth_taxonomy.relationships)} relationships"
)
print(
    f"Graph has {ground_truth_taxonomy.graph.number_of_nodes()} nodes and {ground_truth_taxonomy.graph.number_of_edges()} edges"
)

Ground truth taxonomy created with 20 relationships
Graph has 20 nodes and 20 edges


In [3]:
# Create identity mappings for both datasets (all digits 0-9)
mnist_mapping = {i: i for i in range(10)}  # Keep all MNIST digits as-is
svhn_mapping = {i: i for i in range(10)}  # Keep all SVHN digits as-is

print(f"MNIST mapping: {mnist_mapping}")
print(f"SVHN mapping: {svhn_mapping}")
print(f"Using all 10 digits from both datasets")

MNIST mapping: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
SVHN mapping: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
Using all 10 digits from both datasets


In [4]:
# Configuration
TRAIN = True  # Set to True to train models from scratch


def train_digit_model(datamodule, domain_name, logger_name, model_name):
    """Train a ResNet model for a specific digit domain"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=3,  # Fewer epochs for digit classification
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=10,  # 10 digits
            architecture="resnet50",
            optim="adamw",  # Use AdamW for better convergence on digit data
            optim_kwargs={
                "lr": 0.001,
                "weight_decay": 1e-4,
            },
        )
        trainer.fit(model, datamodule=datamodule)
        results = trainer.test(datamodule=datamodule, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=datamodule)

    print(f"{domain_name} Results: {results}")
    return results

In [5]:
# Train MNIST model (Domain 0)
print("Training/Testing MNIST Model:")
mnist_datamodule = MNISTMappedDataModule(mapping=mnist_mapping)
mnist_results = train_digit_model(
    mnist_datamodule,
    "MNIST",
    "digit_mnist",
    "resnet50-mnist-min-val-loss",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Training/Testing MNIST Model:


/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/bjoern/dev/master-thesis/project/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 26.3 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
26.3 M    Trainable params
0         Non-trainable params
26.3 M    Total params
105.186   Total estimated model params size (MB)
162       Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 750/750 [01:15<00:00,  9.97it/s, v_num=5]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 750/750 [01:15<00:00,  9.97it/s, v_num=5]


Restoring states from the checkpoint path at /home/bjoern/dev/master-thesis/project/checkpoints/resnet50-mnist-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/bjoern/dev/master-thesis/project/checkpoints/resnet50-mnist-min-val-loss.ckpt
/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:05<00:00, 30.57it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9745000004768372
        eval_loss           0.1137707382440567
        hp_metric           0.9745000004768372
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
MNIST Results: [{'eval_loss': 0.1137707382440567, 'eval_accuracy': 0.9745000004768372, 'hp_metric': 0.9745000004768372}]


In [6]:
# Train SVHN model (Domain 1)
print("Training/Testing SVHN Model:")
svhn_datamodule = SVHNMappedDataModule(mapping=svhn_mapping)
svhn_results = train_digit_model(
    svhn_datamodule,
    "SVHN",
    "digit_svhn",
    "resnet50-svhn-min-val-loss",
)

Training/Testing SVHN Model:


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
100%|██████████| 182M/182M [00:29<00:00, 6.07MB/s] 
100%|██████████| 64.3M/64.3M [00:21<00:00, 3.02MB/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 26.3 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
26.3 M    Trainable params
0         Non-trainable params
26.3 M    Total params
105.186   Total estimated model params size (MB)
162       Modules in train mode
0         Modules in eval mode


Epoch 2: 100%|██████████| 916/916 [02:11<00:00,  6.99it/s, v_num=0]        

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 916/916 [02:12<00:00,  6.92it/s, v_num=0]


Restoring states from the checkpoint path at /home/bjoern/dev/master-thesis/project/checkpoints/resnet50-svhn-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/bjoern/dev/master-thesis/project/checkpoints/resnet50-svhn-min-val-loss.ckpt


Testing DataLoader 0: 100%|██████████| 407/407 [00:11<00:00, 34.67it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8778042197227478
        eval_loss           0.44375646114349365
        hp_metric           0.8778042197227478
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
SVHN Results: [{'eval_loss': 0.44375646114349365, 'eval_accuracy': 0.8778042197227478, 'hp_metric': 0.8778042197227478}]


In [7]:
# Configuration for prediction generation
PREDICT = True  # Set to True to generate predictions from scratch

if PREDICT:
    # Load trained models
    mnist_model = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-mnist-min-val-loss.ckpt"
    )
    mnist_model.eval()

    svhn_model = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-svhn-min-val-loss.ckpt"
    )
    svhn_model.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    # Generate cross-domain predictions
    print("Generating cross-domain predictions...")

    # MNIST model predicting on SVHN data
    mnist_on_svhn = trainer.predict(mnist_model, datamodule=svhn_datamodule)
    predictions_mnist_on_svhn = torch.cat(mnist_on_svhn).argmax(dim=1)  # type: ignore

    # SVHN model predicting on MNIST data
    svhn_on_mnist = trainer.predict(svhn_model, datamodule=mnist_datamodule)
    predictions_svhn_on_mnist = torch.cat(svhn_on_mnist).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    mnist_targets = torch.cat(
        [label for _, label in mnist_datamodule.predict_dataloader()]
    )
    svhn_targets = torch.cat(
        [label for _, label in svhn_datamodule.predict_dataloader()]
    )

    # Save predictions
    pd.DataFrame(
        {
            "mnist_targets": mnist_targets,
            "predictions_svhn_on_mnist": predictions_svhn_on_mnist,
        }
    ).to_csv("data/digit_mnist_predictions.csv", index=False)

    pd.DataFrame(
        {
            "svhn_targets": svhn_targets,
            "predictions_mnist_on_svhn": predictions_mnist_on_svhn,
        }
    ).to_csv("data/digit_svhn_predictions.csv", index=False)

    print("Predictions saved to CSV files.")

# Load prediction results
df_mnist = pd.read_csv("data/digit_mnist_predictions.csv")
df_svhn = pd.read_csv("data/digit_svhn_predictions.csv")

print(f"MNIST predictions shape: {df_mnist.shape}")
print(f"SVHN predictions shape: {df_svhn.shape}")
print(f"Sample MNIST predictions: {df_mnist.head()}")
print(f"Sample SVHN predictions: {df_svhn.head()}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Generating cross-domain predictions...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 407/407 [00:10<00:00, 40.00it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 157/157 [00:04<00:00, 34.86it/s]
Predictions saved to CSV files.
MNIST predictions shape: (10000, 2)
SVHN predictions shape: (26032, 2)
Sample MNIST predictions:    mnist_targets  predictions_svhn_on_mnist
0              7                          7
1              2                          2
2              1                          7
3              0                          6
4              4                          2
Sample SVHN predictions:    svhn_targets  predictions_mnist_on_svhn
0             5                          8
1             2                          8
2             1                          8
3             0                          8
4             6                          8


In [8]:
# Construct taxonomy from cross-domain predictions
learned_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        # MNIST model (domain 0) predicting on SVHN data (domain 1)
        (0, 1, np.array(df_svhn["predictions_mnist_on_svhn"], dtype=np.intp)),
        # SVHN model (domain 1) predicting on MNIST data (domain 0)
        (1, 0, np.array(df_mnist["predictions_svhn_on_mnist"], dtype=np.intp)),
    ],
    domain_targets=[
        (0, np.array(df_mnist["mnist_targets"], dtype=np.intp)),  # MNIST targets
        (1, np.array(df_svhn["svhn_targets"], dtype=np.intp)),  # SVHN targets
    ],
    domain_labels=domain_labels,
    relationship_type="mcfp",  # Most Confident Prediction
)

print("Learned taxonomy constructed from cross-domain predictions.")
print(
    f"Learned taxonomy has {learned_taxonomy.graph.number_of_nodes()} nodes and {learned_taxonomy.graph.number_of_edges()} edges"
)

Learned taxonomy constructed from cross-domain predictions.
Learned taxonomy has 20 nodes and 20 edges


In [9]:
# Generate and save taxonomy visualizations
print("Generating taxonomy visualizations...")

# Learned taxonomy visualization
learned_taxonomy.visualize_graph(
    "Learned Digit Taxonomy (MNIST ↔ SVHN)",
    height=800,
    width=1200,
).save_graph("output/digit_learned_taxonomy.html")

# Ground truth taxonomy visualization
ground_truth_taxonomy.visualize_graph(
    "Ground Truth Digit Taxonomy (Perfect 1:1 Mapping)",
    height=800,
    width=1200,
).save_graph("output/digit_ground_truth_taxonomy.html")

print("Taxonomy visualizations saved to output/ directory.")

Generating taxonomy visualizations...
Taxonomy visualizations saved to output/ directory.


In [10]:
# Evaluate learned taxonomy against ground truth
edr = learned_taxonomy.edge_difference_ratio(ground_truth_taxonomy)
precision, recall, f1 = learned_taxonomy.precision_recall_f1(ground_truth_taxonomy)

print("Digit Taxonomy Evaluation Results:")
print("=" * 40)
print(f"Edge Difference Ratio: {edr:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print()
print("Interpretation:")
print(
    f"- EDR of {edr:.4f} means {edr*100:.1f}% of edges differ between learned and ground truth"
)
print(
    f"- Precision of {precision:.4f} means {precision*100:.1f}% of learned relationships are correct"
)
print(
    f"- Recall of {recall:.4f} means {recall*100:.1f}% of true relationships were discovered"
)
print(f"- F1 score of {f1:.4f} balances precision and recall")

Digit Taxonomy Evaluation Results:
Edge Difference Ratio: 0.9022
Precision: 0.2000
Recall: 0.2000
F1 Score: 0.2000

Interpretation:
- EDR of 0.9022 means 90.2% of edges differ between learned and ground truth
- Precision of 0.2000 means 20.0% of learned relationships are correct
- Recall of 0.2000 means 20.0% of true relationships were discovered
- F1 score of 0.2000 balances precision and recall


In [11]:
# Analyze detailed relationship differences
print("Detailed Relationship Analysis:")
print("=" * 40)

# Get relationships from both taxonomies
learned_relationships = set()
for edge in learned_taxonomy.graph.edges(data=True):
    target, source, data = edge
    learned_relationships.add((target, source, round(data["weight"], 3)))

ground_truth_relationships = set()
for edge in ground_truth_taxonomy.graph.edges(data=True):
    target, source, data = edge
    ground_truth_relationships.add((target, source, round(data["weight"], 3)))

# Find correct, missed, and incorrect relationships
correct_relationships = learned_relationships & ground_truth_relationships
missed_relationships = ground_truth_relationships - learned_relationships
incorrect_relationships = learned_relationships - ground_truth_relationships

print(f"Total Ground Truth Relationships: {len(ground_truth_relationships)}")
print(f"Total Learned Relationships: {len(learned_relationships)}")
print(f"Correct Relationships: {len(correct_relationships)}")
print(f"Missed Relationships: {len(missed_relationships)}")
print(f"Incorrect Relationships: {len(incorrect_relationships)}")

# Show some examples of each type
print("\nSample Correct Relationships:")
for i, rel in enumerate(list(correct_relationships)[:5]):
    target, source, weight = rel
    target_str = f"Domain {target.domain_id}, Class {target.class_id}"
    source_str = f"Domain {source.domain_id}, Class {source.class_id}"
    print(f"  {target_str} -> {source_str} (confidence: {weight})")

if missed_relationships:
    print("\nSample Missed Relationships:")
    for i, rel in enumerate(list(missed_relationships)[:5]):
        target, source, weight = rel
        target_str = f"Domain {target.domain_id}, Class {target.class_id}"
        source_str = f"Domain {source.domain_id}, Class {source.class_id}"
        print(f"  {target_str} -> {source_str} (confidence: {weight})")

if incorrect_relationships:
    print("\nSample Incorrect Relationships:")
    for i, rel in enumerate(list(incorrect_relationships)[:5]):
        target, source, weight = rel
        target_str = f"Domain {target.domain_id}, Class {target.class_id}"
        source_str = f"Domain {source.domain_id}, Class {source.class_id}"
        print(f"  {target_str} -> {source_str} (confidence: {weight})")

Detailed Relationship Analysis:
Total Ground Truth Relationships: 20
Total Learned Relationships: 20
Correct Relationships: 0
Missed Relationships: 20
Incorrect Relationships: 20

Sample Correct Relationships:

Sample Missed Relationships:


AttributeError: 'DomainClass' object has no attribute 'domain_id'

In [None]:
# Build universal taxonomies
print("Building universal taxonomies...")

# Build universal taxonomy for learned relationships
learned_taxonomy.build_universal_taxonomy()
learned_taxonomy.visualize_graph(
    "Learned Digit Universal Taxonomy",
    height=800,
    width=1200,
).save_graph("output/digit_learned_universal_taxonomy.html")

# Build universal taxonomy for ground truth
ground_truth_taxonomy.build_universal_taxonomy()
ground_truth_taxonomy.visualize_graph(
    "Ground Truth Digit Universal Taxonomy",
    height=800,
    width=1200,
).save_graph("output/digit_ground_truth_universal_taxonomy.html")

print("Universal taxonomy visualizations saved to output/ directory.")

# Print summary of universal taxonomies
print(f"\nLearned Universal Taxonomy:")
print(f"  Nodes: {learned_taxonomy.graph.number_of_nodes()}")
print(f"  Edges: {learned_taxonomy.graph.number_of_edges()}")

print(f"\nGround Truth Universal Taxonomy:")
print(f"  Nodes: {ground_truth_taxonomy.graph.number_of_nodes()}")
print(f"  Edges: {ground_truth_taxonomy.graph.number_of_edges()}")

In [None]:
# Analyze digit confusion patterns
print("Digit Confusion Analysis:")
print("=" * 40)

# Create confusion matrices for cross-domain predictions
from collections import defaultdict

# MNIST model on SVHN data confusion
mnist_on_svhn_confusion = defaultdict(lambda: defaultdict(int))
for true_digit, pred_digit in zip(
    df_svhn["svhn_targets"], df_svhn["predictions_mnist_on_svhn"]
):
    mnist_on_svhn_confusion[true_digit][pred_digit] += 1

# SVHN model on MNIST data confusion
svhn_on_mnist_confusion = defaultdict(lambda: defaultdict(int))
for true_digit, pred_digit in zip(
    df_mnist["mnist_targets"], df_mnist["predictions_svhn_on_mnist"]
):
    svhn_on_mnist_confusion[true_digit][pred_digit] += 1

print("MNIST model predicting SVHN digits:")
print("True\\Pred", end="")
for i in range(10):
    print(f"{i:>6}", end="")
print()
for true_digit in range(10):
    print(f"{true_digit:>4}    ", end="")
    for pred_digit in range(10):
        count = mnist_on_svhn_confusion[true_digit][pred_digit]
        print(f"{count:>6}", end="")
    print()

print("\nSVHN model predicting MNIST digits:")
print("True\\Pred", end="")
for i in range(10):
    print(f"{i:>6}", end="")
print()
for true_digit in range(10):
    print(f"{true_digit:>4}    ", end="")
    for pred_digit in range(10):
        count = svhn_on_mnist_confusion[true_digit][pred_digit]
        print(f"{count:>6}", end="")
    print()

# Calculate per-digit accuracy
print("\nPer-digit cross-domain accuracy:")
print("Digit | MNIST→SVHN | SVHN→MNIST")
print("------|------------|------------")
for digit in range(10):
    mnist_correct = mnist_on_svhn_confusion[digit][digit]
    mnist_total = sum(mnist_on_svhn_confusion[digit].values())
    mnist_acc = mnist_correct / mnist_total if mnist_total > 0 else 0

    svhn_correct = svhn_on_mnist_confusion[digit][digit]
    svhn_total = sum(svhn_on_mnist_confusion[digit].values())
    svhn_acc = svhn_correct / svhn_total if svhn_total > 0 else 0

    print(f"  {digit}   |   {mnist_acc:.3f}    |   {svhn_acc:.3f}")