In [2]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [3]:
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init

import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

In [4]:
epoch = 50
batch_size = 512
learning_rate = 0.0002
num_gpus = 1

In [5]:
mnist_train = dset.MNIST("drive/MyDrive", train=True,
                         transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          # transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                          transforms.Normalize((0.5), (0.5))
                         ]),
                         target_transform=None,
                         download=False)

train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size,
                                           shuffle=True, drop_last=True)

In [6]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.layer1 = nn.Sequential(
        nn.Linear(100, 7*7*256),
        nn.BatchNorm1d(7*7*256),
        nn.ReLU()
    )
    self.layer2 = nn.Sequential(OrderedDict([
                          ('conv1', nn.ConvTranspose2d(256,128,3,2,1,1)),
                          ('bn1', nn.BatchNorm2d(128)),
                          ('relu1', nn.LeakyReLU()),
                          ('conv2', nn.ConvTranspose2d(128,64,3,1,1)),
                          ('bn2', nn.BatchNorm2d(64)),
                          ('relu2', nn.LeakyReLU())
    ]))
    self.layer3 = nn.Sequential(OrderedDict([
                          ('conv3', nn.ConvTranspose2d(64,16,3,1,1)),
                          ('bn3', nn.BatchNorm2d(16)),
                          ('relu3', nn.LeakyReLU()),
                          ('conv4', nn.ConvTranspose2d(16,1,3,2,1,1)),
                          ('relu4', nn.Tanh())
    ]))

  def forward(self,z):
    out = self.layer1(z)
    out = out.view(batch_size//num_gpus,256,7,7)
    out = self.layer2(out)
    out = self.layer3(out)
    return out

In [20]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.layer1 = nn.Sequential(OrderedDict([
                          ('conv1', nn.Conv2d(1,8,3,padding=1)),
                          ('bn1', nn.BatchNorm2d(8)),
                          ('relu1', nn.LeakyReLU()),
                          ('conv2', nn.Conv2d(8,16,3,padding=1)),
                          ('bn2', nn.BatchNorm2d(16)),
                          ('relu2', nn.LeakyReLU()),
                          ('max1', nn.MaxPool2d(2,2))
    ]))
    self.layer2 = nn.Sequential(OrderedDict([
                          ('conv3', nn.Conv2d(16,32,3,padding=1)),
                          ('bn3', nn.BatchNorm2d(32)),
                          ('relu3', nn.LeakyReLU()),
                          ('max2', nn.MaxPool2d(2,2)),
                          ('conv4', nn.Conv2d(32,64,3,padding=1)),
                          ('bn4', nn.BatchNorm2d(64)),
                          ('relu4', nn.LeakyReLU())
    ]))
    self.fc = nn.Sequential(
        nn.Linear(64*7*7, 1),
        nn.Sigmoid()
    )

  def forward(self,x):
    out = self.layer1(x)
    print(out.size())
    out = self.layer2(out)
    print(out.size())
    out = out.view(batch_size//num_gpus, -1)
    # out = out.view(1, -1)
    print(out.size())
    out = self.fc(out)
    return out

In [21]:
!pip install torchinfo



In [22]:
from torchinfo import summary

In [23]:
dis = Discriminator().cuda()

In [24]:
summary(dis, (batch_size,1,28,28))

torch.Size([512, 16, 14, 14])
torch.Size([512, 64, 7, 7])
torch.Size([512, 3136])


Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [512, 16, 14, 14]         --
|    └─Conv2d: 2-1                       [512, 8, 28, 28]          80
|    └─BatchNorm2d: 2-2                  [512, 8, 28, 28]          16
|    └─LeakyReLU: 2-3                    [512, 8, 28, 28]          --
|    └─Conv2d: 2-4                       [512, 16, 28, 28]         1,168
|    └─BatchNorm2d: 2-5                  [512, 16, 28, 28]         32
|    └─LeakyReLU: 2-6                    [512, 16, 28, 28]         --
|    └─MaxPool2d: 2-7                    [512, 16, 14, 14]         --
├─Sequential: 1-2                        [512, 64, 7, 7]           --
|    └─Conv2d: 2-8                       [512, 32, 14, 14]         4,640
|    └─BatchNorm2d: 2-9                  [512, 32, 14, 14]         64
|    └─LeakyReLU: 2-10                   [512, 32, 14, 14]         --
|    └─MaxPool2d: 2-11                   [512, 32, 7, 7]           --
|    └─Co

In [29]:
gen = Generator().cuda()

In [32]:
summary(gen, (batch_size,100))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [512, 12544]              --
|    └─Linear: 2-1                       [512, 12544]              1,266,944
|    └─BatchNorm1d: 2-2                  [512, 12544]              25,088
|    └─ReLU: 2-3                         [512, 12544]              --
├─Sequential: 1-2                        [512, 64, 14, 14]         --
|    └─ConvTranspose2d: 2-4              [512, 128, 14, 14]        295,040
|    └─BatchNorm2d: 2-5                  [512, 128, 14, 14]        256
|    └─LeakyReLU: 2-6                    [512, 128, 14, 14]        --
|    └─ConvTranspose2d: 2-7              [512, 64, 14, 14]         73,792
|    └─BatchNorm2d: 2-8                  [512, 64, 14, 14]         128
|    └─LeakyReLU: 2-9                    [512, 64, 14, 14]         --
├─Sequential: 1-3                        [512, 1, 28, 28]          --
|    └─ConvTranspose2d: 2-10             [512, 16, 14, 14]     

In [25]:
generator = nn.DataParallel(Generator()).cuda()
discriminator = nn.DataParallel(Discriminator()).cuda()

In [26]:
gen_params = generator.state_dict().keys()
dis_params = discriminator.state_dict().keys()

for i in gen_params:
  print(i)

module.layer1.0.weight
module.layer1.0.bias
module.layer1.1.weight
module.layer1.1.bias
module.layer1.1.running_mean
module.layer1.1.running_var
module.layer1.1.num_batches_tracked
module.layer2.conv1.weight
module.layer2.conv1.bias
module.layer2.bn1.weight
module.layer2.bn1.bias
module.layer2.bn1.running_mean
module.layer2.bn1.running_var
module.layer2.bn1.num_batches_tracked
module.layer2.conv2.weight
module.layer2.conv2.bias
module.layer2.bn2.weight
module.layer2.bn2.bias
module.layer2.bn2.running_mean
module.layer2.bn2.running_var
module.layer2.bn2.num_batches_tracked
module.layer3.conv3.weight
module.layer3.conv3.bias
module.layer3.bn3.weight
module.layer3.bn3.bias
module.layer3.bn3.running_mean
module.layer3.bn3.running_var
module.layer3.bn3.num_batches_tracked
module.layer3.conv4.weight
module.layer3.conv4.bias


In [27]:
loss_func = nn.MSELoss()
gen_optim = torch.optim.Adam(generator.parameters(), lr=5*learning_rate, betas=(0.5, 0.999))
dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

ones_label = torch.ones(batch_size,1).cuda()
zeros_label = torch.zeros(batch_size,1).cuda()

def image_check(gen_fake):
  img = gen_fake.data.numpy()
  for i in range(10):
    plt.imshow(img[i][0], cmap='gray')
    plt.show()

In [None]:
for i in range(epoch):
  for j,(image,label) in enumerate(train_loader):
    image = image.cuda()




    gen_optim.zero_grad()

    z = init.normal(torch.Tensor(batch_size,100), mean=0, std=0.1).cuda()

    gen_fake = generator.forward(z)
    dis_fake = discriminator.forward(gen_fake)

    gen_loss = torch.sum(loss_func(dis_fake, ones_label))

    # gen_loss.backward(retain_variables=True)
    gen_loss.backward()
    gen_optim.step()




    dis_optim.zero_grad()

    z = init.normal(torch.Tensor(batch_size,100), mean=0, std=0.1).cuda()

    gen_fake = generator.forward(z)
    dis_fake = discriminator.forward(gen_fake)
    dis_real = discriminator.forward(image)
    dis_loss = torch.sum(loss_func(dis_fake, zeros_label)) + torch.sum(loss_func(dis_real,ones_label))
    dis_loss.backward()
    dis_optim.step()





    if j % 50 ==0:
      print(gen_loss, dis_loss)
      torch.save([generator,discriminator], './model/dcgan.pkl')

      print("{}th iteration gen_loss: {} dis_loss: {}".format(i, gen_loss.data, dis_loss.data))
      v_utils.save_image(gen_fake.data[0:25], "./result/gen_{}_{}.png".format(i,j), nrow=5)

  image_check(gen_fake.cpu())