In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
import tqdm

sys.path.insert(0, '/mnt/data/molchanov/dltranz')

from domyshnik.models import *
from domyshnik.data import *
from domyshnik.constants import *
from domyshnik.utils import *

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.colors as mcolors


def draw(imgs):
        if isinstance(imgs, list):
            imgs = torch.stack(imgs)
        fig = plt.figure()
        rows, columns = 1, imgs.shape[0]
        for i in range(imgs.shape[0]):
            fig.add_subplot(rows, columns, i+1)
            plt.imshow(imgs[i])
        plt.show()
        
def draw_rgb(imgs):
    if isinstance(imgs, list):
        imgs = torch.stack(imgs)
    fig = plt.figure()
    rows, columns = 1, imgs.shape[0]
    for i in range(imgs.shape[0]):
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(imgs[i].transpose(0, 1).transpose(1, 2))
    plt.show()
    
        
%matplotlib inline

In [None]:
DEVICE = torch.device('cuda:0')
#loader = get_cifar10_test_loader_without_augmentation(batch_size=1024)
loader =get_cifar10_train_loader(batch_size=1024, n_augments=-1, augment_labels=False)
sample = next(iter(loader))

In [None]:
#model = get_cifar10_metriclearning_persample_model_cated()
model = get_cifar10_centroids_model(31)
model.eval()
model.to(DEVICE)

embeddings = []
labels = []

with torch.no_grad():
    with tqdm.notebook.tqdm(total=len(loader)) as steps:
        for itr, data in enumerate(loader):
            x, y = data[0].to(DEVICE), data[1]
            out = model(x)
            
            out = out.detach().cpu()
            y = y.cpu()
            
            embeddings.append(out.detach().cpu())
            labels.append(y.detach().cpu())
            
            steps.update()

embeddings = torch.cat(embeddings, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
            
print(f'embeddings shape {embeddings.shape}')
print(f'labels shape {labels.shape}')

In [None]:
pca = PCA(n_components=50)
pca.fit(embeddings)
X_embedded = pca.transform(embeddings)

centroids = model.centroids.data.cpu().numpy()
X_centroids = pca.transform(centroids)

X_embedded = np.concatenate((X_centroids, X_embedded), 0)

tsne = TSNE(n_components=2)
X_embedded = tsne.fit_transform(X_embedded)

X_centroids = X_embedded[:X_centroids.shape[0]]
X_embedded = X_embedded[X_centroids.shape[0]:]

#X_embedded = PCA(n_components=50).fit_transform(embeddings)
#X_embedded = TSNE(n_components=2).fit_transform(X_embedded)

In [None]:
colors = [col for col in mcolors.TABLEAU_COLORS]
cl = [colors[lbl] for lbl in labels]

plt.figure(figsize=(80,80))
ax1 = plt.subplot()
ax1.scatter(X_embedded[:, 0], X_embedded[:, 1], s=700, c=cl, alpha=0.2)

ax1.scatter(X_centroids[:, 0], X_centroids[:, 1], s=3700, c='black', marker='X')

plt.show()

In [None]:
clas_names = ['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

plt.figure(figsize=(20,10))
plt.bar(clas_names, height=1, data=clas_names, color=colors)

In [None]:
# one class plots
colors = [col for col in mcolors.TABLEAU_COLORS]
cl = [colors[lbl] for lbl in labels]
for l in range(10):
    idx = (labels == l).astype('int').nonzero()
    mas = X_embedded[idx]
    cc = np.array(cl)[idx]
    
    plt.figure(figsize=(80,80))
    ax1 = plt.subplot()
    ax1.scatter(mas[:, 0], mas[:, 1], s=700, c=cc)
    
    ax1.scatter(X_centroids[:, 0], X_centroids[:, 1], s=2700, c='black', marker='X')


    plt.title(clas_names[l], fontsize=140)
    plt.show()

In [None]:
model = get_cifar10_centroids_model(50)
model.eval()
model.to(DEVICE)

# Distances to all other samples

In [None]:
loader = get_cifar10_train_loader(batch_size=128, n_augments=-1, augment_labels=False)
sample = next(iter(loader))
print(sample[0].size())
out = model(sample[0].to(DEVICE))
lbl = sample[1]


def calc_hist(s1, s2):
    d = outer_pairwise_distance(s1, s2)
    return d.detach().cpu().numpy()

D = None
LBL = []
for i, data in tqdm.notebook.tqdm(enumerate(loader)):
    if i == 0:
        continue
    out1 = model(data[0].to(DEVICE))
    d = calc_hist(out, out1)
    if D is None:
        D = d
    else:
        D = np.concatenate((D, d), -1)
        
    LBL += data[1].numpy().tolist()

D.shape

In [None]:
i = 0
for i in range(128):
    mask = (D[i] < 0.3)
    dist = D[i][mask]
    lbls = np.array(LBL)[mask]
    #plt.hist(dist, bins=100)
    #plt.show()
    #plt.close()

    dt = torch.Tensor(dist)
    lb = torch.Tensor(lbls)

    dt, idx = torch.sort(dt)
    lb = lb.index_select(0, idx)

    plt.hist(lb.numpy(), bins=100, color='red')
    plt.title(lbl[i].item())
    plt.show()
    plt.close()
    print('-------------------------------------------------------')

In [None]:
dist = outer_pairwise_distance(out).detach().cpu().numpy()
imgs = sample[0]
lb = sample[1]
for i in range(dist.shape[0]):
    print('\n--------------------------------------')
    plt.hist(dist[i], bins=100)
    plt.show()
    plt.close()
    
    ti, idx = torch.Tensor(dist[i]).sort()
    best_imgs = imgs.index_select(0, idx)
    lbs = lb[i] + lb[idx][:10]
    print(f'labels: {lbs.data}')
    draw_rgb([imgs[i]] + [im for im in best_imgs[:5]])

In [None]:
dist = outer_pairwise_distance(out).detach().cpu().numpy()
dist.shape[0]

# Distance hist for augments and other images in batch

In [None]:
loader =get_cifar10_train_loader(batch_size=128, n_augments=10, augment_labels=False)
sample = next(iter(loader))
print(sample[0].size())
out = model(sample[0].to(DEVICE))
out = out.view(128, 11, -1)

In [None]:
for i in range(1, 128):
    print('\n----------------------------------')
    d1 = outer_pairwise_distance(out[0]).view(-1).detach().cpu().numpy() + (torch.eye(11)*3).view(-1).numpy()
    d2 = outer_pairwise_distance(out[0], out[i]).view(-1).detach().cpu().numpy()
    if d2.min() < 0.6:
        print(f'{sample[1][0].item()} vs {sample[1][i].item()}')
        draw_rgb([sample[0][0, 0], sample[0][i, 0]])
        
    elif sample[1][0].item() == sample[1][i].item():
        print(f'{sample[1][0].item()} vs {sample[1][i].item()}')
        draw_rgb([sample[0][0, 0], sample[0][i, 0]])
        
    plt.hist(d1, color='blue')
    plt.hist(d2, color='red')
    plt.show()
    plt.close()

# Centroids distance hist

In [None]:
for i in range(model.centroids.size(0)):
    plt.hist(dcc[i].detach().cpu().numpy())
    plt.show()
    plt.close()

In [None]:
def cifar_torch_augmentation(p=1):
    return torchvision.transforms.Compose([
        transforms.ToPILImage(),   
        transforms.RandomApply([
            transforms.RandomResizedCrop(size=32, scale=(0.5, 1.0))
            ], p=0.5),
        
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, 
                                   contrast=0.3,#0.7, 
                                   saturation=(0.3, 0.5),#(0.5, 1), 
                                   hue=0.5)
            ], p=0.5),
        
        transforms.RandomApply([
            transforms.RandomHorizontalFlip()
        ], p=0.7),
        
        transforms.ToTensor(),
        #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])

aug = cifar_torch_augmentation()

loader =get_cifar10_test_loader(batch_size=128, n_augments=-2, augment_labels=False)
sample = next(iter(loader))
print(sample[0].size())
for i in range(sample[0].size(0)):
    print('\n----------------------------------------')
    draw_rgb([sample[0][i], aug(sample[0][i])])
    