In [1]:
import os, sys, time
import itertools
import imageio
import math
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
from dataset import get_data
from scipy.misc import imsave
import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.utils as utils
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as v_utils
from torch.autograd import Variable

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [12]:
class arguments():
    def __init__(self):
        self.dataset = 'MNIST'
        self.dataroot = '/data/jehyuk/imgdata'
        self.workers = 2
        self.n_gpu = 1
        self.batchsize = 64
        self.maxepoch = 100
        self.imagesize = 64
        self.lrG = 0.0002
        self.lrD = 0.0002
        self.channel_bunch = 64
        self.use_cuda = True
        self.n_z = 64
        self.result_dir = '/home/jehyuk/GenerativeModels/GAN/results/AAE/' + self.dataset
        self.save_dir = '/home/jehyuk/GenerativeModels/GAN/models/AAE/' + self.dataset
        self.n_sample = 16

opt = arguments()

In [13]:
torch.manual_seed(20)
torch.cuda.manual_seed_all(20)

In [14]:
def load_dataset(dataroot = opt.dataroot, dataset=opt.dataset):
    data_folder = os.path.join(dataroot, dataset)
    if not os.path.exists(data_folder):
        os.makedirs(data_folder)
    transform = transforms.Compose([transforms.Scale(opt.imagesize),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    if dataset == 'MNIST':
        trn_data = dset.MNIST(data_folder, train=True, transform=transform, download=True)
        tst_data = dset.MNIST(data_folder, train=False, transform=transform, download=True)
        n_channels = 1
    elif dataset == 'Fashion-MNIST':
        trn_data = dset.FashionMNIST(data_folder, train=True, transform=transform, download=True)
        tst_data = dset.FashionMNIST(data_folder, train=False, transform=transform, download=True)
        n_channels = 1
    elif dataset == 'CIFAR10':
        trn_data = dset.cifar.CIFAR10(data_folder, train=True, transform=transform, download=True)
        tst_data = dset.cifar.CIFAR10(data_folder, train=False, transform=transform, download=True)
        n_channels = 3
    elif dataset == 'CelebA':
        trn_data = get_data(data_folder, split='train', image_size=opt.imagesize)
        tst_data = get_data(data_folder, split='test', image_size=opt.imagesize)
        n_channels = 3
    trn_loader = utils.data.DataLoader(trn_data, batch_size=opt.batchsize, shuffle=True, num_workers=opt.workers, drop_last=True)
    tst_loader = utils.data.DataLoader(tst_data, batch_size=opt.batchsize, shuffle=False, num_workers=opt.workers, drop_last=True)
    return trn_loader, tst_loader, n_channels

In [15]:
def initialize_weights(m):
    classname = m.__class__.__name__
    if classname.find('LInear') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


In [37]:
class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config
        
        layers = nn.Sequential()
        layers.add_module("Linear1", nn.Linear(self.config.imagesize*self.config.imagesize, 1000))
        layers.add_module("Activation1", nn.ReLU(inplace=True))
        layers.add_module("Linear2", nn.Linear(1000, 1000))
        layers.add_module("Activation2", nn.ReLU(inplace=True))
        layers.add_module("Linear3", nn.Linear(1000, self.config.n_z))
        self.layers = layers
        
    def forward(self, x):
        x = x.view(-1, self.config.imagesize * self.config.imagesize)
        x = self.layers(x)
        return x

class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        self.config = config
        
        layers = nn.Sequential()
        layers.add_module("Linear1", nn.Linear(config.n_z, 1000))
        layers.add_module("Activation1", nn.ReLU(inplace=True))
        layers.add_module("Linear2", nn.Linear(1000, 1000))
        layers.add_module("Activation2", nn.ReLU(inplace=True))
        layers.add_module("Linear3", nn.Linear(1000, self.config.imagesize*self.config.imagesize))
        layers.add_module("Activation3", nn.Sigmoid())
        self.layers = layers
        
    def forward(self, z):
        x = self.layers(z)
        x = x.view(-1, self.config.imagesize, self.config.imagesize)
        return x
    
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.config = config
        
        layers = nn.Sequential()
        layers.add_module("Linear1", nn.Linear(config.n_z, 1000))
        layers.add_module("Activation1", nn.ReLU(inplace=True))
        layers.add_module("Linear2", nn.Linear(1000, 1000))
        layers.add_module("Activation2", nn.ReLU(inplace=True))
        layers.add_module("Linear3", nn.Linear(1000, 1))
        layers.add_module("Activation3", nn.Sigmoid())
        self.layers = layers
    
    def forward(self, z):
        return self.layers(z)
        

In [35]:
class AAE(nn.Module):
    def __init__(self, config):
        super(AAE, self).__init__()
        self.config = config
        self.trn_loader, self.tst_loader, _ = load_dataset(self.config.dataroot, self.config.dataset)
        self.is_cuda = torch.cuda.is_available()
        
        self.encoder = Encoder(self.config)
        self.decoder = Decoder(self.config)
        self.discriminator = Discriminator(self.config)
        self.encoder.apply(initialize_weights)
        self.decoder.apply(initialize_weights)

        self.discriminator.apply(initialize_weights)
        self.sample_z = Variable(torch.randn((self.config.n_sample, self.config.n_z)), volatile=True)
        
        if self.is_cuda and self.config.use_cuda:
            selfencoder, self.decoder, self.discriminator = self.encoder.cuda(), self.decoder.cuda(), self.discriminator.cuda()
            self.sample_z = self.sample_z.cuda()
        
        self.optim_encoder = torch.optim.Adam(params=self.encoder.parameters(), lr=self.config.lrG, betas=(0.5, 0.999))
        self.optim_decoder = torch.optim.Adam(params=self.decoder.parameters(), lr=self.config.lrG, betas=(0.5, 0.999))
        self.optim_generator = torch.optim.Adam(params=self.encoder.parameters(), lr = self.config.lrG, betas=(0.5, 0.999))
        self.optim_discriminator = torch.optim.Adam(params=self.discriminator.parameters(), lr=self.config.lrD, betas=(0.5, 0.999))
        self.BCEloss = nn.BCELoss()
        self.MSEloss = nn.MSELoss()
        
    def train(self):
        self.loss_dict=dict()
        self.loss_dict['recon_loss'] = list()
        self.loss_dict['D_fake_loss'], self.loss_dict['D_real_loss'], self.loss_dict['D_loss'] = list(), list(), list()
        
        print('------------------Start training------------------')
        for epoch in range(self.config.maxepoch):
            print(">>>>Epoch: {}".format(epoch+1))
            start_time = time.time()
            for iter_num, (image, label) in enumerate(self.trn_loader):
                # Train the autoencoder
                self.encoder.train()
                self.decoder.train()
                
                self.encoder.zero_grad()
                self.decoder.zero_grad()
                x = Variable(image)
                if self.is_cuda:
                    x = x.cuda()
                x_recon = self.decoder.forward(self.encoder.forward(x))
                recon_loss = self.MSEloss(x_recon, x)
                recon_loss.backward()
                self.optim_encoder.step()
                self.optim_decoder.step()
                
                # Train the discriminator
                self.discriminator.train()
                self.encoder.eval()
                fake_z, fake_z_label = self.encoder.forward(x), Variable(torch.zeros(self.config.batchsize, 1))
                real_z, real_z_label = Variable(torch.randn(self.config.batch_size, self.n_z)), Variable(torch.ones(self.config.batchsize, 1))
                if self.is_cuda:
                    real_z = real_z.cuda()
                

In [34]:
aae = AAE(opt)

TypeError: optimizer can only optimize Variables, but one of the params is tuple