In [1]:
import torch
import numpy as np
import torch.nn.functional as F
import warnings
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from tqdm import tqdm
import pandas as pd
import scipy.spatial as sp

import clip
from transformers import CLIPProcessor, CLIPModel

from PIL import Image
from scipy.spatial.distance import correlation


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
warnings.filterwarnings('ignore')
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [2]:
import os
import torch
from PIL import Image
import torchvision.transforms as T

save_true_dir = "/srv/nfs-data/sisko/matteoc/monkeys/generative/img_real"
save_recon_dir = "/srv/nfs-data/sisko/matteoc/monkeys/generative/img_gen_linavg"
true_image_files = sorted(os.listdir(save_true_dir))
recon_image_files = sorted(os.listdir(save_recon_dir))

true_images = []
recon_images = []
easy_transform = transforms.ToTensor()
transform_real = transforms.Compose([
    transforms.Resize((480, 480)),  
    transforms.ToTensor()
])

for filename in tqdm(true_image_files):
    img_path = os.path.join(save_true_dir, filename)
    img = Image.open(img_path).convert("RGB")  
    img_tensor = transform_real(img)
    true_images.append(img_tensor)

for filename in tqdm(recon_image_files):
    img_path = os.path.join(save_recon_dir, filename)
    img = Image.open(img_path).convert("RGB") 
    img_tensor = easy_transform(img)
    recon_images.append(img_tensor)

true_images_tensor = torch.stack(true_images) 
recon_images_tensor = torch.stack(recon_images) 

print("Shape true images tensor:", true_images_tensor.shape)
print("Shape reconstructed images tensor:", recon_images_tensor.shape)

100%|██████████| 100/100 [00:02<00:00, 37.19it/s]
100%|██████████| 100/100 [00:04<00:00, 21.47it/s]


Shape true images tensor: torch.Size([100, 3, 480, 480])
Shape reconstructed images tensor: torch.Size([100, 3, 1024, 1024])


In [8]:
@torch.no_grad()
def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device='cuda:3'):
    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(all_images)-1)
        return perf
    else:
        return success_cnt, len(all_images)-1
    

def encode_image_hf(clip_model, img_batch):
    return clip_model.get_image_features(pixel_values=img_batch)


def cal_metrics(all_images, all_brain_recons, device):
    all_images = all_images[:].to(device)
    all_brain_recons = torch.stack([img for img in all_brain_recons[:]]).to(device).to(all_images.dtype).clamp(0,1).squeeze()

    print("Images shape:", all_images.shape)
    print("Recons shape:", all_brain_recons.shape)

    # Ensure both tensors are the same size for MSE
    resize = transforms.Resize((all_images.size(2), all_images.size(3)), interpolation=transforms.InterpolationMode.BILINEAR)
    all_brain_recons = resize(all_brain_recons)

    print("Images shape after resize:", all_images.shape)
    print("Recons shape after resize:", all_brain_recons.shape)

    # Preprocess
    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    # Flatten images while keeping the batch dimension
    all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).to(device).cpu()
    all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()

    print(all_images_flattened.shape)
    print(all_brain_recons_flattened.shape)

    # PixCorr
    print("\n------calculating pixcorr------")
    corrsum = 0
    for i in tqdm(range(len(all_images))):
        corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]
    pixcorr = corrsum / len(all_images)
    print("PixCorr:", pixcorr)

    # SSIM
    preprocess = transforms.Compose([
        transforms.Resize(625, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu().numpy())
    recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu().numpy())
    print("converted, now calculating ssim...")

    ssim_score=[]
    for im, rec in tqdm(zip(img_gray, recon_gray), total=len(all_images)):
        ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))

    ssim_mean = np.mean(ssim_score)
    print("SSIM:", ssim_mean)

    # MSE
    mse = torch.nn.functional.mse_loss(all_brain_recons, all_images).item()
    print("MSE:", mse)

    # Cosine Similarity
    cosine_sim = torch.nn.functional.cosine_similarity(all_brain_recons_flattened, all_images_flattened).mean().item()
    print("Cosine Similarity:", cosine_sim)

    # Feature-based evaluations using different models
    def evaluate_model(model, preprocess, feature_layers, layer_names):
        results = {}
        for feature_layer, layer_name in zip(feature_layers, layer_names):
            print(f"\n---{layer_name}---")
            all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                     model, preprocess, feature_layer, device=device)
            results[layer_name] = np.mean(all_per_correct)
            print(f"2-way Percent Correct: {results[layer_name]:.4f}")
        return results

    # AlexNet
    from torchvision.models import alexnet, AlexNet_Weights
    alex_weights = AlexNet_Weights.IMAGENET1K_V1
    alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4', 'features.11']).to(device)
    alex_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    alexnet_results = evaluate_model(alex_model, preprocess, ['features.4', 'features.11'], ['AlexNet(2)', 'AlexNet(5)'])
    del alex_model
    torch.cuda.empty_cache()

    # InceptionV3
    from torchvision.models import inception_v3, Inception_V3_Weights
    inception_weights = Inception_V3_Weights.DEFAULT
    inception_model = create_feature_extractor(inception_v3(weights=inception_weights), return_nodes=['avgpool']).to(device)
    inception_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    inception_results = evaluate_model(inception_model, preprocess, ['avgpool'], ['InceptionV3'])
    del inception_model
    torch.cuda.empty_cache()

    # CLIP
    clip_model, preprocess = clip.load("ViT-L/14", device=device)
    # clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
    ])

    all_per_correct = two_way_identification(all_brain_recons, all_images,
                                            clip_model.encode_image, preprocess, None) # final layer
    clip_results = np.mean(all_per_correct)
    print("CLIP:", clip_results)

    # EfficientNet
    from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
    eff_weights = EfficientNet_B1_Weights.DEFAULT
    eff_model = create_feature_extractor(efficientnet_b1(weights=eff_weights), return_nodes=['avgpool']).to(device)
    eff_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = eff_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = eff_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    effnet_distance = np.array([sp.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("EffNet Distance:", effnet_distance)
    del eff_model
    torch.cuda.empty_cache()

    # SwAV
    swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    swav_model = create_feature_extractor(swav_model, return_nodes=['avgpool']).to(device)
    swav_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = swav_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = swav_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    swav_distance = np.array([correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("SwAV Distance:", swav_distance)
    del swav_model
    torch.cuda.empty_cache()

    # Save the results
    metrics = {
        'PixCorr': [pixcorr],
        'SSIM': [ssim_mean],
        'MSE': [mse],
        'Cosine Similarity': [cosine_sim],
        'AlexNet(2)': [alexnet_results["AlexNet(2)"]],
        'AlexNet(5)': [alexnet_results["AlexNet(5)"]],
        'InceptionV3': [inception_results["InceptionV3"]],
        'CLIP': [clip_results],  # corrected line
        'EffNet Distance': [effnet_distance],
        'SwAV Distance': [swav_distance]
    }
    return metrics 




In [3]:
@torch.no_grad()
def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device='cuda:3'):
    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    # if return_avg:
    #     perf = np.mean(success_cnt) / (len(all_images)-1)
    #     return perf
    # else:
    #     return success_cnt, len(all_images)-1

    if return_avg:
        perf_per_sample = success_cnt / (len(all_images)-1)
        return perf_per_sample  # restituisce tutti i valori
    else:
        return success_cnt, len(all_images)-1

    

def encode_image_hf(clip_model, img_batch):
    return clip_model.get_image_features(pixel_values=img_batch)


def cal_metrics_std(all_images, all_brain_recons, device):
    all_images = all_images[:].to(device)
    all_brain_recons = torch.stack([img for img in all_brain_recons[:]]).to(device).to(all_images.dtype).clamp(0,1).squeeze()

    print("Images shape:", all_images.shape)
    print("Recons shape:", all_brain_recons.shape)

    # Ensure both tensors are the same size for MSE
    resize = transforms.Resize((all_images.size(2), all_images.size(3)), interpolation=transforms.InterpolationMode.BILINEAR)
    all_brain_recons = resize(all_brain_recons)

    print("Images shape after resize:", all_images.shape)
    print("Recons shape after resize:", all_brain_recons.shape)

    # Preprocess
    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    # Flatten images while keeping the batch dimension
    all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).to(device).cpu()
    all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()

    print(all_images_flattened.shape)
    print(all_brain_recons_flattened.shape)

    # PixCorr
    print("\n------calculating pixcorr------")
    corrsum = 0
    for i in tqdm(range(len(all_images))):
        corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]
    pixcorr_all = [np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1] for i in range(len(all_images))]
    pixcorr = np.mean(pixcorr_all)
    pixcorr_std = np.std(pixcorr_all)
    print(f"PixCorr: {pixcorr:.4f} ± {pixcorr_std:.4f}")

    # SSIM
    preprocess = transforms.Compose([
        transforms.Resize(625, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu().numpy())
    recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu().numpy())
    print("converted, now calculating ssim...")

    ssim_score=[]
    for im, rec in tqdm(zip(img_gray, recon_gray), total=len(all_images)):
        ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))

    ssim_mean = np.mean(ssim_score)
    ssim_std = np.std(ssim_score)
    print(f"SSIM: {ssim_mean:.4f} ± {ssim_std:.4f}")


    # MSE
    mse = torch.nn.functional.mse_loss(all_brain_recons, all_images).item()
    print("MSE:", mse)

    # Cosine Similarity
    cosine_sim = torch.nn.functional.cosine_similarity(all_brain_recons_flattened, all_images_flattened).mean().item()
    cosine_sim_std = torch.nn.functional.cosine_similarity(all_brain_recons_flattened, all_images_flattened).std().item()
    print(f"Cosine Sim: {cosine_sim:.4f} ± {cosine_sim_std:.4f}")


    # Feature-based evaluations using different models
    def evaluate_model(model, preprocess, feature_layers, layer_names):
        results = {}
        for feature_layer, layer_name in zip(feature_layers, layer_names):
            print(f"\n---{layer_name}---")
            all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                     model, preprocess, feature_layer, device=device)
            results[layer_name] = {
            'mean': np.mean(all_per_correct),
            'std': np.std(all_per_correct)
        }
        print(f"2-way Percent Correct: {results[layer_name]['mean']:.4f} ± {results[layer_name]['std']:.4f}")
        return results

    # AlexNet
    from torchvision.models import alexnet, AlexNet_Weights
    alex_weights = AlexNet_Weights.IMAGENET1K_V1
    alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4', 'features.11']).to(device)
    alex_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    alexnet_results = evaluate_model(alex_model, preprocess, ['features.4', 'features.11'], ['AlexNet(2)', 'AlexNet(5)'])
    del alex_model
    torch.cuda.empty_cache()

    # InceptionV3
    from torchvision.models import inception_v3, Inception_V3_Weights
    inception_weights = Inception_V3_Weights.DEFAULT
    inception_model = create_feature_extractor(inception_v3(weights=inception_weights), return_nodes=['avgpool']).to(device)
    inception_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    inception_results = evaluate_model(inception_model, preprocess, ['avgpool'], ['InceptionV3'])
    del inception_model
    torch.cuda.empty_cache()

    # CLIP
    clip_model, preprocess = clip.load("ViT-L/14", device=device)
    # clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
    ])

    all_per_correct = two_way_identification(all_brain_recons, all_images,
                                            clip_model.encode_image, preprocess, None) # final layer
    clip_results = np.mean(all_per_correct)
    clip_results_std = np.std(all_per_correct)
    print(f"CLIP: {clip_results:.4f} ± {clip_results_std:.4f}")

    # EfficientNet
    from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
    eff_weights = EfficientNet_B1_Weights.DEFAULT
    eff_model = create_feature_extractor(efficientnet_b1(weights=eff_weights), return_nodes=['avgpool']).to(device)
    eff_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = eff_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = eff_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    # effnet_distance = np.array([sp.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    effnet_all = [sp.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]
    effnet_distance = np.mean(effnet_all)
    effnet_std = np.std(effnet_all)
    print(f"EffNet Distance: {effnet_distance:.4f} ± {effnet_std:.4f}")

    del eff_model
    torch.cuda.empty_cache()

    # SwAV
    swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    swav_model = create_feature_extractor(swav_model, return_nodes=['avgpool']).to(device)
    swav_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = swav_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = swav_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    # swav_distance = np.array([correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    swav_all = [correlation(gt[i], fake[i]) for i in range(len(gt))]
    swav_distance = np.mean(swav_all)
    swav_std = np.std(swav_all)
    print(f"SwAV Distance: {swav_distance:.4f} ± {swav_std:.4f}")

    del swav_model
    torch.cuda.empty_cache()

    # Save the results
    metrics = {
        'PixCorr': (pixcorr, pixcorr_std),
        'SSIM': (ssim_mean, ssim_std),
        'MSE': (mse, 0.0),  
        'Cosine Similarity': (cosine_sim, cosine_sim_std), 
        'AlexNet(2)': (alexnet_results["AlexNet(2)"]['mean'], alexnet_results["AlexNet(2)"]['std']),
        'AlexNet(5)': (alexnet_results["AlexNet(5)"]['mean'], alexnet_results["AlexNet(5)"]['std']),
        'InceptionV3': (inception_results["InceptionV3"]['mean'], inception_results["InceptionV3"]['std']),
        'CLIP': (clip_results, clip_results_std),
        'EffNet Distance': (effnet_distance, effnet_std),
        'SwAV Distance': (swav_distance, swav_std)
    }

    return metrics 




In [None]:
def calculate_subject_wise_metrics(all_images, all_brain_recons, device):
    subject_results = {}

    print("\nProcessing subject...")
    metrics = cal_metrics_std(all_images, all_brain_recons, device)
    subject_results["subject"] = metrics

    print("\nMetrics:")
    for k, (mean, std) in metrics.items():
        print(f"{k}: {mean:.4f} ± {std:.4f}")


calculate_subject_wise_metrics(true_images_tensor, recon_images_tensor, 'cuda:3')

In [None]:
# METRICS for Linear/TimeAtt:
# PixCorr: 0.1402 ± 0.1626
# SSIM: 0.3638 ± 0.1994
# MSE: 0.1095 ± 0.0000
# Cosine Similarity: 0.8107 ± 0.1075
# AlexNet(2): 0.8876 ± 0.1614
# AlexNet(5): 0.9542 ± 0.0791
# InceptionV3: 0.8679 ± 0.2254
# CLIP: 0.8785 ± 0.2012
# EffNet Distance: 0.7916 ± 0.1441
# SwAV Distance: 0.4915 ± 0.1105

# METRICS for MLP/TimeAtt:
# PixCorr: 0.1514 ± 0.1671
# SSIM: 0.3563 ± 0.2017
# MSE: 0.1106 ± 0.0000
# Cosine Similarity: 0.8146 ± 0.1229
# AlexNet(2): 0.8812 ± 0.1770
# AlexNet(5): 0.9432 ± 0.0954
# InceptionV3: 0.8361 ± 0.2363
# CLIP: 0.8416 ± 0.2289
# EffNet Distance: 0.8223 ± 0.1391
# SwAV Distance: 0.5155 ± 0.1138

# METRICS for Linear/AvgTime:
# PixCorr: 0.1563 ± 0.1630
# SSIM: 0.3416 ± 0.2062
# MSE: 0.1072 ± 0.0000
# Cosine Similarity: 0.8007 ± 0.1003
# AlexNet(2): 0.8762 ± 0.1678
# AlexNet(5): 0.9502 ± 0.0752
# InceptionV3: 0.8079 ± 0.2303
# CLIP: 0.8265 ± 0.2454
# EffNet Distance: 0.8270 ± 0.1367
# SwAV Distance: 0.5281 ± 0.1165


## TOP1-Accuracy

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])

true_images_tensor_VIT = preprocess(true_images_tensor)
recon_images_tensor_VIT = preprocess(recon_images_tensor)

In [None]:
import torch
import numpy as np
from torchmetrics.functional import accuracy
from torchvision.models import ViT_H_14_Weights, vit_h_14
from tqdm import tqdm

# Load ViT model and preprocessing
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
weights = ViT_H_14_Weights.DEFAULT
vit_model = vit_h_14(weights=weights).to(device)
vit_model.eval()
preprocess = weights.transforms()

# Function to compute accuracy
def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1):
    pick_range = [i for i in range(len(pred)) if i != class_id]
    acc_list = []
    for _ in range(num_trials):
        idxs_picked = np.random.choice(pick_range, n_way - 1, replace=False)
        pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]])
        acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), 
                       task='multiclass', num_classes=n_way, top_k=top_k)
        acc_list.append(acc.item())
    return np.mean(acc_list), np.std(acc_list)

# Initialize storage
all_acc = []
all_std = []

# Process images
for i in tqdm(range(len(recon_images_tensor_VIT)), desc="Processing images"):
    # Preprocess images
    image = preprocess(true_images_tensor_VIT[i].unsqueeze(0)).to(device)
    recon_image = preprocess(recon_images_tensor_VIT[i].unsqueeze(0)).to(device)

    # Get model outputs
    recon_image_out = vit_model(recon_image).squeeze(0).softmax(0).detach()
    gt_class_id = vit_model(image).squeeze(0).softmax(0).argmax().item()

    # Compute accuracy
    acc, std = n_way_top_k_acc(recon_image_out, gt_class_id, 50, 1000, 1)
    all_acc.append(acc)
    all_std.append(std)

# Compute and print final results
mean_acc = np.mean(all_acc)
mean_std = np.mean(all_std)

print(f"Overall mean acc: {mean_acc:.4f}, Overall mean std: {mean_std:.4f}")


Downloading: "https://download.pytorch.org/models/vit_h_14_swag-80465313.pth" to /home/matteoc/.cache/torch/hub/checkpoints/vit_h_14_swag-80465313.pth
100%|██████████| 2.36G/2.36G [00:06<00:00, 381MB/s]
Processing images: 100%|██████████| 100/100 [00:51<00:00,  1.92it/s]

Overall mean acc: 0.3239, Overall mean std: 0.2301



