In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
import tqdm

sys.path.insert(0, '/mnt/data/molchanov/dltranz')

from domyshnik.models import *
from domyshnik.data import *
from domyshnik.constants import *
from domyshnik.utils import *

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.colors as mcolors


def draw(imgs):
        if isinstance(imgs, list):
            imgs = torch.stack(imgs)
        fig = plt.figure()
        rows, columns = 1, imgs.shape[0]
        for i in range(imgs.shape[0]):
            fig.add_subplot(rows, columns, i+1)
            plt.imshow(imgs[i])
        plt.show()
        
def draw_rgb(imgs):
    if isinstance(imgs, list):
        imgs = torch.stack(imgs)
    fig = plt.figure()
    rows, columns = 1, imgs.shape[0]
    for i in range(imgs.shape[0]):
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(imgs[i].transpose(0, 1).transpose(1, 2))
    plt.show()
    
        
%matplotlib inline

# Metric Learning MNIST Test 

In [None]:
#model = get_mnist_metriclearning_model()
model = get_mnist_metriclearning_persample_model()
model

In [None]:
data_loader = get_mnist_test_loader(BATCH_SIZE, n_augments=N_AUGMENTS)
sample = next(iter(data_loader))
data = sample[0].view(-1, sample[0].size(-2), sample[0].size(-1))
draw(sample[0][0])

In [None]:
embeds = model(sample[0])

In [None]:
distances = torch.matmul(embeds, embeds.transpose(0, 1))
for i in range(distances.shape[0]):
    distances[i, i] = -100

In [None]:
for i in range(distances.shape[0]):
    vals, idx = torch.sort(distances[i], descending=True)
    idx = [i] + idx[:N_AUGMENTS + 1].numpy().tolist()
    imgs = data[idx]
    draw(imgs)

# Domyshnik Mnist Test

In [None]:
import torch.nn.functional as F

In [None]:
data_loader = get_mnist_test_loader(BATCH_SIZE, n_augments=N_AUGMENTS, augment_labels=True)
sample = next(iter(data_loader))
data = sample[0].view(-1, sample[0].size(-2), sample[0].size(-1))
draw(sample[0][0])

In [None]:
sample[1]

# experiment 1 (domyshnik net)

In [None]:
info = mnist_domyshnik_lunch_info
sample = next(iter(info.test_loader))

In [None]:
model = get_mnist_domyshnik_model()
model

In [None]:
out = model(sample[0])
out = out.view(BATCH_SIZE, N_AUGMENTS + 1, out.size(-1))
out = F.softmax(out, dim=-1)
out.size()


In [None]:
draw(sample[0][0])

In [None]:
out0 = out[0]
dists = torch.matmul(out0, out0.transpose(0, 1))
'max', dists.max(), 'min', dists.min()

In [None]:
draw(sample[0][10])

In [None]:
out1 = out[1]
dists = torch.matmul(out0, out1.transpose(0, 1))
'max', dists.max(), 'min', dists.min()

In [None]:
print(out[0, 0])
print(out[0, 1])
print(out[0, 2])
print(out[0, 3])

print(out[10, 0])
print(out[110, 1])
print(out[23, 2])
print(out[4, 3])

# experiment 2 metric learning simularity

In [None]:
#m_model = get_mnist_metriclearning_model()
m_model = get_mnist_metriclearning_persample_model()
m_model

In [None]:
m_out = m_model(sample[0])
m_out = m_out.view(BATCH_SIZE, N_AUGMENTS + 1, m_out.size(-1))
m_out.size()

### distance to negatives

In [None]:
dists = F.pairwise_distance(m_out[0], m_out[10])
'max', dists.max().item(), 'min', dists.min().item()

### distance to self 

In [None]:
dists = F.pairwise_distance(m_out[0], m_out[0][torch.randperm(m_out[0].size(0))])
'max', dists.max().item(), 'min', dists.min().item()
dists

In [None]:
t = -1
for i in range(BATCH_SIZE):
    dists = F.pairwise_distance(m_out[i], m_out[i][torch.randperm(m_out0.size(0))])
    if t < dists.max().item():
        t = dists.max().item()
'max', t

In [None]:
t = -1
for i in range(BATCH_SIZE):
    dists = F.pairwise_distance(m_out[0], m_out[i])
    if t < dists.max().item():
        t = dists.max().item()
'max', t

In [None]:
m = MnistDomyshnikNetNet3()
m

In [None]:
out = m(sample[0])
out.size()

## Check Metric Space

In [None]:
DEVICE = torch.device('cuda:0')
loader = get_cifar10_test_loader_without_augmentation(batch_size=1024)
sample = next(iter(loader))

In [None]:
model = get_cifar10_metriclearning_persample_model_cated()
model.eval()
model.to(DEVICE)

embeddings = []
labels = []

with torch.no_grad():
    with tqdm.notebook.tqdm(total=len(loader)) as steps:
        for itr, data in enumerate(loader):
            x, y = data[0].to(DEVICE), data[1]
            out = model(x)
            
            out = out.detach().cpu()
            y = y.cpu()
            
            embeddings.append(out.detach().cpu())
            labels.append(y.detach().cpu())
            
            steps.update()

embeddings = torch.cat(embeddings, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
            
print(f'embeddings shape {embeddings.shape}')
print(f'labels shape {labels.shape}')

In [None]:
X_embedded = PCA(n_components=50).fit_transform(embeddings)
X_embedded = TSNE(n_components=2).fit_transform(X_embedded)

In [None]:
colors = [col for col in mcolors.TABLEAU_COLORS]
cl = [colors[lbl] for lbl in labels]

plt.figure(figsize=(80,80))
ax1 = plt.subplot()
ax1.scatter(X_embedded[:, 0], X_embedded[:, 1], s=700, c=cl)

plt.show()

In [None]:
clas_names = ['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']



In [None]:
plt.figure(figsize=(20,10))
plt.bar(clas_names, height=1, data=clas_names, color=colors)



In [None]:
# one class plots
colors = [col for col in mcolors.TABLEAU_COLORS]
cl = [colors[lbl] for lbl in labels]
for l in range(10):
    idx = (labels == l).astype('int').nonzero()
    mas = X_embedded[idx]
    cc = np.array(cl)[idx]
    
    plt.figure(figsize=(80,80))
    ax1 = plt.subplot()
    ax1.scatter(mas[:, 0], mas[:, 1], s=700, c=cc)


    plt.title(clas_names[l], fontsize=140)
    plt.show()

In [None]:
lbls = [0, 1]
colors = [col for col in mcolors.TABLEAU_COLORS]
cl = [colors[lbl] for lbl in labels]
idx = (labels.any(lbls)).astype('int').nonzero()
mas = X_embedded[idx]
cc = np.array(cl)[idx]

plt.figure(figsize=(80,80))
ax1 = plt.subplot()
ax1.scatter(mas[:, 0], mas[:, 1], s=700, c=cc)


plt.title(clas_names[lbls[0]] + '/' + clas_names[lbls[1]], fontsize=140)
plt.show()