In [97]:
import os
import math
import random

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.utils.data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [96]:
reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [98]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [99]:
image_size = 32
batch_size = 4

dataset = datasets.CIFAR10(root="./data/cifar", 
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         shuffle=True)

Files already downloaded and verified


In [100]:
def weights_init(m):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            m.weight.data.normal_(mean=1.0, std=0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)

In [101]:
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        
        self.conv_1 = nn.Conv2d(3, 64, kernel_size = 3, stride = 2, padding = 1)
        self.batch_norm_1 = nn.BatchNorm2d(64)
        
        self.conv_2 = nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_2 = nn.BatchNorm2d(128)
        
        self.conv_3 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_3 = nn.BatchNorm2d(256)
        
        self.conv_4 = nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_4 = nn.BatchNorm2d(512)
        
        self.conv_5 = nn.Conv2d(512, num_classes, kernel_size = 2, stride = 1, padding = 0)

        
    def forward(self, x):
        '''
        Inputs:
            x: (batch x 3 x 32 x 32)
        Outputs:
            prob: (batch x 1)
        '''
        x = self.conv_1(x)
        x = self.batch_norm_1(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_2(x)
        x = self.batch_norm_2(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_3(x)
        x = self.batch_norm_3(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_4(x)
        x = self.batch_norm_4(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_5(x)
        
        x = x.view(x.size(0), -1)
    
        return x

          
class Generator(nn.Module):
    def __init__(self, z_size):
        super(Generator, self).__init__()
        
        self.conv_1 = nn.ConvTranspose2d(z_size, 512, kernel_size = 4, stride = 1, padding = 0, bias=False)     
        
        self.conv_2 = nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1, bias=False)
        self.batch_norm_2 = nn.BatchNorm2d(256)
        
        self.conv_3 = nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1, bias=False)
        self.batch_norm_3 = nn.BatchNorm2d(128)
        
        self.conv_4 = nn.ConvTranspose2d(128, 3, kernel_size = 4, stride = 2, padding = 1, bias=False)

    
    def forward(self, noise):
        '''
        Inputs:
            noise: (batch x z_size)
        Outputs:
            image: (batch x 3 x 32 x 32)
        '''
        #code here
        image = noise.view(noise.size(0), 100, 1, 1)

        
        image = self.conv_1(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_2(image)
        image = self.batch_norm_2(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_3(image)
        image = self.batch_norm_3(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_4(image)
        image = torch.tanh(image)
        
        return image

In [18]:
disc = Discriminator(11)
xx = disc(torch.zeros(10, 3, 32, 32))


In [19]:
xx.size()

torch.Size([10, 11])

In [102]:
class NoiseLoss(nn.Module):
    def __init__(self, scale=None, observed=None):
        super(NoiseLoss, self).__init__()
        
        self.scale    = 1
        self.observed = observed
        
        if scale is not None:
            self.scale = scale
            
    def forward(self, params):
        noise_loss = 0.0
        for param in params:
            noise = torch.empty_like(param).normal_(mean=0.0, std=1.0)
            noise_loss += self.scale * (noise * param).sum()
        noise_loss = noise_loss / self.observed
        return noise_loss
    
class PriorLoss(nn.Module):
    def __init__(self, prior_std=1.0, observed=None):
        super(PriorLoss, self).__init__()
        
        self.prior_std = prior_std
        self.observed  = observed
        
    def forward(self, params):
        prior_loss = 0.0
        for param in params:
            prior_loss += (param.pow(2) / self.prior_std ** 2).sum()
        prior_loss = prior_loss / self.observed
        return prior_loss

In [103]:
def adversarial_loss(logits):
    '''
    Inputs:
        logits: (batch x num_classes)
    Outputs:
        loss: scalar
    '''
    probs = F.softmax(logits, dim = -1)

    log_probs = (1 - probs[:, 0] + 1e-4).log()

    loss = torch.mean(-log_probs)

    return loss
    

In [66]:
xx = adversarial_loss(torch.ones(28, 11).to(device))

torch.Size([28])
tensor(0.0952, device='cuda:0')


In [48]:
xx.item()

0.09520021826028824

In [104]:
num_classes = 11
discriminator = Discriminator(num_classes).to(device)



latent_size = 100
num_z    = 1
num_mcmc = 4

generators = []
for generator in range(num_mcmc):
    generator = Generator(latent_size).to(device)
    generators.append(generator)


criterion = nn.CrossEntropyLoss()

lr    = 0.0002
beta1 = 0.5
beta2 = 0.999

disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
    
gen_optimizers = []
for generator in generators:
    optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
    gen_optimizers.append(optimizer)
    
    
gen_noise_alpha  = 0.0001
disc_noise_alpha = 0.0001

gen_prior_criterion  = PriorLoss(prior_std=1., observed=1000.)
gen_noise_criterion  = NoiseLoss(scale=math.sqrt(2 * gen_noise_alpha / lr), observed=1000.)
disc_prior_criterion = PriorLoss(prior_std=1., observed=50000.)
disc_noise_criterion = NoiseLoss(scale=math.sqrt(2 * disc_noise_alpha * lr), observed=50000.)


epoch      = 0
num_epochs = 25

dis_losses = []
gen_losses = []

In [105]:
fixed_noise = torch.randn(batch_size, latent_size).to(device)

In [106]:
def generate_images(generators):
    '''
    Inputs: 
        generators: list of generators
    Outputs:
        generated images: (batch * num_generators x 3 x 32 x 32)
    '''
       
    noise = torch.randn(batch_size, latent_size).to(device)
    
    outputs = []
    for generator in generators:
        outputs.append(generator(noise))
        
    return torch.cat(outputs)

In [31]:
output  = generate_images(generators)
print(output.size())

torch.Size([128, 3, 32, 32])


In [112]:
while epoch < num_epochs:
    epoch += 1
    
    for batch_idx, (image, _) in enumerate(dataloader):
        image      = image.to(device) 
        batch_size = image.size(0)
        
        
        #################
        #Generate Images
        #################
                
        generated_images = generate_images(generators)
        
        
        #################
        #Train Generator
        #################
        
        for generator in generators:
            generator.zero_grad()
           
        #####################
        #code here###########
        output = discriminator(generated_images)
        gen_loss = adversarial_loss(output)
        
        for generator in generators:
            gen_loss += gen_prior_criterion(generator.parameters())
            gen_loss += gen_noise_criterion(generator.parameters())
            
        gen_loss.backward()

        for optimizer in gen_optimizers:
            optimizer.step()
        
        
        #################################
        #Train Discriminator on Real Data
        #################################
    
        discriminator.zero_grad()
        ########################
        #code here##############
        real_logits = discriminator(image)
        real_loss   = adversarial_loss(real_logits)
        real_loss.backward()
        
        #################################
        #Train Discriminator on Fake Data
        #################################
        
        fake_logits = discriminator(generated_images.detach())
        fake_labels = torch.zeros(fake_logits.size(0)).long().to(device)
        
        ########################
        #code here##############
        fake_loss   = criterion(fake_logits, fake_labels)
        fake_loss.backward()
        
        
        ###############################
        #Train Supervised Discriminator
        ###############################
        
        
        
        for input_supervised, target_supervised in dataloader:
            input_supervised, target_supervised = input_supervised.to(device), target_supervised.to(device)
            break
        
        ##############################
        #code here####################
        logits = discriminator(input_supervised)
        target_supervised = target_supervised + 1
        loss_supervised   = criterion(logits, target_supervised)
        loss_supervised.backward()
        
        disc_prior_criterion(discriminator.parameters()).backward()
        disc_noise_criterion(discriminator.parameters()).backward()
        
        disc_optimizer.step()
            
            
        if batch_idx % 100 == 0:
            for i, generator in enumerate(generators):
                    torchvision.utils.save_image(generator(fixed_noise).data,
                                          '%s/%d_%d.png' % ('./bayes_gan_images', epoch, i),
                                          normalize=True)

In [111]:
fake_logits.size()

torch.Size([16, 11])

<h1>Домашнее задание</h1>
<h3>1. Попробуйте прочитать самостоятельно статью и понять что такое Prior и Noise Losses</h3>
<h3>2. В статье сказано что Bayes GAN умеет учиться классифицировать изображения в semi-supervised сеттинге. Создайте еще
один даталоудер с датасетом 4000 картинок и обучите ган. Посчитайте accuracy</h3>

In [None]:
for input_supervised, target_supervised in !!!!SEMI_SUPERVISED_DATALOADER!!!!!:
    input_supervised, target_supervised = input_supervised.to(device), target_supervised.to(device)
    break