In [2]:
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import models
import torch
import ipywidgets as widgets
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from imbalance_cifar import IMBALANCECIFAR10, IMBALANCECIFAR100
import torch.nn.functional as F
import pandas as pd
import umap
from ipywidgets import interact
from sklearn.metrics import silhouette_score



# Load model

## Define

In [3]:
embed_layer = 'last' # last or before_last
model_checkpoint = 'checkpoint/cifar10_resnet32_CE_None_exp_0.01_seed_None_0/ckpt.best.pth.tar'
# model_checkpoint = 'checkpoint/cifar10_resnet32_LDAM_DRW_exp_0.01_sam_0.8_sched_none_seed_None_0/ckpt.best.pth.tar'

arch = 'resnet32'
dataset = 'cifar10'
loss_type = 'CE'
gpu = 0

In [4]:
model_name  = model_checkpoint.split('/')[-2]

num_classes = 100 if dataset == 'cifar100' else 10
use_norm = True if loss_type == 'LDAM' else False
model = models.__dict__[arch](num_classes=num_classes, use_norm=use_norm)

if gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
else:
    # DataParallel will divide and allocate batch_size to all available GPUs
    model = torch.nn.DataParallel(model).cuda()

checkpoint = torch.load(model_checkpoint)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

print(model)


ResNet_s(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affin

### Dataloader

In [5]:
transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)
val_len=len(val_dataset)
print('Length of val dataset: ', val_len)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False,num_workers=12, pin_memory=True)

imb_type = 'exp'
imb_factor = 0.01
rand_number = 0

train_dataset = IMBALANCECIFAR10(root='./data', imb_type=imb_type, imb_factor=imb_factor, rand_number=rand_number, train=True, download=True, transform=transform_train)
train_len=len(train_dataset)
print('Length of train dataset: ', train_len)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=False,num_workers=12, pin_memory=True)

Files already downloaded and verified
Length of val dataset:  10000
Files already downloaded and verified
Length of train dataset:  12406


### Get intermediate layer features

In [6]:
features = {'train': [], 'val': []}
labels = {'train': [], 'val': []}
loaders = {'train': train_loader, 'val': val_loader}
datasets = ['train', 'val']

def get_features(dataset='val'):
    def hook(model, input, output):
        features[dataset].append(F.avg_pool2d(output, output.size()[3]).view(output.size(0), -1).detach().cpu().numpy())
    return hook

model.eval()
for dataset in datasets:
    model = models.__dict__[arch](num_classes=num_classes, use_norm=use_norm)
    if gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_checkpoint)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    if embed_layer == 'before_last':
        model.module.layer3.register_forward_hook(get_features(dataset=dataset))
    with torch.no_grad():
        for i, (input, target) in enumerate(loaders[dataset]):
            labels[dataset].extend(target.cpu().numpy())
            if gpu is not None:
                input = input.cuda(gpu, non_blocking=True)
            # compute output
            output = model(input)
            if embed_layer == 'last':
                features[dataset].append(output.detach().cpu().numpy())

for dataset in datasets:
    features[dataset] = np.concatenate(features[dataset], axis=0)
    labels[dataset] = np.array(labels[dataset])

In [7]:
# high_dim_features = np.random.rand(100, 1000)
high_dim_features = np.concatenate([features['train'], features['val']], axis=0)
high_dim_classes = np.concatenate([labels['train'], labels['val']], axis=0)
print(high_dim_features.shape)
print(high_dim_classes.shape)

(22406, 10)
(22406,)


In [8]:
# get UMAP embeddings
reducer = umap.UMAP()
umap_results = reducer.fit_transform(high_dim_features)
print('umap shape:',umap_results.shape)
tsne_results = TSNE(n_components=2).fit_transform(high_dim_features,high_dim_classes)
print('tsne shape:',tsne_results.shape)
df = pd.DataFrame(columns=['tsne1', 'tsne2', 'class', 'dataset'])
df['umap1'] = umap_results[:,0]
df['umap2'] = umap_results[:,1]
df['tsne1'] = tsne_results[:,0]
df['tsne2'] = tsne_results[:,1]
df['class'] = high_dim_classes
df['dataset'] = ['train']*train_len + ['val']*val_len
df.to_csv(f'embeddings/{model_name}_{embed_layer}.csv', index=False)

umap shape: (22406, 2)
tsne shape: (22406, 2)


# Plots
## Define

In [9]:
embed_layer = 'before_last' # last or before_last
# model_checkpoint = 'checkpoint/cifar10_resnet32_CE_None_exp_0.01_seed_None_0/ckpt.best.pth.tar'
model_checkpoint = 'checkpoint/cifar10_resnet32_LDAM_DRW_exp_0.01_sam_0.8_sched_none_seed_None_0/ckpt.best.pth.tar'

In [10]:
model_name  = model_checkpoint.split('/')[-2]

# load csv
df = pd.read_csv(f'embeddings/{model_name}_{embed_layer}.csv')
# scatterplot with seaborn with hue for class style for dataset. Add widgets to select dataset and number of samples
def plot_tsne(dataset, n_samples, embeddings='umap'):
    if embeddings == 'umap':
        x = 'umap1'
        y = 'umap2'
    elif embeddings == 'tsne':
        x = 'tsne1'
        y = 'tsne2'
    sns.scatterplot(data = df[df['dataset'] == dataset].sample(n=n_samples), x=x, y=y, hue='class', style='dataset', palette=sns.color_palette("bright", 10), s=10)
    sil_score = silhouette_score(df[df['dataset'] == dataset][[x, y]], df[df['dataset'] == dataset]['class'])
    print(f'Silhouette score for {dataset} dataset: {sil_score}')
    plt.show()

interact(plot_tsne, dataset=['train', 'val'], n_samples=(100, 10000, 1000), embeddings=['umap', 'tsne'])

interactive(children=(Dropdown(description='dataset', options=('train', 'val'), value='train'), IntSlider(valu…

<function __main__.plot_tsne(dataset, n_samples, embeddings='umap')>

In [11]:
# plot scatterpot with selected class colored

dd = widgets.Dropdown(
    options=[str(i) for i in range(10)],
    value='0',
    description='Number:')

def draw_plot(num, dataset='train', embeddings='umap'):
    if embeddings == 'umap':
        x = 'umap1'
        y = 'umap2'
    elif embeddings == 'tsne':
        x = 'tsne1'
        y = 'tsne2'
    color = lambda label: 'grey' if label != int(num) else 'red'
    sns.scatterplot(data = df[df['dataset'] == dataset], x=x, y=y, hue='class', style='dataset', palette=[color(i) for i in range(10)], s=10)
    plt.show()

widgets.interactive(draw_plot, num=dd, dataset=['train', 'val'], embeddings=['umap', 'tsne'])

interactive(children=(Dropdown(description='Number:', options=('0', '1', '2', '3', '4', '5', '6', '7', '8', '9…