In [14]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data import get_dataloaders
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import pickle

def load_cifar10_labels(batch_file):
    """Load labels from a CIFAR-10 batch file."""
    with open(batch_file, 'rb') as f:
        batch = pickle.load(f, encoding='latin1')
    return np.array(batch['labels'])

def compute_metrics(y_true, y_pred, average='macro'):
    """Compute classification metrics."""
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=average)
    recall = recall_score(y_true, y_pred, average=average)
    f1 = f1_score(y_true, y_pred, average=average)
    return acc, precision, recall, f1

# Load CIFAR dataset
def load_cifar_labels():
    transform = transforms.Compose([transforms.ToTensor()])
    cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    # Extract true labels
    y_true = np.array([label for _, label in cifar_test])
    return y_true

cifar_hmc = np.genfromtxt("HMC_cifar_probs.csv")
print(cifar_hmc[1])
y_pred = np.argmax(cifar_hmc[:10000], axis=1)
print(y_pred)

# Path to the CIFAR-10 dataset
cifar10_path = "./data/cifar-10-batches-py/"

# Load test batch labels
test_labels = load_cifar10_labels(cifar10_path + "test_batch")
print(test_labels)

#data = np.load("evaluation_phase.npz")
#y_true = data["y_test"]

y_true = load_cifar_labels()
print(y_true)

[6.03314996e-01 1.45048290e-01 6.58268854e-03 5.31147420e-03
 2.76604253e-20 2.76627577e-29 0.00000000e+00 6.75450265e-02
 9.72232446e-02 7.49743432e-02]
[7 0 6 ... 8 6 0]
[3 8 8 ... 5 1 7]
Files already downloaded and verified
[3 8 8 ... 5 1 7]


In [9]:
# Compute metrics
acc, precision, recall, f1 = compute_metrics(y_true, y_pred, average='macro')

# Print metrics
print(f"Accuracy: {acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Accuracy: 0.8797
Precision: 0.8793
Recall: 0.8797
F1 Score: 0.8794
