In [None]:
# pixeltransformer, ähnlich wie bei pixelcnn und pixelrnn geht es darum,
# die pixelwerte vorherzusagen

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import pickle as pkl

%load_ext autoreload
%autoreload 2
from utils import *
from layers import *
from transformer import *
from cifar10 import *

SX = 32

In [None]:
showimg(getimg(1))
getlabel(1)

showimg(randomcrop(getimg(1), s=8))

In [None]:
from collections import Counter

#counter = Counter(list((randomcrop(getimg(4), s=8)*255).astype(int).flatten()))
#sum([ -np.log(n/192)*n/192 for p, n in counter.items() ]) # <-- Anzahl an bits pro pixelkanal ist ca. 4.5 bits/kanal

In [None]:

def batchgen(bsize=16, start=500):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N))
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs = []
            ys = []
            for i in mb:
                x = randomcrop(getimg(i), s=8)
                x = list(x.flatten())
                x = [257/256] + x # start "token"
                xs.append(x)
            xs = np.array(xs)
            xs = (255*xs).astype(int)
            ys = xs[:,1:]
            xs = xs[:,:-1]
            
            yield xs, ys
        print(f'========== EPOCH {ep} COMPLETED ==========')
        ep += 1

In [None]:
bg = batchgen()
xs, ys = next(bg)
xs.shape

In [None]:

from layers import *


def _Att(q, k, v, mask=None, bias=None):
    b, h, i, m = q.shape
    b, h, j, m = k.shape
    b, h, j, n = v.shape
    
    beta = torch.einsum('bhim, bhjm -> bhij', q, k) / np.sqrt(m)
    if bias is not None:
        beta = beta + bias

    if mask is not None:
        beta = beta.masked_fill(mask == 0, -1e12)

    beta = F.softmax(beta, dim=-1)
    
    if mask is not None:
        beta = beta.masked_fill(mask == 0, 0)  # make sure its really closed

    o = torch.einsum('bhij,bhjn->bhin', beta, v)
    
    return o, beta


class MultiHeadAttention(nn.Module):
    def __init__(self, n, m, nh, p):
        super().__init__()
        self.q = nn.Linear(n, m)
        self.k = nn.Linear(n, m)
        self.v = nn.Linear(n, n)
        self.p = nn.Linear(n, n)
        self.bias = nn.Parameter(torch.zeros(1, nh, p, p))
        self.nh = nh

    def forward(self, x, y, z, mask=None):

        q = rearrange(self.q(x), 'b p (h n) -> b h p n', h=self.nh)
        k = rearrange(self.k(y), 'b p (h n) -> b h p n', h=self.nh)
        v = rearrange(self.v(z), 'b p (h n) -> b h p n', h=self.nh)

        x, self.beta = _Att(q, k, v, mask, self.bias.repeat(len(x), 1, 1, 1))
        x = rearrange(x, 'b h p n -> b p (h n)', h=self.nh)
        
        x = self.p(x)
        return x


class EncoderBlock(nn.Module):
    def __init__(self, n, m, nh, p):
        super().__init__()
        self.mha = MultiHeadAttention(n, m, nh, p)
        self.ln = LayerNorm(n)
        self.ff = FeedForwardLayer(n)
        self.dropout = nn.Dropout(0.1)
    def forward(self, x, mask=None):
        y = self.ln(x)
        x = x + self.dropout(self.mha(y, y, y, mask))
        x = self.ff(x)
        return x


class Net(nn.Module):
    def __init__(self, n, nh, p):
        super().__init__()

        self.emb = nn.Embedding(256+10, n)
        
        self.posenc = PositionalEncoding(n)

        self.ln1 = LayerNorm(n)
        self.enc1 = EncoderBlock(n, n, nh, p)
        self.enc2 = EncoderBlock(n, n, nh, p)
        self.enc3 = EncoderBlock(n, n, nh, p)
        self.enc4 = EncoderBlock(n, n, nh, p)
        self.enc5 = EncoderBlock(n, n, nh, p)
        self.enc6 = EncoderBlock(n, n, nh, p)
        self.ln2 = LayerNorm(n)

        self.dense = nn.Linear(n, 256)
        
        self.dropout = nn.Dropout(0.1)
        self.n = n
        self.cuda()
        
    def forward(self, x):
        mask = np2t(np.tri(x.shape[1])[None]).cuda()
        
        #x = rearrange(x, 'b (p n) -> b p n', n=1)
        #x = gelu(self.predense(x))
        x = self.emb(x)
        x = self.posenc(x)
        x = self.dropout(x)
        
        x = self.ln1(x)
        x = self.enc1(x, mask)
        x = self.enc2(x, mask)
        x = self.enc3(x, mask)
        x = self.enc4(x, mask)
        x = self.ln2(x)
        
        x = self.dense(x)
        
        return x

In [None]:
net = Net(256, 4, p=8*8*3)
from torch_optimizer import Lookahead, Yogi
net.optim = Lookahead(Yogi(net.parameters(), lr=3e-3))
net.iters = 0
net.losses = []
net.vlosses = []
net.vaccs = []
bg = batchgen()

from torchsummary import summary
#summary(net, (8*8*3,))

In [None]:
def valloss():
    net.eval()
    xs = []
    for i in range(100):
        x = randomcrop(getimg(i), s=8)
        x = list(x.flatten())
        x = [257/256] + x # start "token"
        xs.append(x)
    xs = np.array(xs)
    xs = (255*xs).astype(int)
    ys = xs[:,1:]
    xs = xs[:,:-1]
    xs, ys = np2t(xs, ys)
    yp = net(xs.long())

    yp = yp[:,-12:]
    ys = ys[:,-12:]
    
    yp = yp.reshape(-1, 256)
    ys = ys.reshape(-1).long()
    return F.nll_loss(F.log_softmax(yp, dim=1), ys) / np.log(2)
    
def loss():
    net.train()
    xs, ys = next(bg)
    xs, ys = np2t(xs, ys)
    yp = net(xs.long())
    yp = yp.reshape(-1, 256)
    ys = ys.reshape(-1).long()
    return F.nll_loss(F.log_softmax(yp, dim=1), ys) / np.log(2)

valloss()
loss()

In [None]:
losses = []

for k in trange(999999):
    net.train()
    l = loss()
    l.backward()
    losses.append(l.item())
    net.optim.step()
    net.zero_grad()

    if len(losses) == 50:
        net.vlosses.append((net.iters, valloss().item()))
        net.losses.append((net.iters, np.mean(losses)))
        losses = []

    if k % 50 == 0:
        plt.plot(*zip(*net.losses))
        plt.plot(*zip(*net.vlosses))
        plt.grid()
        plt.show()

    net.iters += 1