In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import sys
import json

from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from huggingface_hub import hf_hub_download
from sklearn.manifold import TSNE

sys.path.append('../')
from dataset.hf_imagenet_dataset import HFImageNet

In [None]:
analysis_dir = '../weights_CL_imagenet/with_inst_params_lr_0.8_wd_1e-8'

## Load  instance-level parameters from all epochs

In [None]:
# Initialize a dictionary to hold the loaded checkpoints
epoch_instance_level_temp = [] # (nr_epochs, nr_instances)
total_epochs = 120

# Loop through the checkpoint files and load each one
for epoch in tqdm(range(total_epochs), desc='Loading instance level parameters from all checkpoints'):
    checkpoint_path = f'{analysis_dir}/epoch_{epoch}.pth.tar'
    
    # Check if the file exists
    if os.path.isfile(checkpoint_path):
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path)
        epoch_instance_level_temp.append(checkpoint['inst_parameters'])
    else:
        print(f"Checkpoint file {checkpoint_path} not found.")

# Convert the entries to temperature and vertically stack across epochs
epoch_instance_level_temp = np.exp(np.vstack(epoch_instance_level_temp))


In [None]:
temp_convergence = epoch_instance_level_temp[-1,:]
print(f'Mean temperature at convergence: {np.mean(temp_convergence) :0.1f}')
print(f'Max temperature at convergence: {np.max(temp_convergence) :0.1f}')
print(f'Min temperature at convergence: {np.min(temp_convergence) :0.1f}')

## Load labels for entire dataset 

In [None]:
imagenet_dataset = HFImageNet(split='train', transform=None)
label_instances = np.array(imagenet_dataset.dataset['label'])

In [None]:
# Mapping of label-ids to name of classes
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k):v for k,v in id2label.items()}

## Analyze samples of particular class

In [None]:
class_idx = 145

# Find instances of this class and their temperature value
instance_class_idx = np.where(label_instances == class_idx)[0] # (N, 1)
instance_class_temp_convergence = temp_convergence[instance_class_idx] # (N, 1)
instance_class_temp_all_epochs = epoch_instance_level_temp[:, instance_class_idx] # (nr_epochs, N)


## Compute T-SNE embedding of data

In [None]:
features_instances = instance_class_temp_all_epochs.T #(N, D=nr_epochs)
coordinates = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=50).fit_transform(features_instances) # (N, 2)

## Plot the result of T-SNE in a chart

In [None]:
# Read the list of images for all instances in this class
list_all_instances_images = [imagenet_dataset.dataset[int(idx)]['image'] for idx in instance_class_idx]

In [None]:
np.all(instance_class_temp_all_epochs[-1, :] == instance_class_temp_convergence)

In [None]:
# Classify each instance as easy vs hard based on temperature value and use it to color the border
all_border_label = []
for idx in range(len(instance_class_idx)):
    if instance_class_temp_convergence[idx] > 1.6:
        all_border_label.append('red')
    else:
        all_border_label.append('green')

    # Check if this is an image which was ignored at the start but then learnt towards the end
    temperature_trajectory = instance_class_temp_all_epochs[:, idx]
    delta_high_low = np.max(temperature_trajectory) - np.min(temperature_trajectory)

    if np.max(temperature_trajectory) > 1.6 and delta_high_low > 0.5 and temperature_trajectory[-1] < 1.0:
        all_border_label[-1] = 'orange'
    
outliers = (np.array(all_border_label) == 'red').sum()
active_learning = (np.array(all_border_label) == 'orange').sum()
print(f'{outliers} elements classifed as outliers')
print(f'{active_learning} elements classifed as active-learning candidates')

In [None]:
# Function to plot each image
def plot_image(x, y, image, ax, border_color='white'):
    # Add a white border to the image
    border_size = 10  # Adjust border size as needed
    image = ImageOps.expand(image, border=border_size, fill=border_color)
    
    im = OffsetImage(image, zoom=0.2)  # Adjust zoom as needed
    ab = AnnotationBbox(im, (x, y), frameon=False, pad=0)
    ax.add_artist(ab)

In [None]:
fig, ax = plt.subplots(figsize=(25*4, 10*4))
subset_num = -1
img_size = 512
ax.scatter(coordinates[:subset_num, 0], coordinates[:subset_num, 1])  # Plotting the points (optional)

# Plot each image at its coordinates
for coord, image, border_label in tqdm(zip(coordinates[:subset_num, :], list_all_instances_images[:subset_num], all_border_label[:subset_num])):
    plot_image(coord[0], coord[1], image.resize((img_size, img_size)), ax, border_label)
    
fig.patch.set_facecolor('black')
ax.axis('off')
plt.suptitle(f"Analysis for images of visual category: {id2label[class_idx]}", fontsize=16, fontweight='bold', color='white')
fig.savefig(f"class_{class_idx}_{id2label[class_idx]}.pdf", bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor())
plt.show()