In [1]:
import math
import time
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

from torch import nn, optim
import torch.nn.functional as F
import torch
import torchvision
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from PIL import Image

import torch.nn.functional as fnn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import os

import yaml

from sklearn.linear_model import LogisticRegressionCV
from firelab.config import Config

from visualize import random_interpolation, uniform_interpolation, visualization
from utils import load_dataset
from evaluation import fit_FC
from modules import Autoencoder, Critic
from train import train_acai, train_baseline

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Train ACAI

# Please write `your path` where the folder with raw images is located

In [None]:
PATH = ...
DEVICE = ... # Probably 'cuda:0'

In [2]:
class CelebA(Dataset):
    def __init__(self, path='/root/data/CelebA/img_align_celeba/', part='train'):
        if part=='train':
            self.data = [os.path.join(path, file) for file in os.listdir(path)][:182637]
        else:
            self.data = [os.path.join(path, file) for file in os.listdir(path)][182637:]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.transform(Image.open(self.data[idx]))
    
def make_dataloader(dataset, batch_size, image_size=4):
    dataset.transform = transforms.Compose([
                                            transforms.Resize((image_size, image_size)),                
                                            transforms.RandomHorizontalFlip(),      
                                            transforms.ToTensor()])
    return DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4, drop_last=True)

In [3]:
train_loader = make_dataloader(CelebA(path = PATH), 64, image_size=64)

In [4]:
os.mkdir('CelebA64_256_v2')

In [4]:
args =  {'dataset': 'MNIST',
         'eval_each': 10,
         'epochs': 101,
         'log_dir': 'CelebA64_256_v2/',
         'device': 'cuda:7',
         'weight_decay': 1e-05,
         'depth': 16,
         'gamma': 0.2,
         'lmbda': 0.5,
         'batch_norm': False,
         'batch_size': 64,
         'colors': 3,
         'latent_width': 4, # Bottleneck HW
         'width': 128, # Means 4 downsampling blocks
         'latent': 32, # Bottleneck channels
         'n_classes': 10,
         'advdepth': 16,
         'lr': 0.0001}

In [5]:
scales = int(round(math.log(args['width'] // args['latent_width'], 2)))
autoencoder = Autoencoder(scales=scales,depth=args['depth'],latent=args['latent'],colors=args['colors']).to(args['device'])
critic = Critic(scales=scales, depth=args['advdepth'], latent=args['latent'], colors=args['colors']).to(args['device'])

In [6]:
# Define optimizers
opt_ae = optim.Adam(autoencoder.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
opt_c = optim.Adam(critic.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

losses = defaultdict(list)

# Define `Perceptual Loss`

In [None]:
def tanh2sigmoid(batch):
    return batch.div(2).add(0.5)

class VGGExtractor(nn.Module):
    def __init__(self, vgg):
        super(VGGExtractor, self).__init__()
        
        mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
        self.register_buffer('mean', mean)

        std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
        self.register_buffer('std', std)
        
        self.relu0_1 = vgg[0:2]
        self.relu1_1 = vgg[2:7]
        self.relu2_1 = vgg[7:12]
        self.relu3_1 = vgg[12:21]
        self.relu4_1 = vgg[21:30]
        
    def forward(self, x, level=1):
        x = (x - self.mean)/self.std
        
        extracted_featrues = []
        for block in [self.relu0_1, self.relu1_1, self.relu2_1, self.relu3_1, self.relu4_1][:level]:
            x = block(x)
            extracted_featrues.append(x)
            
        return extracted_featrues
    
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.feature_exctracor = VGGExtractor(models.vgg19(pretrained=True).features).to(DEVICE)
        self.feature_exctracor.eval();

    def forward(self, inp1, inp2, level):
        features1, features2 = self.feature_exctracor(inp1, level), self.feature_exctracor(inp2, level)
        return [(features1[i] - features2[i]).pow(2).mean() for i in range(level)]
    
vgg_loss = VGGLoss()

# `Training loop`

In [None]:
LOG_DIR = 'CelebA64_256_v2/'

for epoch in range(args['epochs']):
    for index, X in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, desc=f'Epoch: {epoch}'):
        X = X.to(args['device'])
        
        alpha = 0.5 * torch.rand(args['batch_size'], 1, 1, 1).to(args['device'])

        latent_code = autoencoder.encoder(X)
        reconstruction = autoencoder.decoder(latent_code)

        # Here we shift all objects in batch by 1
        shifted_index = torch.arange(0, args['batch_size']) - 1
        interpolated_code = latent_code + alpha * (latent_code[shifted_index] - latent_code)

        # Decode interpolated latent code and calculate Critic's predictions
        reconstruction_interpolated = autoencoder.decoder(interpolated_code)
        alpha_reconstruction = critic(reconstruction_interpolated).reshape(args['batch_size'], 1, 1, 1)
        
        # Term1: Reconstruction loss
        # Term2: Trying to fool the Critic via Lowering it's predicted values on interpolated samples
        
        reconstruction_loss = F.mse_loss(X, reconstruction)
        critic_fooling_loss = (alpha_reconstruction**2).sum()
        
        perceptual1, perceptual2, perceptual3 = vgg_loss(X, reconstruction, 3)
        
        ae_loss = reconstruction_loss + (perceptual1 + perceptual2.mul(1/5) + perceptual2.mul(1/10))  + \
                  args['lmbda'] * critic_fooling_loss

        # Term1: Critic is trying to guess actual alpha
        # Term2: Critic is trying to assing "high realistic score" to samples which are linear interpolations (in data spcae)
        #        of original images and their reconstructions. Thus we are trying to encode the information about real samples
        #        to help Critic to distinguish between original and interpolated samples. (REGULARIZATION, optional)
        #        In case if our AE is perfect, it is just the critic(X) -> 0, w.r.t. Critic parameters
        
        alpha_guessing_loss = F.mse_loss(alpha_reconstruction, alpha)
        realistic_loss = (critic(args['gamma'] * X + (1 - args['gamma']) * reconstruction)**2).sum()
        critic_loss = alpha_guessing_loss + realistic_loss
        
        # AE's parameters update
        opt_ae.zero_grad()
        ae_loss.backward(retain_graph=True)
        opt_ae.step()

        # Critic's parameters update
        opt_c.zero_grad()
        critic_loss.backward(retain_graph=True)
        
        # Clip gradients of a Critic
#         nn.utils.clip_grad_norm_(critic.parameters(), 4)
        opt_c.step()
        
        # Store gradient norms
    uniform_interpolation(autoencoder, train_loader, N=11, savepath=f'{LOG_DIR}{epoch}.png')
    
    torch.save(autoencoder.state_dict(), LOG_DIR + 'ae.pt')

HBox(children=(FloatProgress(value=0.0, description='Epoch: 0', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 1', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 2', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 3', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 4', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 5', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 6', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 7', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 8', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 9', max=2853.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 10', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 11', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 12', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 13', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 14', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 15', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 16', max=2853.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='Epoch: 17', max=2853.0, style=ProgressStyle(description_w…

In [None]:
torch.save({'AE': autoencoder.state_dict(),
            'Critic': critic.state_dict()},
            'CelebA64_256_v2/acai.pt')