In [None]:
import os
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tv

from data import SimCLRAugment
from models import Encoder, Projector

%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.figsize'] = (20,20)
sns.set()
np.random.seed(2020)

In [None]:
filename = f'./checkpoints/{md5}.pkl'
checkpoint = torch.load(filename)

model = Encoder(checkpoint['hparams'])
model.load_state_dict(checkpoint['encoder_state_dict'])
model.cuda()

proj = Projector(checkpoint['hparams'])
proj.load_state_dict(checkpoint['projector_state_dict'])
proj.cuda()

checkpoint['hparams'].as_dict

In [None]:
from torchvision import datasets

batchsize = 512
train_transform = SimCLRAugment(checkpoint['hparams'], batchsize)

cifar_train_data = datasets.CIFAR10(os.getenv('SLURM_TMPDIR')+'/data', train=True, transform=train_transform, download=False)
cifar_test_data = datasets.CIFAR10(os.getenv('SLURM_TMPDIR')+'/data', train=False, transform=train_transform, download=False)

svhn_train_data = datasets.SVHN(os.getenv('SLURM_TMPDIR')+"/data/svhn", split='train', transform=train_transform, download=False)
fake_data = datasets.FakeData(4000, (3, 32, 32), transform=train_transform)

In [None]:
@torch.no_grad()
def get_cos_dis(x, use_all=False, use_proj=False):
  x = x.cuda()
  z = model(x)
  if use_proj:
    z = proj(z)

  znorm = z / torch.norm(z, 2, dim=1, keepdim=True)
  cos_sim = torch.einsum('id,jd->ij', znorm, znorm) / 0.5

  if use_all:
    indices = torch.triu_indices(batchsize,batchsize, 1)
    dist = cos_sim[indices[0], indices[1]]
    var, mean = torch.var_mean(cos_sim[indices[0], indices[1]])
  else:
    dist = cos_sim[0,:]
    var, mean = torch.var_mean(cos_sim[0, :])

  return dist, mean, var


def plot_density(dataset, name, count, axes, color_idx):
  for idx, param in enumerate(product((False, True), (False, True))):
    xs = []
    means = []
    variances = []
    indices = np.random.randint(0, len(dataset), count)
    for i in indices:
      dist, mean, var = get_cos_dis(dataset[i][0], *param)
      xs.append(dist.cpu().numpy())
      means.append(mean.cpu().item())
      variances.append(var.cpu().item())

    axes[idx,0].set_title('Distributions')
    axes[idx,0].hist(xs, histtype=u'bar', alpha=0.5, bins=100, stacked=True, density=True, color=[sns.color_palette()[color_idx] for _ in range(len(xs))], lw=0, label=name)
    axes[idx,0].set_xlim(0, 2)
    axes[idx,0].legend()

    axes[idx,1].set_title('Means')
    sns.distplot(means, label=name, ax=axes[idx,1])
    axes[idx,1].legend()

    axes[idx,2].set_title('Variances')
    sns.distplot(variances, label=name, ax=axes[idx,2])
    axes[idx,2].legend()

In [None]:
_, axes = plt.subplots(4, 3)
plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)
plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 1)

In [None]:
_, axes = plt.subplots(4, 3)
plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)
plot_density(cifar_test_data, 'CIFAR-10 Test', 20, axes, 1)

In [None]:
_, axes = plt.subplots(4, 3)
plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)
plot_density(fake_data, 'Fake Data', 20, axes, 1)

In [None]:
_, axes = plt.subplots(4, 3)
plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)
plot_density(svhn_train_data, 'SVHN Train', 20, axes, 1)