In [None]:
import pickle 
import torch 
import os 
import numpy as np 
import seaborn as sns 
import matplotlib.pyplot as plt 
from tqdm import tqdm 
import json 
from datasets import get_datasets
import torchvision.transforms as T 
DATA_PATH = '/data3/bumjin/bumjin_data/ILSVRC2012_val/'
transform = T.Compose([
                T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
                T.CenterCrop(224),
                ])

In [None]:
encoder = "resnet34"
path = f"results/{encoder}"

labels = json.load(open('labels.json', 'rb'))
gaps = pickle.load(open(os.path.join(path, "valid_gap.pkl" ), 'rb')) # L x 50000 x C
cls_statistics = pickle.load(open(os.path.join(path, "cls_statistics.pkl" ), 'rb'))  # L C CLS
hiddens = torch.stack(pickle.load(open(os.path.join(path, "valid_hidden.pkl" ), 'rb')))
_, valid_dataset = get_datasets('imagenet1k', DATA_PATH, transform)


# values 
num_channels = [g.size(1) for g in gaps]
linearized_gaps = torch.concat(gaps, dim=1)
linearized_cls_statistics = torch.concat(cls_statistics, dim=0).permute(1,0)
linearized_cls_statistics = torch.nan_to_num(linearized_cls_statistics)

gaps[0].size(), cls_statistics[0].size(), hiddens.size(), linearized_cls_statistics.size()

In [None]:
def cosine_sim(a,b):
    return torch.dot(a, b)/ (torch.norm(a) * torch.norm(b) + 1e-15)

def gap_pure(cls_mu_lineared_gaps, cls, i):    
    mu_cls = cls_mu_lineared_gaps[cls]
    mu_i = cls_mu_lineared_gaps[i]
    return cosine_sim(mu_cls, mu_i)

def measure_hidden(cls_mu_hidden, cls, i):    
    mu_cls = cls_mu_hidden[cls]
    mu_i = cls_mu_hidden[i]
    return cosine_sim(mu_cls, mu_i)

def gap_stats(linearized_cls_statistics, cls, i ):
    mu_cls = linearized_cls_statistics[cls]
    mu_i = linearized_cls_statistics[i]
    return cosine_sim(mu_cls, mu_i)

def get_top_k_labels(v, sims, top_k=5):
    sim = sims[v]
    values, indices = torch.sort(sim, descending=True)
    return  [labels[i] for i in indices[:top_k]], \
            [i.item() for i in indices[:top_k]], \
            [values[i] for i in indices[:top_k]]

def plot_random_samples(classes, N=5):
    ratio = 2
    np.random.seed(3)
    fig, axes = plt.subplots(len(classes), N, figsize=(ratio*N,ratio*len(classes) ))
    af = axes.flat
    for cls in classes:
        # indices = [i for i in range(cls*50, (cls+1)*50)][:N]
        indices = np.random.choice([i for i in range(cls*50, (cls+1)*50)], size=(N,), replace=False)
        for i in range(N):
            ax = next(af)
            ax.imshow(valid_dataset[indices[i]][0])
            ax.set_xticks([])
            ax.set_yticks([])
            if i ==0:
                ax.set_title(labels[cls], rotation=0, fontsize=15)
    return fig, axes 
    


In [None]:
# gap_pure for all class 
DEVICE = "cuda:0"

if  os.path.exists(f"{path}/sims.pkl"):
    sims  = pickle.load(open(f"results/{encoder}/sims.pkl", 'rb'))
else:
    # define the sim matrix
    sims ={
        'gap_pure'  : torch.zeros(1000,1000).to(DEVICE),
        'gap_stats' : torch.zeros(1000,1000).to(DEVICE),
        'hidden'    : torch.zeros(1000,1000).to(DEVICE),
    }
    linearized_gaps = linearized_gaps.to(DEVICE)
    cls_sample_linearized_gaps = torch.stack(linearized_gaps.split([50 for _ in range(1000)], dim=0)) # 1000 x 50 x CHL
    cls_mu_lineared_gaps = cls_sample_linearized_gaps.mean(dim=1) # 1000 x CHL
    cls_mu_lineared_gaps = cls_mu_lineared_gaps / cls_mu_lineared_gaps.norm(dim=1).unsqueeze(1)

    for cls in tqdm(range(1000)):
        for i in range(1000):
            sim_cls_i = gap_pure(cls_mu_lineared_gaps, cls, i)
            sims['gap_pure'][cls,i] = sim_cls_i

    linearized_cls_statistics = linearized_cls_statistics.to(DEVICE)
    linearized_cls_statistics = linearized_cls_statistics / linearized_cls_statistics.norm(dim=1).unsqueeze(1)

    for cls in tqdm(range(1000)):
        for i in range(1000):
            sim_cls_i = gap_stats(linearized_cls_statistics, cls, i)
            sims['gap_stats'][cls,i] = sim_cls_i
            
    # hidden state 
    cls_sample_hidden =  torch.stack(hiddens.split([50 for _ in range(1000)], dim=0)) # 1000 x 50 x D
    cls_mu_hidden = cls_sample_hidden.mean(dim=1) # 1000 x CHL
    cls_mu_hidden = cls_mu_hidden / cls_mu_hidden.norm(dim=1).unsqueeze(1)

    for cls in tqdm(range(1000)):
        for i in range(1000):
            sim_cls_i = measure_hidden(cls_mu_lineared_gaps, cls, i)
            sims['hidden'][cls,i] = sim_cls_i

    for k,v in sims.items():
        sims[k] = v.cpu()
    with open(f"{path}/sims.pkl", 'wb') as f:
        pickle.dump(sims, f, pickle.HIGHEST_PROTOCOL)

# Check the Results

In [None]:
CLS=10
top_k_labels, top_k_indices, top_k_sims = get_top_k_labels(CLS, sims['gap_pure'], top_k=10)
print(top_k_indices)
top_k_labels, top_k_indices, top_k_sims = get_top_k_labels(CLS, sims['gap_stats'], top_k=10)
print(top_k_indices)
top_k_labels, top_k_indices, top_k_sims = get_top_k_labels(CLS, sims['hidden'], top_k=10)
print(top_k_indices)

In [None]:
# CLS=8
# for key in ['gap_pure', 'gap_stats', 'hidden']:
#     top_k_labels, top_k_indices, top_k_sims = get_top_k_labels(CLS, sims[key], top_k=3)
#     fig, axes = plot_random_samples(top_k_indices,  N=3)
#     # plt.savefig(f"results/imgs_{CLS}_{key}.pdf")



In [None]:
CLS=850
TOP_K=10
fig, axes = plt.subplots(2, TOP_K, figsize=(2*TOP_K, 2*2))
af = axes.flat
for key in ['gap_stats', 'hidden']:
    top_k_labels, top_k_indices, top_k_sims = get_top_k_labels(CLS, sims[key], top_k=TOP_K)
    for i in range(TOP_K):
        ax = next(af)
        ax.imshow(valid_dataset[top_k_indices[i]*50+13][0])
        ax.set_xticks([])
        ax.set_yticks([])
        title = labels[top_k_indices[i]]
        if len(title) >10:
            title = " ".join(title.split(" ")[-2:])
        ax.set_title(title, rotation=0, fontsize=12)
plt.savefig(f"results/img_{CLS}.pdf")