In [None]:
import os, time, sys
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:
class Generator(nn.Module):
    def __init__(self, d=128):
        super(Generator,self).__init__()
        self.deconv1 = nn.ConvTranspose2d(in_channels=100, out_channels=d*8, kernel_size=4, stride=1, padding=0)
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(in_channels=d*8, out_channels=d*4, kernel_size=4, stride=2, padding=1)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(in_channels=d*4, out_channels=d*2, kernel_size=4, stride=2, padding=1)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(in_channels=d*2, out_channels=d, kernel_size=4, stride=2, padding=1)
        self.deconv4_bn = nn.BatchNorm2d(d)
        self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1)
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    def forward(self, input):
        x = F.relu(self.deconv1_bn(self.deconv1(input)))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, d=128):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3,d,4,2,1)
        self.conv2 = nn.Conv2d(d,d*2,4,2,1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2,d*4,4,2,1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4,d*8,4,2,1)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8,1,4,2,0)
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input),0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)),0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)),0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)),0.2)
        x = torch.sigmoid(self.conv5(x))
        
        return x

In [None]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean,std)
        m.bias.data.zero_()

In [None]:
fixed_z_ = torch.randn((5*5, 100)).view(-1,100,1,1) # 5*5 25개 한번에 넣어버리기
fixed_z_ = Variable(fixed_z_.cuda(), requires_grad=False)
def show_result(num_epoch, show=False, save=False, path='result.png', isFix=False):
    z_ = torch.randn((5*5,100)).view(-1,100,1,1)
    z_ = Variable(z_.cuda(), requires_grad=False)
    
    G.eval()
    if isFix:
        test_images = G(fixed_z_)
    else:
        test_images = G(z_)
    G.train()
    
    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5,5)) #returns instance of figure and array of axes
    for i,j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i,j].get_xaxis().set_visible(False)
        ax[i,j].get_yaxis().set_visible(False)
    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i,j].cla() #clear axis
        ax[i,j].imshow((test_images[k].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
        
    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)
    
    if show:
        plt.show()
    else:
        plt.close()

### .eval() .train()
#### eval and train works differently on dropout and BN
#### 모델을 훈련할 땐 model.train() , 테스트할 땐 model.eval() 함수를 호출하면 Module 클래스의 훈련 상태 여부를 바꾸게 되며 이는 dropout

In [None]:
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))
    
    y1 = hist['D_losses']
    y2 = hist['G_losses']
    
    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')
    
    plt.xlabel('Iter')
    plt.ylabel('Loss')
    
    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()
    
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
batch_size = 256
lr = 0.00005
train_epoch = 10000

img_size = 64
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2) #num of threads to use

In [None]:
G = Generator(128)
D = Discriminator(128)
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
G.cuda()
D.cuda()

BCE_loss = nn.BCELoss()

G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))

In [None]:
if not os.path.isdir('CIFAR_DCGAN_results'):
    os.mkdir('CIFAR_DCGAN_results')
if not os.path.isdir('CIFAR_DCGAN_results/Random_results'):
    os.mkdir('CIFAR_DCGAN_results/Random_results')
if not os.path.isdir('CIFAR_DCGAN_results/Fixed_results'):
    os.mkdir('CIFAR_DCGAN_results/Fixed_results')

In [None]:
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []
num_iter = 0

### training start

In [None]:
start_time = time.time()
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    epoch_start_time = time.time()
    for x_, _ in trainloader:
        D.zero_grad()
        
        mini_batch = x_.size()[0]
        
        y_real_ = torch.ones(mini_batch)
        y_fake_ = torch.zeros(mini_batch)
        
        x_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda())

        D_result = D(x_).squeeze()
        D_real_loss = BCE_loss(D_result, y_real_)
        z_ = torch.randn((mini_batch, 100)).view(-1,100,1,1)
        z_ = Variable(z_.cuda())
        G_result = G(z_)
        
        D_result = D(G_result).squeeze()
        D_fake_loss = BCE_loss(D_result, y_fake_)
        D_fake_score = D_result.data.mean()
        
        D_train_loss = D_real_loss + D_fake_loss
        
        D_train_loss.backward()
        D_optimizer.step()
        D_losses.append(D_train_loss.data)
        G.zero_grad()
        
        z_ = torch.randn((mini_batch,100)).view(-1,100,1,1)
        z_ = Variable(z_.cuda())
        
        G_result = G(z_)
        D_result = D(G_result).squeeze() #removes 1 in shape ex) (2,1,2,1) -> (2,2)
        G_train_loss = BCE_loss(D_result, y_real_)
        G_train_loss.backward()
        G_optimizer.step()
        
        G_losses.append(G_train_loss.data)
        
        num_iter += 1
    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time
    
    print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
                                                              torch.mean(torch.FloatTensor(G_losses))))
    p = 'CIFAR_DCGAN_results/Random_results/CIFAR_DCGAN_' + str(epoch + 1) + '.png'
    fixed_p = 'CIFAR_DCGAN_results/Fixed_results/CIFAR_DCGAN_' + str(epoch + 1) + '.png'
    show_result((epoch+1), save=True, path=p, isFix=False)
    show_result((epoch+1), save=True, path=fixed_p, isFix=True)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

In [None]:
end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)

print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
print("Training finish!... save training results")
torch.save(G.state_dict(), "CIFAR_DCGAN_results/generator_param.pkl")
torch.save(D.state_dict(), "CIFAR_DCGAN_results/discriminator_param.pkl")
with open('CIFAR_DCGAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f) #dump to write to the file

show_train_hist(train_hist, save=True, path='CIFAR_DCGAN_results/CIFAR_DCGAN_train_hist.png')

images = []
for e in range(train_epoch):
    img_name = 'CIFAR_DCGAN_results/Fixed_results/CIFAR_DCGAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('CIFAR_DCGAN_results/generation_animation.gif', images, fps=5)