In [None]:
# Emerging Properties in Self-Supervised Vision Transformers: Spezielle self-supervised Methode,
# die beim Vision-Transformer zu emergenten Effekten nämlich Segmentations-Attentionmaps führen soll
#
# Version 2: Teacher hat global view, es gibt mehrere heads, sinkhorn statt centering

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import pickle as pkl

%load_ext autoreload
%autoreload 2
from utils import *
from layers import *
from transformer import *
from cifar10 import *

In [None]:
showimg(resize(getimg(1)))

In [None]:
i = 30
showimg(getimg(i))
getlabel(i)

# 0: Flugzeug
# 1: Auto
# 2: Vogel
# 3: Katze
# 4: Reh
# 5  Hund
# 6: Kröte
# 7: Pferd
# 8: Boot
# 9: Truck

In [None]:
showimg(getunsupervised(1))

In [None]:


def batchgen(bsize=32, start=500):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N//5)) ##################################
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs = []
            ys = []
            for i in mb:
                xs.append(getimg(i))
                ys.append(getlabel(i))
            yield np.array(xs), np.array(ys)
        print(f'========== EPOCH {ep} COMPLETED ==========')
        ep += 1


def ubatchgen(bsize=32, start=0):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N_UNSUPERVISED))
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs1 = np.zeros((bsize, 24, 24, 3))
            xs2 = np.zeros((bsize, 24, 24, 3))
            for i, j in enumerate(mb):
                x1 = getunsupervised(j).reshape((32, 32, -1))
                x2 = getunsupervised(j).reshape((32, 32, -1))
                xs1[i] = randomcrop(x1, s=24)
                xs2[i] = resize(x2, s=24)
            yield xs1, xs2
        print(f'========== UNSUPERVISED EPOCH {ep} COMPLETED ==========')
        ep += 1

In [None]:
bg = batchgen()
xs, ys = next(bg)
xs.shape

ubg = ubatchgen()
xs1, xs2 = next(ubg)
xs1.shape
showimg(xs2[0])

In [None]:
from layers import *


class Net(nn.Module):
    
    def __init__(self, n, nh, M):
        super().__init__()
        self.dense1 = nn.Linear(3, n)
        self.posenc2d = PositionalEncoding2d(n)
        self.ln1 = LayerNorm(n)
        self.seed = Seed(n, M)
        self.isab1 = ISAB2(n, n, nh)
        self.isab2 = ISAB2(n, n, nh)
        self.isab3 = ISAB2(n, n, nh)
        self.ln2 = LayerNorm(n)
        self.dense2 = nn.Linear(n, 10)
        self.heads = nn.ModuleList([ nn.Linear(n, 100) for _ in range(10) ])
        self.cuda()
    
    def forward(self, x):
        x = self.dense1(x)
        x = x / np.sqrt(x.shape[2])
        x = self.posenc2d(x)
        x = rearrange(x, 'b h w c -> b (h w) c')
        x = self.ln1(x)
        y = self.seed(x)
        x, y = self.isab1(x, y)
        x, y = self.isab2(x, y)
        x, y = self.isab3(x, y)
        y = self.ln2(y)
        y = y[:,0,:]
        self.prelast = y
        ys = [ h(y) for h in self.heads ]
        return ys
        

In [None]:
net = Net(128, 4, 10)
from torch_optimizer import Lookahead, Yogi
net.optim = Lookahead(Yogi(net.parameters(), lr=3e-3, weight_decay=0.0))
net.iters = 0
net.losses1 = []
net.losses2 = []
net.losses = []
net.vlosses = []
net.vaccs = []
bg = batchgen()

#net.load_state_dict(torch.load('vt_emerging_properties2_211223.dat'))

teacher = Net(128, 4, 10)
teacher.load_state_dict(net.state_dict());

In [None]:
    
def ssloss():
    net.train()
    x1, x2 = next(ubg)
    x1, x2 = np2t(x1, x2)
    ss = net(x1)
    tau = 1
    ss = [ F.softmax(s / tau, dim=1) for s in ss ]
    with torch.no_grad():
        ts = teacher(x2)
        Qs = [ sinkhorn(t) for t in ts ]
    loss1 = sum([ torch.mean(-torch.log(s+1e-12) * Q) for s, Q in zip(ss, Qs)])
    loss2 = keleoRegularizer(net.prelast)
    return loss1, loss2
    

def sinkhorn(t, eps=0.1):
    Q = torch.exp(t/eps)
    Q = torch.nan_to_num(Q)
    for i in range(3):
        Q = Q / (Q.sum(0, keepdim=True) + 1e-12)
        Q = Q / (Q.sum(1, keepdim=True) + 1e-12)
    return Q

def keleoRegularizer(y):
    y = y / (torch.sqrt(torch.sum(y**2, dim=1, keepdim=True)) + 1e-12)
    dsq = torch.sum((y[:,None,:] - y[None,:,:])**2, dim=-1)
    dsq = dsq + torch.eye(len(y)).to(y.device) * 10000 # deselect diagonal
    dmin = torch.min(dsq, dim=1)[0]
    return torch.mean(-torch.log(dmin+1e-12)) / 100

ssloss()

In [None]:
losses1 = []
losses2 = []
ubg = ubatchgen()

for k in trange(999999):
    net.train()
    l1, l2 = ssloss()
    (l1+l2).backward()
    losses1.append(l1.item())
    losses2.append(l2.item())
    net.optim.step()
    net.zero_grad()
    update_mt(teacher, net, tau=0.9) ## 99

    if len(losses1) == 50:
        net.losses1.append((net.iters, np.mean(losses1)))
        net.losses2.append((net.iters, np.mean(losses2)))
        losses1 = []
        losses2 = []

    if k % 200 == 0:
        plt.plot(*zip(*net.losses1))
        plt.plot(*zip(*net.losses2))
        plt.grid()
        plt.show()

        i = np.random.randint(1000)
        x = resize(getimg(i), s=24)
        xs = np2t([x])
        yp = net(xs)
        #beta = t2np(net.isab3.mab1.mha.beta)[0,0,0,:].reshape(24,24)
        #plt.imshow(0*beta, alpha=.95-beta/beta.max()*.95, extent=(0, 1, 0, 1), cmap='gray')
        beta = t2np(net.isab3.mab1.mha.beta)[0,0,:3,:].reshape(3,24,24).transpose(1, 2, 0)
        beta /= beta.max()
        showimg(beta)
        plt.show()
        showimg(x)
        plt.show()
        plt.plot(t2np(F.softmax(yp[0][0])))
        plt.show()
# nach 5000 iterationen um 10% des plateuas gesunken (0.23 plateu -> 0.21)
    net.iters += 1

In [None]:
torch.save(net.state_dict(), 'vt_emerging_properties2_221223.dat')

In [None]:
i = 1
net.eval()
x = getimg(i)
xs = np2t([x])
yp = net(xs)

showimg(x)
plt.show()

for i in range(10):
    for j in range(4):
        beta = t2np(net.isab3.mab1.mha.beta)[0,j,i,:].reshape(32,32)
        #showimg(x)
        plt.imshow(0*beta, alpha=.95-beta/beta.max()*.95, extent=(0, 1, 0, 1), cmap='gray')
        plt.show()

In [None]:
beta

In [None]:

net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

img = plt.imread('examples/dog.jpg')[:,:,:3]
img_big = plt.imread('examples/dog hr.jpg')
xs = np2t([img])
net(xs)
beta = t2np(net.isab3.mab1.mha.beta)[0,0:3,0,:].reshape(3,32,32).transpose(1, 2, 0)
beta /= beta.max()
showimg(img_big)
plt.show()
showimg(beta)

In [None]:
def valloss():
    net.eval()
    xs = np.array([getimg(i) for i in range(200)])
    yt = np.array([getlabel(i) for i in range(200)])
    xs, yt = np2t(xs, yt)
    yt = yt.long()
    net(xs)
    yp = net.dense2(net.prelast)
    loss = F.nll_loss(F.log_softmax(yp, dim=1), yt).item()
    yp, yt = t2np(yp, yt)
    yp = yp.argmax(-1)
    acc = np.mean(yp==yt)
    return loss, acc
    
def loss():
    net.train()
    xs, yt = next(bg)
    xs, yt = np2t(xs, yt)
    yt = yt.long()
    net(xs)
    yp = net.dense2(net.prelast)
    return F.nll_loss(F.log_softmax(yp, dim=1), yt)

In [None]:
net.optim = Lookahead(Yogi(net.parameters(), lr=3e-3, weight_decay=0.0))

In [None]:
net.load_state_dict(torch.load('vt_emerging_properties2_221223.dat'))

In [None]:
losses = []
#slosses = []

for k in trange(999999):
    net.train()
    l = loss()
    #sl = sloss()
    l.backward()
    losses.append(l.item())
    #slosses.append(sl.item())
    net.optim.step()
    net.zero_grad()

    if len(losses) == 50:
        vl, vacc = valloss()
        net.vlosses.append((net.iters, vl))
        net.vaccs.append((net.iters, vacc))
        net.losses.append((net.iters, np.mean(losses)))
        #net.slosses.append((net.iters, np.mean(slosses)))
        losses = []
        #slosses = []

    if k % 50 == 0:
        plt.plot(*zip(*net.losses))
        plt.plot(*zip(*net.vlosses))
        plt.plot(*zip(*net.vaccs))
        #plt.plot(*zip(*net.slosses))
        plt.grid()
        plt.show()

    net.iters += 1

In [None]:
np.array(net.vlosses)[:,1].min(), np.array(net.vaccs)[:,1].max()

In [None]:
N//5

In [None]:
net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

In [None]:
from sklearn.manifold import TSNE

xs = np.array([getimg(i) for i in range(200)])
yt = np.array([getlabel(i) for i in range(200)])
xs = np2t(xs)
yp = net(xs)

X = t2np(net.prelast)
X_embedded = TSNE(n_components=2).fit_transform(X)

In [None]:
for i in range(10):
    mask = yt==i
    plt.scatter(X_embedded[mask][:,0], X_embedded[mask][:,1], s=4)