In [None]:
import matplotlib.pyplot as plt
import os
os.environ.update(dict(CUDA_VISIBLE_DEVICES='3'))

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
data = torch.load('/ssd1/tta/imagenet_val_resnet50_lyrfts_full.pth')
# data = torch.load('/ssd1/tta/inc4_resnet50_shf_bn_full.pth')

In [None]:
#features, labels
features_inc = data['features']
features_im = data['ifeatures'] # 50000, C (2048, 1024, 512, 256)
labels = data['labels']
logits_inc = data['logits']
logits_im = data['ilogits']
correct = data['correct'].type(torch.int)

layer = -1

def cidx(i):
    return labels == i

In [None]:
v = torch.stack(
   (features_im[2].mean(0), features_im[2].var(0))
)
print(v.shape)
torch.save(v, '/ssd1/tta/imagenet_val_resnet50_lyr3_stat.pth')

In [None]:
v = torch.stack(
    tuple(features_im[2][cidx(i)].std(0) for i in range(1000))
).mean(0)
print(v.shape)
v = (v-v.mean())/v.std()
# v = (v-v.min())/(v.max()-v.min())
pd.Series(v.numpy()).hist(bins=100)

torch.save(v, '/ssd1/tta/imagenet_val_resnet50_lyr3_std.pth')

In [None]:
var2 = features_im[2].std(dim=0)
var2 = (var2 - var2.mean()) / var2.std()
# var2 = np.exp(var2)
# var2 = 1 / (1 + np.exp(-var2))
# var2 = np.clip(var2, a_min=0, a_max=10)
pd.Series(var2.numpy()).hist(bins=100)

In [None]:
from openTSNE import TSNE
import openTSNE.callbacks
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from tqdm import tqdm
class ProgressCallback(openTSNE.callbacks.Callback):
    def __init__(self, pbar: tqdm, step: int=1) -> None:
        super().__init__()
        self.pbar = pbar
        self.step = step

    def __call__(self, iteration, error, embedding):
        self.pbar.update(self.step)
        return False

In [None]:
import matplotlib as mpl
def visualize_tsne(features: np.ndarray, labels: np.ndarray, label_names: list[str]=None,
                   figsize=(10, 10), dimension=2, perplexity=30, legend_nrow=2):
    
    print(f'{features.shape=}, {labels.shape=}')

    with tqdm(total=750) as pbar:
        tsne = TSNE(n_jobs=8, 
                    n_components=dimension, 
                    perplexity=perplexity, 
                    callbacks_every_iters=1,
                    callbacks=ProgressCallback(pbar, 1))
        trained = tsne.fit(features)

    cluster = np.array(trained)

    print('t-SNE computed, waiting for plot...')

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot() if dimension < 3 else fig.add_subplot(projection='3d')
    
    classes = np.unique(labels)
    ncls = len(classes)//2
    for i in classes:
        idx = np.where(labels == i)
        ax_args = dict(
            marker = 'o' if i < ncls else '^', 
            label = i if label_names is None else label_names[int(i)], 
            edgecolors = 'face' if i<10 else '#000000bb', 
            linewidths = 0.5,
            c=mpl.color_sequences['tab10'][int(i%ncls)]
        )

        if dimension < 3:
            ax.scatter(cluster[idx, 0], cluster[idx, 1], **ax_args)
        else:
            ax.scatter(cluster[idx, 0], cluster[idx, 1] ,cluster[idx, 2], **ax_args)
            
    ax.autoscale()

    plt.legend(loc='lower center', ncol=len(classes)//legend_nrow, bbox_to_anchor=(0.5, -0.05))
    plt.axis('off')
    plt.show()

    return cluster, fig




In [None]:
# tsne_num = 2048
# tsne_idx = np.random.choice(len(features_im[-1]), tsne_num)

tsne_idx = cidx(0)
tsne_num = sum(tsne_idx)
layer = 3

ncls = 4
tsne_data = tuple(features_im[layer][cidx(i)] for i in range(ncls)) + tuple(features_inc[layer][cidx(i)] for i in range(ncls))
tsne_fts = np.concatenate(tsne_data)
# tsne_labels = np.concatenate((np.zeros(tsne_num*ncls), np.ones(tsne_num*ncls)))
tsne_labels = np.concatenate(tuple(np.ones(tsne_num) * i for i in range(len(tsne_data))))

t_ncls = ncls
_,_ = visualize_tsne(tsne_fts, tsne_labels, 
                     [f"ImageNet(C{i})" for i in range(t_ncls)] + [f"INC(C{i})" for i in range(t_ncls)],
                  perplexity=15, dimension=2, figsize=(10, 5))