In [None]:
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torchvision
#
import einops
from torchvision import datasets
import torchvision.transforms as T
import matplotlib.pyplot as plt
import random
from sklearn.cluster import KMeans
#


In [None]:
class Squash(nn.Module):
    def __init__(self, eps=10e-21):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        """
         IN:  (b, n, d)
         OUT: squash(x(b,n,d))
        """
        x_norm = torch.norm(x, dim=2, keepdim=True)
        return (1 - 1 / (torch.exp(x_norm) + self.eps)) * (x / (x_norm + self.eps))

In [None]:
xx = np.linspace(-0.5, 1.5, 51)

t_freq = 3
t_symm = 0.5
#
yy = (1 - np.tanh(t_freq * (xx - t_symm))) * 0.5
plt.plot(xx, yy)
plt.axvline(t_symm)
plt.axvline(0)
plt.axvline(1)

In [None]:
xx = np.linspace(-0.5, 1.5, 51)

t_freq = 3
t_symm = 0.5
#
yy = (1 - np.tanh(t_freq * (xx - t_symm))) * 0.5
plt.plot(xx, yy)
plt.axvline(t_symm)
plt.axvline(0)
plt.axvline(1)

# POS. Embedding 4D

In [None]:
h, w = 28, 28
d = 4
#
pos_w = torch.linspace(0, 1, w)
pos_h = torch.linspace(0, 1, h)

# linear
pe = torch.zeros(4, h, w)
pe[0] = torch.linspace(0, 1, w).unsqueeze(1).repeat(1, h)
pe[1] = torch.linspace(1, 0, w).unsqueeze(1).repeat(1, h)
pe[2] = torch.linspace(0, 1, h).T.repeat(w, 1)
pe[3] = torch.linspace(1, 0, h).T.repeat(w, 1)

#
# exponential -> not symmetric
pe = torch.zeros(4, h, w)
l = -310
pe[0] = torch.exp(torch.linspace(0, 1, w) * -l).unsqueeze(1).repeat(1, h)
pe[1] = torch.exp(torch.linspace(1, 0, w) * -l).unsqueeze(1).repeat(1, h)
pe[2] = torch.exp(torch.linspace(0, 1, h) * -l).T.repeat(w, 1)
pe[3] = torch.exp(torch.linspace(1, 0, h) * -l).T.repeat(w, 1)


# tanh
t_freq = 2
t_symm = 0.5
pe[0] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
pe[1] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
pe[2] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, h) - t_symm)).T.repeat(w, 1)) * 0.5
pe[3] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, h) - t_symm)).T.repeat(w, 1)) * 0.5

In [None]:
for ipe in pe:
    plt.imshow(ipe, cmap="gray")
    plt.show()

In [None]:
E = pe.permute(1,2,0).reshape(h*w,d)

In [None]:
G = []
for i in range(h*w):
    for j in range(h*w):
        a = E[i]
        b = E[j]
        cs = F.cosine_similarity(a, b, dim=0)
        G.append(cs)
S = torch.Tensor(G).reshape(h,w,h,w)
print(S.min(), S.max(), S.shape)

### plot all

In [None]:
fig, axes = plt.subplots(h,w, figsize=(15, 15))
for xi in range(w):
    for yi in range(h):
        ax = axes[xi][yi]
        ax.imshow(S[xi][yi], cmap="gray")
        ax.axis('off')
plt.show()

In [None]:
pe.sum(dim=0)

In [None]:
def pos_tanh_embedding(h, w, t_freq = 2, t_symm = 0.5):
    pe = torch.zeros(4, h, w)
    pe[0] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[1] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[2] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, h) - t_symm)).T.repeat(w, 1)) * 0.5
    pe[3] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, h) - t_symm)).T.repeat(w, 1)) * 0.5
    return pe

# plot utils

In [None]:
def pairwise_cosin_sim(E, h=None, w=None, reshape=False):
    """
        IN 
            E (n, d)
        OUT
            G (n, n)
    """
    S = []
    for ei in E:
        for ej in E:
            cs = F.cosine_similarity(ei, ej, dim=0)
            S.append(cs)
    S = torch.Tensor(S)
    if reshape:
        S = S.reshape(h,w,h,w)
    return S

def plot_sim_tensor(S, figsize=(10, 10)):
    """
        IN
            S (h,w,h,w)
    """
    plt.figure(figsize=figsize)
    plt.imshow(S.permute(0,2,1,3).reshape(S.shape[0] * S.shape[1], S.shape[2]*S.shape[3]), cmap="gray")
    plt.axis("off")
    plt.show()

### emb on MNIST

In [None]:
ds_train = datasets.MNIST(
    root = '/mnt/data/pytorch',
    train = True,                         
    transform = T.ToTensor(), 
    download = True,            
)
ds_test = datasets.MNIST(
    root = '/mnt/data/pytorch',
    train = False, 
    transform = T.ToTensor()
)

In [None]:
x,_ = ds_train[1]
img = x[0][20:27,2:9]
print(img.shape)
plt.imshow(img, cmap="gray", vmin=0, vmax=1)

In [None]:
x,_ = ds_train[2]
img = x[0]
print(img.shape)
plt.imshow(img, cmap="gray", vmin=0, vmax=1)

In [None]:
E = pe + x
E = E.permute(1,2,0).reshape(28**2, -1)
print(E.shape)
#
E = torch.cat([pe, x], dim=0).permute(1,2,0).reshape(28**2, -1)
print(E.shape)

In [None]:
S = pairwise_cosin_sim(E, 28, 28, True)
plot_sim_tensor(S)

In [None]:
R = torch.softmax(E @ E.T, dim=0) @ E
S = pairwise_cosin_sim(R, 28, 28, True)
plot_sim_tensor(S)


In [None]:
for ri in R.reshape(28,28,-1).permute(2,0,1):
    plt.imshow(ri, cmap="gray")
    print(ri.min(), ri.max())
    plt.show()

### cluster pixels

In [None]:
n_clusters = 12
#
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(E)
Y = kmeans.predict(E)
Y = Y.reshape(28, 28) / n_clusters
ce = kmeans.cluster_centers_
plt.imshow(Y)
plt.show()
#
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(R)
Y = kmeans.predict(R)
Y = Y.reshape(28, 28) / n_clusters
cr = kmeans.cluster_centers_
plt.imshow(Y)
plt.show()

In [None]:
n_min = 4
n_max = 16

clusters = list(range(n_min, n_max + 1, 2))

fig, axes = plt.subplots(1, len(clusters), figsize=(len(clusters) * 4, 4))

for idx in range(len(clusters)):
    n_clusters = clusters[idx]
    ax = axes[idx]
    #
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(E)
    Y = kmeans.predict(E)
    Y = Y.reshape(28, 28) / n_clusters
    ce = kmeans.cluster_centers_
    ax.imshow(Y)
    ax.set_title(str(n_clusters))
plt.show()

In [None]:
k = 5
C = torch.Tensor(ce)
C = torch.cat([torch.rand(10, 5), torch.Tensor(ce)], dim=0)
C.shape, E.shape

In [None]:
temp = 0.01
S = (E @ C.T) / temp
A = torch.softmax(S, dim=1)

In [None]:
AT = torch.einsum("nk, k -> nk",A, A.sum(dim=0) / E.shape[0] * k)
#
AT = torch.softmax(AT / temp, dim=1)

In [None]:
Y = AT.argmax(dim=1)
set(list(np.array(Y)))

In [None]:
plt.imshow(Y.reshape(28, 28))

### emb on CIFAR

In [None]:
ds = torchvision.datasets.CIFAR10(
    root='/mnt/data/pytorch', train=True, download=True, transform=T.ToTensor())

In [None]:
x,_ = ds[11]
plt.imshow(x.permute(1,2,0))
x.shape

### emb with position - channels cross product

In [None]:
pe = pos_tanh_embedding(32, 32)
print(pe.shape, x.shape)

In [None]:
E = torch.einsum("ijk, ljk -> iljk", pe, x).reshape(3*4,32,32)

In [None]:
G = []
E = E.permute(1,2,0).reshape(32*32, -1)
for i in range(32*32):
    for j in range(32*32):
        a = E[i]
        b = E[j]
        cs = F.cosine_similarity(a, b, dim=0)
        G.append(cs)
S = torch.Tensor(G).reshape(32,32,32,32)
print(S.min(), S.max(), S.shape)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(S.permute(0,2,1,3).reshape(32*32, 32*32), cmap="gray")
plt.axis("off")

### emb with position CAT

In [None]:
pe = pos_tanh_embedding(32, 32)
E = torch.cat((x, pe), dim=0)
E.shape

In [None]:
G = []
E = E.permute(1,2,0).reshape(32*32, -1)
for i in range(32*32):
    for j in range(32*32):
        a = E[i]
        b = E[j]
        cs = F.cosine_similarity(a, b, dim=0)
        G.append(cs)
S = torch.Tensor(G).reshape(32,32,32,32)
print(S.min(), S.max(), S.shape)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(S.permute(0,2,1,3).reshape(32*32, 32*32), cmap="gray")
plt.axis("off")

### emb without position

In [None]:
pe = pos_tanh_embedding(32, 32)
#E = torch.cat((x, pe), dim=0)
E = x
E.shape

In [None]:
G = []
E = E.permute(1,2,0).reshape(32*32, -1)
for i in range(32*32):
    for j in range(32*32):
        a = E[i]
        b = E[j]
        cs = F.cosine_similarity(a, b, dim=0)
        G.append(cs)
S = torch.Tensor(G).reshape(32,32,32,32)
print(S.min(), S.max(), S.shape)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(S.permute(0,2,1,3).reshape(32*32, 32*32), cmap="gray")
plt.axis("off")