In [None]:

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import pickle as pkl

%load_ext autoreload
%autoreload 2
from utils import *
from layers import *
from transformer import *
from cifar10 import *

In [None]:
# 0: Flugzeug
# 1: Auto
# 2: Vogel
# 3: Katze
# 4: Reh
# 5: Hund
# 6: Kröte
# 7: Pferd
# 8: Boot
# 9: Truck

labels = [
    'Plane',
    'Car',
    'Bird',
    'Cat',
    'Dog',
    'Deer',
    'Frog',
    'Horse',
    'Boat',
    'Truck'
]

In [None]:
showimg(getimg(1))

In [None]:
N

In [None]:

def batchgen(bsize=32, start=0):
    ep = 0
    while True:
        inds = np.random.permutation(range(start, N//5))
        minibatches = [ inds[k*bsize:(k+1)*bsize] for k in range(len(inds)//bsize) ]
        for mb in minibatches:
            xs = np.zeros((bsize, 3, 32, 32))
            zs = np.zeros((bsize), dtype=int)
            for i, j in enumerate(mb):
                x = getimg(j).reshape((32, 32, -1))
                xs[i] = x.transpose(2, 0, 1)
                zs[i] = getlabel(j)
            yield xs, zs
        print(f'========== EPOCH {ep} COMPLETED ==========')
        ep += 1

In [None]:
bg = batchgen()
xs, zs = next(bg)
xs.shape
showimg(xs[0])

In [None]:

class Encoder(nn.Module):
    def __init__(self, n, nlatent):
        super().__init__()
        self.conv1 = nn.Conv2d(13, n, 4, 2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(n)
        self.conv2 = nn.Conv2d(n, 2*n, 4, 2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(2*n)
        self.conv3 = nn.Conv2d(2*n, 4*n, 4, 2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(4*n)
        self.conv4 = nn.Conv2d(4*n, 8*n, 4, 2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(8*n)
        self.conv5 = nn.Conv2d(8*n, 4*16*n, 2, 1, padding=0, bias=True)
        self.bn5 = nn.BatchNorm2d(4*16*n)
        self.conv6 = nn.Conv2d(4*16*n, 2*nlatent, 1, 1, padding=0, bias=True)
        self.cuda()
    def forward(self, x, z=None, d=4):
        if z is None:
            z = torch.zeros(len(x), 10).cuda()
        z = z[:,:,None,None].repeat(1, 1, 32, 32)
        x = x-0.5
        x = torch.cat([x, z], 1)           # Output
        x = relu(self.bn1(self.conv1(x)))  # 64x16x16
        if d==1: return x
        x = relu(self.bn2(self.conv2(x)))  # 128x8x8
        if d==2: return x
        x = relu(self.bn3(self.conv3(x)))  # 256x4x4
        if d==3: return x
        x = relu(self.bn4(self.conv4(x)))  # 512x2x2
        if d==4: return x
        x = relu(self.bn5(self.conv5(x)))
        if d==5: return x
        x = self.conv6(x)                  # 20x1x1
        return x

In [None]:

class Decoder(nn.Module):
    def __init__(self, n, nlatent):
        super().__init__()
        self.conv0 = nn.Conv2d(nlatent+10, 4*16*n, 1, 1, bias=False)
        self.bn0 = nn.BatchNorm2d(4*16*n)
        self.conv01 = nn.ConvTranspose2d(4*16*n, 8*n, 2, 1, bias=False)
        self.bn01 = nn.BatchNorm2d(8*n)
        self.conv1 = nn.ConvTranspose2d(8*n, 4*n, 4, 2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(4*n)
        self.conv2 = nn.ConvTranspose2d(4*n, 2*n, 4, 2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(2*n)
        self.conv3 = nn.ConvTranspose2d(2*n, n, 4, 2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(n)
        self.conv4 = nn.ConvTranspose2d(n, 3, 4, 2, padding=1, bias=True)
        self.cuda()
    def forward(self, x, z=None, d=99999):
        if z is None:
            z = torch.zeros(len(x), 10).cuda()
        if d>=6:
            x = torch.cat((x, z[:,:,None,None]), 1)  # Input
            x = relu(self.bn0(self.conv0(x)))        # 20x1x1
        if d>=5:
            x = relu(self.bn01(self.conv01(x))) 
        if d>=4:
            x = relu(self.bn1(self.conv1(x)))
        if d>=3:
            x = relu(self.bn2(self.conv2(x))) 
        if d>=2:
            x = relu(self.bn3(self.conv3(x)))
        if d>=1:
            x = F.sigmoid(self.conv4(x))
        return x

In [None]:
class Net(nn.Module):
    def __init__(self, n, nlatent):
        super().__init__()
        self.enc = Encoder(n, nlatent)
        self.dec = Decoder(n, nlatent)
        self.nlatent = nlatent
    def forward(self, x, z=None, d=6, train=True):
        x = self.enc(x, z, d)
        if d >= 6:
            mu = x[:,:self.nlatent]
            logvar = x[:,self.nlatent:]
            std = torch.exp(logvar/2)
            eps = torch.randn_like(std) if train else 0
        else:
            mu, logvar = x, 0
            std = 0
            eps = 0
        x = mu + eps*std
        x = self.dec(x, z, d)
        return x, mu, logvar

In [None]:
from torchsummary import summary
net = Net(64, 20)
summary(net, (3, 32, 32))

In [None]:
from torch_optimizer import Lookahead, Yogi
net = Net(64, 20)
net.optim = Lookahead(Yogi(net.parameters(), lr=3e-3, betas=(0.9, 0.99)))
net.iters = 0
net.losses1 = []
net.losses2 = []
net.vlosses = []
bg = batchgen()

In [None]:
def onehot(zs):
    ys = np.zeros((len(zs), 10))
    for i in range(len(zs)):
        ys[i, zs[i]] = 1.
    return ys

'''
def valloss():
    bsize = 200
    xs = np.zeros((bsize, 3, 32, 32))
    zs = np.zeros((bsize), dtype=int)
    for i in range(bsize):
        x = getimg(i).reshape((32, 32, -1))
        xs[i] = x.transpose(2, 0, 1)
        zs[i] = getlabel(i)
    zs = onehot(zs)
    xs, zs = np2t(xs, zs)
    xs2, mu, logvar = net(xs, zs)
    pixelloss = torch.mean((xs-xs2)**2) + torch.mean(torch.abs(xs-xs2))/10
    klloss = 0.5 * torch.mean(-1 - logvar + mu**2/5 + torch.exp(logvar))
    return pixelloss + klloss'''
    
def loss(d=9999):
    xs, zs = next(bg)
    zs = onehot(zs)
    xs, zs = np2t(xs, zs)
    xs2, mu, logvar = net(xs, zs, d=d, train=True)
    pixelloss = torch.mean((xs-xs2)**2) + torch.mean(torch.abs(xs-xs2))/10
    klloss = 0.5 * torch.mean(-1 - logvar + mu**2/5 + torch.exp(logvar)) if not isinstance(logvar, int) else torch.tensor(0)
    return pixelloss, klloss/10

valloss(), loss()

In [None]:
def showexample(i=1, d=9999):
    bsize = 10
    xs = np.zeros((bsize, 3, 32, 32))
    zs = np.zeros((bsize), dtype=int)
    for j in range(bsize):
        x = getimg(j).reshape((32, 32, -1))
        xs[j] = x.transpose(2, 0, 1)
        zs[j] = getlabel(j)
    zs = onehot(zs)
    xs, zs = np2t(xs, zs)
    xs2, mu, logvar = net(xs, zs, d=d, train=False)
    
    xs, xs2 = t2np(xs, xs2)
    showimg(xs[i])
    plt.show()
    showimg(xs2[i])
    plt.show()
    
showexample(4, d=6)

In [None]:
losses1 = []
losses2 = []
bg = batchgen()

for k in trange(999999):
    d = net.iters // 2000 + 1
    d = 6 if net.iters%3==0 else 5
    net.train()
    l1, l2 = loss(d)
    (l1+l2).backward()
    losses1.append(l1.item())
    losses2.append(l2.item())
    net.optim.step()
    net.zero_grad()

    if len(losses1) == 50:
        net.vlosses.append((net.iters, valloss().item()))
        net.losses1.append((net.iters, np.mean(losses1)))
        net.losses2.append((net.iters, np.mean(losses2)))
        losses1 = []
        losses2 = []
        #slosses = []

    if k % 500 == 0:
        plt.plot(*zip(*net.losses1))
        plt.plot(*zip(*net.losses2))
        plt.plot(*zip(*net.vlosses))
        plt.grid()
        plt.show()
        showexample(4, d)

    net.iters += 1

In [None]:
torch.save(net.state_dict(), 'vae141223.dat')

In [None]:
net.losses1

In [None]:
showexample(1)

In [None]:
np.array(net.losses)[:,1]

In [None]:
plt.plot(np.convolve(np.array(net.losses)[:,1], [1, 1, 1, 1, 1, 1, 1, 1, 1]))
plt.ylim([0.6, 0.7])