In [None]:
'''
	https://github.com/jzbontar/pixelcnn-pytorch
    https://github.com/singh-hrituraj/PixelCNN-Pytorch/blob/master/MaskedCNN.py
    https://github.com/singh-hrituraj/PixelCNN-Pytorch/blob/master/Model.py
    https://github.com/singh-hrituraj/PixelCNN-Pytorch/blob/master/train.py
    https://github.com/singh-hrituraj/PixelCNN-Pytorch/blob/master/generate.py
    https://stackoverflow.com/questions/65172786/how-to-load-one-type-of-image-in-cifar10-or-stl10-with-pytorch
    https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
    Have left my code as it was when I last ran it on stl 96x96, requires changes to functions to run on cifar10 though.
    Also not mentioned in my report added gated activation functions for pixelcnn
'''

import math
import random
from os.path import exists
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.nn.utils import weight_norm as wn
#from matplotlib.animation import FuncAnimation
%matplotlib inline

manualSeed = np.random.randint(1, 10000)
print("Random Seed: ", manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)

from google.colab import drive
drive.mount('/content/drive')

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def stl_classes(data_set, label):
    indices = []
    for i in range(len(data_set)):
        if data_set[i][1] == label:
            indices.append(i)
    return indices

def get_data(dataset):
    trans = torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Resize((img_dim, img_dim)),
                                    torchvision.transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                ])

    if 'cifar10' == dataset:
        train_data_set = torchvision.datasets.CIFAR10('drive/My Drive/training/cifar10', train=True, download=True, transform=trans)
        
        indices_birds = torch.tensor(train_data_set.targets) == 2
        indices_sub_birds = [366, 443, 613, 626, 765, 1126, 1171, 2063, 2250, 2372, 2490, 3093, 3139, 3248, 3408, 3448, 4722]#[154, 200, 228, 558, 567, 591, 738, 782, 999, 1016, 1033, 1037, 1095, 1126, 1131, 1176, 1408, 1424, 1443, 1476, 1560, 1581, 1705, 1718, 1736, 1835, 1901, 2013, 2098, 2106, 2157, 2243, 2250, 2405, 2438, 2489, 2505, 2588, 2595, 2653, 2673, 2682, 2722, 2918, 3072, 3139, 2158, 3248, 3319, 3335, 3339, 3351, 3352, 3418, 3448, 3486, 4722]

        indices_horses = torch.tensor(train_data_set.targets) == 7
        indices_sub_horses = [6, 15, 19, 30, 86, 106, 107, 114, 177, 210, 228, 237, 292, 315, 353, 354, 502, 527, 547, 548, 656, 470, 670, 799, 800, 805, 814, 816, 823, 903, 910, 934, 948, 992, 993, 1013, 1066, 1076, 1103, 1124, 1191, 1201, 1251, 1313, 1340, 1343, 1384, 1439, 1463, 1485, 1528, 1529, 1574, 1677, 1772]

        print(len(indices_sub_birds), len(indices_sub_horses))
        horse_data_set = torch.utils.data.dataset.Subset(train_data_set, np.where(indices_horses==1)[0])
        horse_data_set = torch.utils.data.dataset.Subset(horse_data_set, indices_sub_horses)

        bird_data_set = torch.utils.data.dataset.Subset(train_data_set, np.where(indices_birds==1)[0])
        bird_data_set = torch.utils.data.dataset.Subset(bird_data_set, indices_sub_birds)

        horse_loader = torch.utils.data.DataLoader(horse_data_set, 
            shuffle=False, batch_size=1, drop_last=True, pin_memory=True)
        bird_loader = torch.utils.data.DataLoader(bird_data_set, 
            shuffle=False, batch_size=1, drop_last=True, pin_memory=True)
        train_loader = torch.utils.data.DataLoader(train_data_set,
            shuffle=False, batch_size=batch_size, drop_last=True, pin_memory=True)
        
        class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    if 'stl10' == dataset:
        train_data_set = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=trans)
        indices_birds = stl_classes(train_data_set, 1)
        indices_sub_birds = [30, 41, 128, 156, 159, 256, 360]
        indices_horses = stl_classes(train_data_set, 6)
        indices_sub_horses = [3, 4, 9, 10, 18, 38, 53, 56, 86, 179, 184, 192, 201, 215, 227, 246, 286, 300, 302, 319, 320]

        bird_data_set = torch.utils.data.dataset.Subset(train_data_set, indices_birds)
        bird_data_set = torch.utils.data.dataset.Subset(bird_data_set, indices_sub_birds)

        horse_data_set = torch.utils.data.dataset.Subset(train_data_set, indices_horses)
        horse_data_set = torch.utils.data.dataset.Subset(horse_data_set, indices_sub_horses)

        horse_loader = torch.utils.data.DataLoader(horse_data_set, 
            shuffle=False, batch_size=1, drop_last=True)
        bird_loader = torch.utils.data.DataLoader(bird_data_set, 
            shuffle=False, batch_size=1, drop_last=True)

        train_loader = torch.utils.data.DataLoader(train_data_set, shuffle=False, batch_size=batch_size, drop_last=True, pin_memory=True)
        train_iterator = iter(cycle(train_loader))
        class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] # these are slightly different to CIFAR-10
        
    
    '''
    to_show = iter(cycle(bird_loader))
    for i in range(17):
        print(i)
        img, _ = next(to_show)
        img = img + 0.5
        plt.rcParams['figure.dpi'] = 150
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(img.view(img.size()[0], n_channels, img_dim, img_dim)).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
        plt.show()

    to_show = iter(cycle(horse_loader))
    for i in range(13):
        print(i)
        img, _ = next(to_show)
        img = img + 0.5
        plt.rcParams['figure.dpi'] = 150
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(img.view(img.size()[0], n_channels, img_dim, img_dim)).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
        plt.show()
    '''
    data_varience = 1.0#np.var(train_data_set.data / 255.0)
    return train_loader, len(train_loader), class_names, bird_loader, horse_loader, data_varience

class Quantise(nn.Module):
    def __init__(self):
        super(Quantise, self).__init__()
        self.v_width = 128
        self.num_v = 256
        self.mse = lambda x, y: ((x - y)**2).mean()
        
        self.e_table = nn.Embedding(self.num_v, self.v_width)
        self.e_table.weight.data.uniform_(-1/self.num_v, 1/self.num_v)
        self.beta = 0.25

    def forward(self, z_features):
        z_features = z_features.permute(0, 2, 3, 1).contiguous()
        latent_shape = z_features.shape
        feature_vect = z_features.view(-1, self.v_width)

        
        distances = (torch.sum(feature_vect**2, dim=1, keepdim=True) 
                    + torch.sum(self.e_table.weight**2, dim=1)
                    - 2 * torch.matmul(feature_vect, self.e_table.weight.t()))
            
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_v, device=z_features.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        quantized = torch.matmul(encodings, self.e_table.weight).view(latent_shape)
        encoding_indices = encoding_indices.view(latent_shape[0], latent_shape[1], latent_shape[2], 1)
        
        e_latent_loss = self.mse(quantized.detach(), z_features)
        q_latent_loss = self.mse(quantized, z_features.detach())
        loss = q_latent_loss + self.beta * e_latent_loss
        
        quantized = z_features + (quantized - z_features).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices.permute(0, 3, 1, 2).contiguous()

class EncodeBlock(nn.Module):
    def __init__(self, in_f, out_f, k_size, stride, padding=0):
        super(EncodeBlock, self).__init__()
        self.f_pass = nn.Sequential(
            nn.Conv2d(in_f, out_f, kernel_size=k_size, stride=stride, padding=padding),
            #nn.BatchNorm2d(out_f),
            nn.ReLU(True)
        )

    def forward(self,x):
        return self.f_pass(x)

class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )
    
    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x) 

class DecodeBlock(nn.Module):
    def __init__(self, in_f, out_f, k_size, stride, padding=0):
        super(DecodeBlock, self).__init__()
        self.f_pass = nn.Sequential(
            nn.ConvTranspose2d(in_f, out_f, kernel_size=k_size, stride=stride, padding=padding),
            #nn.BatchNorm2d(out_f),
            nn.ReLU(True)
        )

    def forward(self,x):
        return self.f_pass(x) 

class MaskedCNN(nn.Conv2d):
	def __init__(self, mask_type, *args, **kwargs):
		self.mask_type = mask_type
		assert mask_type in ['A', 'B'], "Unknown Mask Type"
		super(MaskedCNN, self).__init__(*args, **kwargs)
		self.register_buffer('mask', self.weight.data.clone())

		_, depth, height, width = self.weight.size()
		self.mask.fill_(1)
		if mask_type =='A':
			self.mask[:,:,height//2,width//2:] = 0
			self.mask[:,:,height//2+1:,:] = 0
		else:
			self.mask[:,:,height//2,width//2+1:] = 0
			self.mask[:,:,height//2+1:,:] = 0

	def forward(self, x):
		self.weight.data*=self.mask
		return super(MaskedCNN, self).forward(x)

class Gated_Act(nn.Module):
    def __init__(self):
        super(Gated_Act, self).__init__()

    def forward(self, x):
        return torch.tanh(x) * torch.sigmoid(x)

class PixelBlock(nn.Module):
    def __init__(self, mask_type, in_f, out_f, k_size):
        super(PixelBlock, self).__init__()
        self.f_pass = nn.Sequential(
            MaskedCNN(mask_type, in_f, out_f, k_size, 1, k_size//2, bias=False),
            nn.BatchNorm2d(out_f),
            Gated_Act()
        )

    def forward(self,x):
        return self.f_pass(x) 

class PixelCNN(nn.Module):
    def __init__(self):
        super(PixelCNN, self).__init__()
        filters = 128
        k_size = 7
        self.main = nn.Sequential(
            PixelBlock('A', 1, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            PixelBlock('B', filters, filters, k_size),
            nn.Conv2d(filters, 256, 1)
        )
        
    def forward(self, x):
        return self.main(x)

def show_latent(image1, image2):
    print("lat")
    plt.rcParams['figure.dpi'] = 101
    plt.imshow(torchvision.utils.make_grid(image1.view(batch_size, 1, 8, 8)[:32]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()   
    plt.rcParams['figure.dpi'] = 101
    plt.imshow(torchvision.utils.make_grid(image2.view(batch_size, 1, 8, 8)[:32]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()   

def show_quantised(image1, image2):
    print("quan")
    plt.rcParams['figure.dpi'] = 101
    plt.imshow(torchvision.utils.make_grid(image1[:,0:3].view(batch_size, 3, 8, 8)[:32]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()   
    plt.rcParams['figure.dpi'] = 101
    plt.imshow(torchvision.utils.make_grid(image2[:,0:3].view(batch_size, 3, 8, 8)[:32]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()   

class VQ_VAE(nn.Module):
    def __init__(self, hidden_dim=1024, latent_dim=32, latent_channels = 128):
        super(VQ_VAE, self).__init__()
        #Input to encode should be (64, 3, 32, 32) -> maybe should experiment with max pooling here
        self.mse = lambda x, y: ((x - y)**2).mean()
        self.bce = nn.BCEWithLogitsLoss(reduction='sum')

        self.quantise = Quantise()

        self.encode = nn.Sequential(
            EncodeBlock(n_channels, latent_channels//2, 4, 2, 1),
            EncodeBlock(latent_channels//2, latent_channels, 4, 2, 1),
            EncodeBlock(latent_channels, latent_channels, 4, 2, 1),
            nn.Conv2d(in_channels=latent_channels, out_channels=latent_channels,kernel_size=3,stride=1, padding=1),
            ResidualStack(in_channels=latent_channels,
                        num_hiddens=latent_channels,
                        num_residual_layers=2,
                        num_residual_hiddens=32),
            nn.Conv2d(latent_channels, self.quantise.v_width, kernel_size=1, stride=1),
        )

        self.decode = nn.Sequential(
            nn.Conv2d(self.quantise.v_width, latent_channels, kernel_size=3, stride=1, padding=1),
            ResidualStack(in_channels=latent_channels,
                        num_hiddens=latent_channels,
                        num_residual_layers=2,
                        num_residual_hiddens=32),
            DecodeBlock(latent_channels, latent_channels, 4, 2, 1),
            DecodeBlock(latent_channels, latent_channels//2, 4, 2, 1),
            nn.ConvTranspose2d(in_channels=latent_channels//2, out_channels=n_channels, kernel_size=4, stride=2, padding=1)
        )                    

    
    def interpolate(self, z_x, z_y, per):
        pers = [per]#[0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88]
        #pers = [0.40 for i in range(z_x.size()[0])]
        
        #pers = [0.3 + (i * 0.04) for i in range(8)]

        interpolations = torch.zeros(z_x.size()).to(device)
        for per in range(z_x.size()[0]):
            interpolations[per] = (z_x[per] * pers[per]) + (z_y[per] * (1 - pers[per]))
        
        return interpolations

    def forward(self, x, detached=False):
        z = self.encode(x)

        loss, quantised, perplexity, indices = self.quantise(z)

        recon_x = self.decode(quantised)
        int_x = recon_x

        return loss, recon_x, int_x, indices, quantised
    
    def forward_interpolate(self, x, y, per):
        z_x = self.encode(x)
        z_y = self.encode(y)

        z_int = self.interpolate(z_x, z_y, per)

        _, quantised_z, _, indices_z = self.quantise(z_int.detach())

        int_x = self.decode(quantised_z)

        return int_x

def decode_sample(sample, model):
    sample = sample.type(torch.int64)
    sample = sample.permute(0, 2, 3, 1)
    encoding_indices = sample.view(-1, 1).to(device)
    encodings = torch.zeros(encoding_indices.shape[0], model.quantise.num_v, device=device)
    encodings.scatter_(1, encoding_indices, 1)
    quantized = torch.matmul(encodings, model.quantise.e_table.weight).view(sample.size()[0], sample.size()[1], sample.size()[2], model.quantise.v_width)
    quantized = quantized.permute(0, 3, 1, 2).contiguous()
    return quantized

def fit_VAE(model, data_loader, optimiser):
    model.train()
    loss_arr = np.zeros(0)
    print(len_train)

    for i in range(400):
        x, _ = next(train_iterator)
        x = x.to(device)
        optimiser.zero_grad()

        vq_loss, recon_x, int_x, _, _ = model(x)

        recon_loss = model.mse(recon_x, x) / data_varience
        loss = recon_loss + vq_loss

        loss.backward() 
        optimiser.step()

        loss_arr = np.append(loss_arr, loss.item()/batch_size)
    return loss_arr

def validate_VAE(model, data_loader):
    model.eval()

    loss_arr = np.zeros(0)
    with torch.no_grad():
        for i in range(10):
            x,t = next(train_iterator)
            x,t = x.to(device), t.to(device)
            vq_loss, recon_x, int_x, _, _ = model(x)

            recon_loss = model.mse(recon_x, x) / data_varience
            loss = recon_loss + vq_loss
    
            loss_arr = np.append(loss_arr, loss.item()/batch_size)
    return loss_arr, x, recon_x, int_x

def fit_Pixel(model, model_p, horse_iterator, bird_iterator, optimiser, optimiser_p, interpolations):
    model.train()
    model_p.train()

    criterion = nn.CrossEntropyLoss()
    loss_arr = np.zeros(0)
    num_int = interpolations.size()[0]+1
    print(len_train)
    for i in range(100):
        h, _ = next(horse_iterator)
        h = h.to(device)
        for i in range(batch_size - num_int):
            h = torch.cat((h, next(horse_iterator)[0].to(device)))
        b = interpolations.to(device)
        comb = torch.cat((h, b), dim=0)

        loss, recon_x, _, indices, quantised = model(comb)

        indices = Variable(indices).to(device) * 1.

        output = model_p(indices) 
        
        loss = criterion(output, indices[:, 0, :, :].long())

        optimiser_p.zero_grad()
        loss.backward()
        optimiser_p.step()
        
        loss_arr = np.append(loss_arr, loss.item()/batch_size)
    return loss_arr

       
def validate_Pixel(model, model_p, horse_iterator, bird_iterator, interpolations):
    model.eval()
    model_p.eval()

    criterion = nn.CrossEntropyLoss()

    loss_arr = np.zeros(0)

    num_int = interpolations.size()[0]+1
    with torch.no_grad():
        for i in range(10):
            h, _ = next(horse_iterator)
            h = h.to(device)
            for i in range(batch_size - num_int):
                h = torch.cat((h, next(horse_iterator)[0].to(device)))
            b = interpolations.to(device)
            comb = torch.cat((h, b), dim=0)

            loss, recon_x, _, indices, quantised = model(comb)
    
            indices = Variable(indices).to(device) * 1.
    
            output = model_p(indices) 
    
            loss = criterion(output, indices[:, 0, :, :].long())
    
            loss_arr = np.append(loss_arr, loss.item()/batch_size)

        i_size = indices.size()[3] 
        sample = torch.Tensor(batch_size, 1, i_size, i_size).to(device)
        sample.fill_(0)
    
        for i in range(i_size):
            for j in range(i_size):
                out = model_p(sample)
                probs = F.softmax(out[:,:,i,j], dim=-1).data
                sample[:,:,i,j] = torch.multinomial(probs, 1).int()

        q_latent = decode_sample(sample, model)
        sample = model.decode(q_latent)
    return loss_arr, comb, recon_x, sample

#Add plot of latent space in generator
def show_stats(loss_train, loss_test):
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111)
    plt.plot(np.arange(0, loss_train.size), loss_train)
    plt.plot(np.arange(0, loss_test.size), loss_test)
    plt.show()

def show_example(original, reconstruction, interpolated):
    plt.rcParams['figure.dpi'] = 101
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(original.view(batch_size, n_channels, img_dim, img_dim)[:8]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(reconstruction.view(batch_size, n_channels, img_dim, img_dim)[:8]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(interpolated.view(batch_size, n_channels, img_dim, img_dim)).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()

    plt.pause(0.0001)

def show_interpolation(images):
    plt.rcParams['figure.dpi'] = 101
    plt.imshow(torchvision.utils.make_grid(images.view(images.size()[0], n_channels, img_dim, img_dim)[:32]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()

def show_batch(original, reconstruction, interpolated):
    plt.rcParams['figure.dpi'] = 150
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(interpolated).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()

def set_up_model():
    epoch = 0
    model = VQ_VAE().to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=VQ_lr)

    epoch_p = 0
    model_p = PixelCNN().to(device)
    optimiser_p = torch.optim.Adam(model_p.parameters(), lr=Pixel_lr)
    scheduler = lr_scheduler.StepLR(optimiser_p, step_size=1, gamma=lr_decay)

    if exists('drive/My Drive/training/VQ-VAE-' + str(img_dim) + '.chkpt'):
        params_vq = torch.load('drive/My Drive/training/VQ-VAE-' + str(img_dim) + '.chkpt')
        model.load_state_dict(params_vq['VQVAE'])
        optimiser.load_state_dict(params_vq['optimiser'])
        epoch = params_vq['epoch']

    if exists('drive/My Drive/training/PixelCNN-' + str(img_dim) + '.chkpt'):
        params = torch.load('drive/My Drive/training/PixelCNN-' + str(img_dim) + '.chkpt')
        model_p.load_state_dict(params['model_p'])
        scheduler.load_state_dict(params['scheduler'])
        optimiser_p.load_state_dict(params['optimiser_p'])
        epoch_p = params['epoch']
    return model, epoch, optimiser, scheduler, model_p, optimiser_p, epoch_p

def train(model, model_p, epoch, optimiser, optimiser_p, epoch_p):
    loss_train, loss_test = np.zeros(0), np.zeros(0)
    loss_g_train, loss_d_train = np.zeros(0), np.zeros(0)

    while (epoch < 40):
        avg_batch_loss_t = fit_VAE(model, train_loader, optimiser)
        avg_batch_loss_v, original, reconstruction, interpolated = validate_VAE(model, train_loader)

        print('train_loss ' + str(avg_batch_loss_t.mean()))
        print('val_loss ' + str(avg_batch_loss_v.mean()))

        loss_train = np.concatenate([loss_train, avg_batch_loss_t])
        loss_test = np.concatenate([loss_test, avg_batch_loss_v])

        show_stats(loss_train, loss_test)
        show_stats(avg_batch_loss_t, avg_batch_loss_v)
        show_example(original+0.5, reconstruction+0.5, interpolated+0.5)
        print(epoch)
        torch.save({'VQVAE':model.state_dict(),
                    'optimiser':optimiser.state_dict(), 
                     'epoch':epoch}, 
                    'drive/My Drive/training/VQ-VAE-' + str(img_dim) + '.chkpt')
        epoch = epoch+1

    interpolations = interpolate_birds_horses(model)
    while (epoch_p < 103):

        avg_batch_loss_t = fit_Pixel(model, model_p, horse_iterator, bird_iterator, optimiser, optimiser_p, interpolations)
        avg_batch_loss_v, original, reconstruction, recon_z = validate_Pixel(model, model_p, horse_iterator, bird_iterator, interpolations)

        print('train_loss ' + str(avg_batch_loss_t.mean()))
        print('val_loss ' + str(avg_batch_loss_v.mean()))

        loss_train = np.concatenate([loss_train, avg_batch_loss_t])
        loss_test = np.concatenate([loss_test, avg_batch_loss_v])

        show_stats(loss_train, loss_test)
        show_stats(avg_batch_loss_t, avg_batch_loss_v)

        show_example(original+0.5, reconstruction+0.5, recon_z+0.5)
        print(epoch_p)
        epoch_p = epoch_p+1 
        torch.save({'model_p':model_p.state_dict(),
                    'optimiser_p':optimiser_p.state_dict(), 
                    'scheduler':scheduler.state_dict(), 
                    'epoch':epoch_p}, 
                    'drive/My Drive/training/PixelCNN-' + str(img_dim) + '.chkpt')
    for i in range(250):
        avg_batch_loss_v, original, reconstruction, recon_z = validate_Pixel(model, model_p, horse_iterator, bird_iterator, interpolations)
        show_batch(original+0.5, reconstruction+0.5, recon_z+0.5)


def interpolate_birds_horses(model):
    interps = None 
    for hind in range(21):
        h_org = next(horse_iterator)[0].to(device)
        for bind in range(7):
            b = torchvision.transforms.functional.adjust_contrast(next(bird_iterator)[0].to(device), 2)
            b = torchvision.transforms.functional.adjust_gamma(b, 1.5)
            for per in [0.66]:
                if bind == 4:
                    h = h_org
                    print("h:   ", hind, "b:   ", bind)
                    inte1 = model.forward_interpolate(h, b, per)
                    b = torchvision.transforms.functional.hflip(b)
                    h = h_org
                    print("h:   ", hind, "b:   ", bind)
                    inte2 = model.forward_interpolate(h, b, per)
    
                    if hind:
                        interps = torch.cat((interps, inte1, inte2), dim=0)
                    else:
                        interps = torch.cat((inte1, inte2), dim=0)
    print(interps.size())
    print("------------------------------------------------------ \n ---------------------------------------------------------- \n -------------------------------------------------")
    return interps

Random Seed:  2401
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
if __name__ == "__main__":
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device)

    batch_size  = 64
    img_dim = 96
    n_channels  = 3

    VQ_lr = 0.001
    Pixel_lr = 0.0001
    lr_decay = 0.999995

    dataset = 'stl10'
    train_loader, len_train, class_names, bird_loader, horse_loader, data_varience = get_data(dataset)

    train_iterator = iter(cycle(train_loader))

cuda
Files already downloaded and verified


In [None]:
bird_iterator = iter(cycle(bird_loader))
horse_iterator = iter(cycle(horse_loader))

model, epoch, optimiser, scheduler, model_p, optimiser_p, epoch_p = set_up_model()
train(model, model_p, epoch, optimiser, optimiser_p, epoch_p)
#interpolate_birds_horses(model)
