In [None]:
# Emerging Properties in Self-Supervised Vision Transformers: Spezielle self-supervised Methode,
# die beim Vision-Transformer zu emergenten Effekten nämlich Segmentations-Attentionmaps führen soll

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import pickle as pkl

%load_ext autoreload
%autoreload 2
from utils import *
from layers import *
from transformer import *
from cifar10 import *

In [None]:
i = 30
showimg(getimg(i))
getlabel(i)

# 0: Flugzeug
# 1: Auto
# 2: Vogel
# 3: Katze
# 4: Reh
# 5  Hund
# 6: Kröte
# 7: Pferd
# 8: Boot
# 9: Truck

In [None]:
showimg(getunsupervised(1))

In [None]:


def batchgen(bsize=32, start=500):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N//5)) ##################################
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs = []
            ys = []
            for i in mb:
                xs.append(getimg(i))
                ys.append(getlabel(i))
            yield np.array(xs), np.array(ys)
        print(f'========== EPOCH {ep} COMPLETED ==========')
        ep += 1


def ubatchgen(bsize=32, start=0):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N_UNSUPERVISED))
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs1 = np.zeros((bsize, 24, 24, 3))
            xs2 = np.zeros((bsize, 24, 24, 3))
            for i, j in enumerate(mb):
                x1 = getunsupervised(j).reshape((32, 32, -1))
                x2 = getunsupervised(j).reshape((32, 32, -1))
                xs1[i] = randomcrop(x1, s=24)
                xs2[i] = randomcrop(x2, s=24)
            yield xs1, xs2
        print(f'========== UNSUPERVISED EPOCH {ep} COMPLETED ==========')
        ep += 1

In [None]:
bg = batchgen()
xs, ys = next(bg)
xs.shape

ubg = ubatchgen()
xs1, xs2 = next(ubg)
xs1.shape
showimg(xs1[0])

In [None]:
from layers import *



class Center(nn.Module):
    def __init__(self, n, beta=0.9):
        super().__init__()
        self.c = torch.zeros(1, n).cuda()
        self.beta = beta
    def forward(self, x):
        self.c = self.beta * self.c + (1-self.beta) * x
        return x - self.c


class Net(nn.Module):
    
    def __init__(self, n, nh, M):
        super().__init__()
        self.dense1 = nn.Linear(3, n)
        self.posenc2d = PositionalEncoding2d(n)
        self.ln1 = LayerNorm(n)
        self.seed = Seed(n, M)
        self.isab1 = ISAB2(n, n, nh)
        self.isab2 = ISAB2(n, n, nh)
        self.isab3 = ISAB2(n, n, nh)
        #self.isab4 = ISAB2(n, n, nh)
        #self.isab5 = ISAB2(n, n, nh)
        #self.isab6 = ISAB2(n, n, nh)
        self.ln2 = LayerNorm(n)
        self.dense2 = nn.Linear(n, 100)
        self.center = Center(100)
        self.cuda()
    
    def forward(self, x, center=False):
        x = self.dense1(x)
        x = x / np.sqrt(x.shape[2])
        x = self.posenc2d(x)
        x = rearrange(x, 'b h w c -> b (h w) c')
        x = self.ln1(x)
        y = self.seed(x)
        x, y = self.isab1(x, y)
        x, y = self.isab2(x, y)
        x, y = self.isab3(x, y)
        #x, y = self.isab4(x, y)
        #x, y = self.isab5(x, y)
        #x, y = self.isab6(x, y)
        y = self.ln2(y)
        y = y[:,0,:]
        self.prelast = y
        y = self.dense2(y)
        if center: y = self.center(y)
        return y

In [None]:
net = Net(128, 4, 10)
from torch_optimizer import Lookahead, Yogi
net.optim = Lookahead(Yogi(net.parameters(), lr=3e-3))
net.iters = 0
net.losses = []
net.vlosses = []
net.slosses = []
net.vaccs = []
bg = batchgen()

net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

teacher = Net(128, 4, 10)
teacher.load_state_dict(net.state_dict());

In [None]:
def valloss():
    net.eval()
    xs = np.array([getimg(i) for i in range(200)])
    yt = np.array([getlabel(i) for i in range(200)])
    xs, yt = np2t(xs, yt)
    yt = yt.long()
    yp = net(xs)
    loss = F.nll_loss(F.log_softmax(yp, dim=1), yt).item()
    yp, yt = t2np(yp, yt)
    yp = yp.argmax(-1)
    acc = np.mean(yp==yt)
    return loss, acc
    
def loss():
    net.train()
    xs, yt = next(bg)
    xs, yt = np2t(xs, yt)
    yt = yt.long()
    yp = net(xs)
    return F.nll_loss(F.log_softmax(yp, dim=1), yt)
    
def ssloss():
    net.train()
    x1, x2 = next(ubg)
    x = np.concatenate((x1, x2))
    x = np2t(x)
    s = net(x)
    with torch.no_grad():
        t = teacher(x, center=True)
    s1, s2 = torch.split(s, len(s)//2)
    t1, t2 = torch.split(t, len(t)//2)
    loss = H(s1, t2) / 2 + H(s2, t1) / 2
    return loss

def H(s, t):
    tps, tpt = 0.1, 0.01 ## 0.02
    t = t.detach()
    s = F.softmax(s / tps, dim=1)
    t = F.softmax(t / tpt, dim=1)
    return torch.mean(-torch.log(s+1e-12) * t)

ssloss()

In [None]:
losses = []
#slosses = []
ubg = ubatchgen()

for k in trange(999999):
    net.train()
    l = ssloss()
    l.backward()
    losses.append(l.item())
    net.optim.step()
    net.zero_grad()
    update_mt(teacher, net, tau=0.9) ## 99

    if len(losses) == 50:
        net.losses.append((net.iters, np.mean(losses)))
        losses = []
        #slosses = []

    if k % 500 == 0:
        plt.plot(*zip(*net.losses))
        plt.grid()
        plt.show()

        i = np.random.randint(1000)
        x = randomcrop(getimg(i), 24)
        xs = np2t([x])
        yp = net(xs)
        #beta = t2np(net.isab3.mab1.mha.beta)[0,0,0,:].reshape(24,24)
        #plt.imshow(0*beta, alpha=.95-beta/beta.max()*.95, extent=(0, 1, 0, 1), cmap='gray')
        beta = t2np(net.isab3.mab1.mha.beta)[0,0,:3,:].reshape(3,24,24).transpose(1, 2, 0)
        beta /= beta.max()
        showimg(beta)
        plt.show()
        showimg(x)
        plt.show()
        plt.plot(t2np(yp)[0])
        plt.show()

    net.iters += 1

In [None]:
torch.save(net.state_dict(), 'vt_emerging_properties_111223.dat')

In [None]:
i = 1
net.eval()
x = getimg(i)
xs = np2t([x])
yp = net(xs)

showimg(x)
plt.show()

for i in range(10):
    for j in range(4):
        beta = t2np(net.isab3.mab1.mha.beta)[0,j,i,:].reshape(32,32)
        #showimg(x)
        plt.imshow(0*beta, alpha=.95-beta/beta.max()*.95, extent=(0, 1, 0, 1), cmap='gray')
        plt.show()

In [None]:

net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

img = plt.imread('examples/dog.jpg')[:,:,:3]
img_big = plt.imread('examples/dog hr.jpg')
xs = np2t([img])
net(xs)
beta = t2np(net.isab3.mab1.mha.beta)[0,0:3,0,:].reshape(3,32,32).transpose(1, 2, 0)
beta /= beta.max()
showimg(img_big)
plt.show()
showimg(beta)

In [None]:
net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

In [None]:
losses = []
#slosses = []

for k in trange(999999):
    net.train()
    l = loss()
    #sl = sloss()
    l.backward()
    losses.append(l.item())
    #slosses.append(sl.item())
    net.optim.step()
    net.zero_grad()

    if len(losses) == 50:
        vl, vacc = valloss()
        net.vlosses.append((net.iters, vl))
        net.vaccs.append((net.iters, vacc))
        net.losses.append((net.iters, np.mean(losses)))
        #net.slosses.append((net.iters, np.mean(slosses)))
        losses = []
        #slosses = []

    if k % 50 == 0:
        plt.plot(*zip(*net.losses))
        plt.plot(*zip(*net.vlosses))
        plt.plot(*zip(*net.vaccs))
        #plt.plot(*zip(*net.slosses))
        plt.grid()
        plt.show()

    net.iters += 1

In [None]:
np.array(net.vlosses)[:,1].min(), np.array(net.vaccs)[:,1].max()

In [None]:
net.load_state_dict(torch.load('vt_emerging_properties_111223.dat'))

In [None]:
from sklearn.manifold import TSNE

xs = np.array([getimg(i) for i in range(200)])
yt = np.array([getlabel(i) for i in range(200)])
xs = np2t(xs)
yp = net(xs)

X = t2np(net.prelast)
X_embedded = TSNE(n_components=2).fit_transform(X)

In [None]:
for i in range(10):
    mask = yt==i
    plt.scatter(X_embedded[mask][:,0], X_embedded[mask][:,1], s=4)