In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import sys
import json
# For generating PDFs
from PIL import Image
import matplotlib.pyplot as plt

sys.path.append('../')
from dataset.hf_imagenet_dataset import HFImageNet

from huggingface_hub import hf_hub_download

In [None]:
analysis_dir = '../weights_CL_imagenet/with_inst_params_lr_0.8_wd_1e-8'

## Load  instance-level parameters from all epochs

In [None]:
# Initialize a dictionary to hold the loaded checkpoints
epoch_instance_level_temp = [] # (nr_epochs, nr_instances)
total_epochs = 120

# Loop through the checkpoint files and load each one
for epoch in tqdm(range(total_epochs), desc='Loading instance level parameters from all checkpoints'):
    checkpoint_path = f'{analysis_dir}/epoch_{epoch}.pth.tar'
    
    # Check if the file exists
    if os.path.isfile(checkpoint_path):
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path)
        epoch_instance_level_temp.append(checkpoint['inst_parameters'])
    else:
        print(f"Checkpoint file {checkpoint_path} not found.")

# Convert the entries to temperature and vertically stack across epochs
epoch_instance_level_temp = np.exp(np.vstack(epoch_instance_level_temp))


In [None]:
temp_convergence = epoch_instance_level_temp[-1,:]
print(f'Mean temperature at convergence: {np.mean(temp_convergence) :0.1f}')
print(f'Max temperature at convergence: {np.max(temp_convergence) :0.1f}')
print(f'Min temperature at convergence: {np.min(temp_convergence) :0.1f}')



## Load labels for entire dataset 

In [None]:
imagenet_dataset = HFImageNet(split='train', transform=None)
label_instances = np.array(imagenet_dataset.dataset['label'])

In [None]:
# Mapping of label-ids to name of classes
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k):v for k,v in id2label.items()}

## Analyze samples of particular class

In [None]:
class_idx = 107

# Find instances of this class
instance_class_idx = np.where(label_instances == class_idx)[0]

# Temperature 
instance_class_temp_convergence = temp_convergence[instance_class_idx]

# Sort the data points based on temperature (low to high)
sort_idx = np.argsort(instance_class_temp_convergence)

instance_class_idx = instance_class_idx[sort_idx]
instance_class_temp_convergence = instance_class_temp_convergence[sort_idx]

assert np.all(temp_convergence[instance_class_idx] == instance_class_temp_convergence), 'Both results should match'

## Capture the top-k easy and hard images

In [None]:
easy_images, hard_images = [], []
topk = 10

for print_idx, (instance_idx, instance_temp) in enumerate(zip(instance_class_idx, instance_class_temp_convergence)):

    # Print the top-k hardest and easiest samples of the class
    if (print_idx < topk):
        easy_images.append((imagenet_dataset.dataset[int(instance_idx)]['image'], instance_temp))

    if (print_idx >= len(instance_class_idx) - topk):
        hard_images.append((imagenet_dataset.dataset[int(instance_idx)]['image'], instance_temp))        

## Plot these images in a PDF and save them

In [None]:
def plot_images(images, row_label, start_subplot, ax_line=False):
    for idx, (img, temp) in enumerate(images):
        idx = idx + 1 # subplot starts the idx from 1
        ax = plt.subplot(2, len(images)+1, idx + 1 + start_subplot)
        #img = img.resize((256, 256), Image.Resampling.LANCZOS)
        ax.imshow(img)
        ax.set_title(f"{temp:.2f}", fontsize=16, color='white')
        ax.axis('off')  # Hide axis
            
    plt.figtext(0.01, 0.75 if start_subplot == 0 else 0.25, row_label, va='center', ha='left', fontsize=12, weight='bold', color='white')

In [None]:
# Assuming easy_images and hard_images are populated
plt.figure(figsize=(20, 8), facecolor='black')  # Adjust the figure size as needed
plot_images(easy_images, "Easy Images\n(score < 1.0)", 0, ax_line=True)
plot_images(hard_images, "Hard Images\n(score > 1.0)", len(easy_images)+1)
plt.suptitle(f"Analysis for images of visual category: {id2label[class_idx]}", fontsize=16, fontweight='bold', color='white')
plt.tight_layout()
plt.savefig(f"class_{class_idx}_{id2label[class_idx]}.pdf")
#plt.show()

## Automated loop over all classes

In [None]:
# We will do this analysis for 100 classes sampled uniformly from the dataset.
topk = 5

for class_idx in range(0, 1000,10):
    
    # Find instances of this class
    instance_class_idx = np.where(label_instances == class_idx)[0]
    
    # Temperature 
    instance_class_temp_convergence = temp_convergence[instance_class_idx]
    
    # Sort the data points based on temperature (low to high)
    sort_idx = np.argsort(instance_class_temp_convergence)
    
    instance_class_idx = instance_class_idx[sort_idx]
    instance_class_temp_convergence = instance_class_temp_convergence[sort_idx]

    easy_images, hard_images = [], []
    
    for print_idx, (instance_idx, instance_temp) in enumerate(zip(instance_class_idx, instance_class_temp_convergence)):
    
        # Print the top-k hardest and easiest samples of the class
        if (print_idx < topk):
            easy_images.append((imagenet_dataset.dataset[int(instance_idx)]['image'], instance_temp))
    
        if (print_idx >= len(instance_class_idx) - topk):
            hard_images.append((imagenet_dataset.dataset[int(instance_idx)]['image'], instance_temp))        

    # Assuming easy_images and hard_images are populated
    fig= plt.figure(figsize=(20, 8), facecolor='black')  # Adjust the figure size as needed
    plot_images(easy_images, "Easy Images\n(score < 1.0)", 0, ax_line=True)
    plot_images(hard_images, "Hard Images\n(score > 1.0)", len(easy_images)+1)
    plt.suptitle(f"Analysis for images of visual category: {id2label[class_idx]}", fontsize=16, fontweight='bold', color='white')
    plt.tight_layout()
    plt.savefig(f"easy_vs_hard/class_{class_idx}_{id2label[class_idx]}.pdf")
    plt.close(fig)