In [None]:
import torch
import torch.nn.functional as F
import numpy as np

device = "cuda" if torch.cuda.is_available() else 'cpu'
print(f"Current device: {device}")
torch.set_default_device(device)

!cp ../MIS.py ./
!cp ../metric.py ./
import MIS
from MIS import task_config, run_psychophysics

In [None]:
from torchvision.datasets import CIFAR10

train_CIFAR10 = CIFAR10(root="./", download=True, train=True)
CIFAR10_data = [data[0] for data in train_CIFAR10]
del train_CIFAR10

In [None]:
activations_original = torch.load("example-activations.pt", map_location=device, weights_only=True)
neuron_axes = torch.load("example-axes.pt", map_location=device, weights_only=True)
activations_new = torch.t(torch.matmul(neuron_axes , torch.t(activations_original)) )

CIFAR10_task_old = task_config(device, CIFAR10_data, activations_original)
CIFAR10_task_new = task_config(device, CIFAR10_data, activations_new)

In [None]:
metric_type = "dreamsim"
K = 9
N = 100
quantile = 0.25
alpha=None

MIS_old = run_psychophysics(CIFAR10_task_old, metric_type=metric_type, K=K, N=N, quantile=quantile, alpha=alpha)
MIS_new = run_psychophysics(CIFAR10_task_new, metric_type=metric_type, K=K, N=N, quantile=quantile, alpha=alpha)

In [None]:
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
ax.hist(MIS_old.detach().cpu().numpy(), bins=40, alpha=0.5, label="old")
ax.hist(MIS_new.detach().cpu().numpy(), bins=40, alpha=0.5, label="new")
ax.legend()
ax.set_xlabel("MIS")
ax.set_ylabel("Frequency")
ax.set_title(f"MIS Distribution for {metric_type} metric")
plt.show()
