In [12]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import make_grid

from IPython.display import display, Image
from IPython.core.debugger import set_trace
import os


In [13]:
DATA_PATH = '../datasets/CIFAR10/'
EPOCHES = 10
BATCH_SIZE = 64
NOISE_DIM = 100

In [24]:
DOWNLOAD = False
if not os.path.exists(DATA_PATH):
    DOWNLOAD = True

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])    

train_data = datasets.CIFAR10(root=DATA_PATH, train=True, transform=trans, download=DOWNLOAD)
test_data = datasets.CIFAR10(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# [3, 32, 32]

In [25]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            # [batch, 3, 32, 32]
            nn.Conv2d(3, 32 ,5, padding=2),
            # [batch, 32, 32, 32]
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(2, 2),
            # [batch, 32, 16, 16]
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),
            # [batch, 64, 16, 16]
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(2, 2),
            # [batch, 64, 8, 8]
        )
        self.fc = nn.Sequential(
            nn.Linear(64*8*8, 1024), # Can I reduce the hidden neuros?
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

class Generator(nn.Module):
    def __init__(self, input_size=100):
        super(Generator, self).__init__()
        self.fc = nn.Linear(input_size, 3*64*64) # Use a ReLU after that?
        self.br = nn.Sequential(
            nn.BatchNorm2d(3),
            nn.ReLU(),
        )
        self.downsample1 = nn.Sequential(
            # [batch, 3, 64, 64]
            nn.Conv2d(3, 50, 3, padding=1),
            # [batch, 50, 64, 64]
            nn.BatchNorm2d(50),
            nn.ReLU(),
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, padding=1),
            # [batch, 35, 64, 64]
            nn.BatchNorm2d(25),
            nn.ReLU()
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(25, 3, 2, stride=2),
            # [batch, 3, 32, 32]
            # Add BatchNorm here?
            nn.Tanh()
        )
    def forward(self, x):
        out = self.fc(x)
        out = out.view(-1, 3, 64, 64)
        out = self.br(out)
        out = self.downsample1(out)
        out = self.downsample2(out)
        out = self.downsample3(out)
        return out

d_model = Discriminator()
g_model = Generator(NOISE_DIM)
if torch.cuda.is_available():
    d_model = d_model.cuda()
    g_model = g_model.cuda()

In [26]:
criterion = nn.BCELoss()
d_optim = optim.Adam(d_model.parameters(), lr=0.0003)
g_optim = optim.Adam(g_model.parameters(), lr=0.0003)

In [27]:
def check_generator(model):
    def to_img(x):
        out = 0.5 * (x + 1)
        out = out.clamp(0, 1)
        out = out.view(-1, 3, 32, 32)
        return out
    model.eval()
    trans = transforms.ToPILImage()
    x_fake = Variable(torch.randn(8*8, NOISE_DIM))
    if torch.cuda.is_available():
        x_fake = x_fake.cuda()
    fake_img = model(x_fake)
    model.train()
    display(trans(make_grid(to_img(fake_img.cpu().data))))

In [None]:
# Training

for epoch in range(EPOCHES):
    print('*'*10)
    print('Epoch: {}'.format(epoch))
    d_tloss = 0.0
    g_tloss = 0.0
    d_taccu = 0.0
    for i, data in enumerate(train_loader):
        img, labels = data
        data_num = img.size(0)
#         img = img.view(data_num, -1)
        for p in d_model.parameters():
            p.requires_grad = True
    
        x_real = Variable(img)
        x_fake = Variable(torch.randn(data_num, NOISE_DIM))
        y_real = Variable(torch.ones(data_num))
        y_fake = Variable(torch.zeros(data_num))
        if torch.cuda.is_available():
            x_real = x_real.cuda()
            x_fake = x_fake.cuda()
            y_real = y_real.cuda()
            y_fake = y_fake.cuda()
        
        d_out_real = d_model(x_real)
        d_loss_real = criterion(d_out_real, y_real)
        g_out_fake = g_model(x_fake)
        d_out_fake = d_model(g_out_fake.detach())
        d_loss_fake = criterion(d_out_fake, y_fake)
        d_loss = d_loss_real + d_loss_fake
        d_tloss += d_loss.data[0]

        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()

        for p in d_model.parameters():
            p.requires_grad = False
#         x_fake = Variable(torch.randn(data_num, NOISE_DIM))
        g_out_fake = g_model(x_fake)
        d_out_fake = d_model(g_out_fake)
        g_loss = criterion(d_out_fake, y_real)
        g_tloss += g_loss.data[0]
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
    print('Generator_Loss: {}'.format(g_tloss))
    print('Discrimitor_Loss: {}'.format(d_tloss))
    check_generator(g_model)