# GANs for Image Generation tasks

## 2. Conditional GANs - AC-GAN

### Prepare DataLoader for MNIST dataset

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

# fix manual seed.
torch.manual_seed(1234)

# set batch size.
BATCH_SIZE = 256

# prepare dataloader.
tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = MNIST(root='./datasets', train=True, download=False, transform=tf)
loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

#### Define Generator

In [2]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # ============================================================ #
        # TODO : Fill fully connected layers upconvolution layers
        # * Specificl details of model architectures are on the slides. 
        # * Hint : Use following functions : 
        #   nn.Linear(), nn.BatchNorm1d(), nn.ConvTranspose2d(), 
        #   nn.BatchNorm2d(), nn.ReLU()
        # ============================================================ #

        self.z_dim = 64
        self.num_class = 10
        self.hidden_dim = 256
        self.img_dim = 28 * 28

        self.fc = nn.Sequential(
            # Fill here.
            nn.Linear(self.z_dim + self.num_class, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024,128*7*7),
            nn.BatchNorm1d(128*7*7),
            nn.ReLU()    
        )
        
        self.upconv = nn.Sequential(
            # Fill here.
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()    
        )
        
    def forward(self, z, y):
        # ============================================================ #
        # TODO : Complete forward function. 
        # * Hint : Use self.fc and self.upconv defined above
        # ============================================================ #

        # Fill here. 
        out = torch.cat((z,y), dim=1)
        out = self.fc(out)
        out = out.view(-1, 128, 7, 7)
        out = self.upconv(out)
        return out



#### Define Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # ============================================================ #
        # TODO : Fill convolution layers and fully connected layers.
        # * Specificl details of model architectures are on the slides. 
        # * Hint : Use following functions : 
        #   nn.Conv2d(), nn.LeaklyReLU(), nn.Linear(), nn.BatchNorm1d(),
        #   nn.Sigmoid()
        # ============================================================ #
        
        self.img_dim = 28 * 28
        self.hidden_dim = 256
        self.num_class = 10

        self.conv = nn.Sequential(
            # Fill here.
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)    

        )
        self.fc = nn.Sequential(
            # Fill here.
            nn.Linear(128*7*7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2)    

        )
        self.fc_disc = nn.Sequential(
            # Fill here. 
            nn.Linear(1024, 1),
            nn.Sigmoid()   

        )
        
        self.fc_cls = nn.Sequential(
            # Fill here. 
            nn.Linear(1024, 10)
        
        )

    def forward(self, x):
        # ============================================================ #
        # TODO : Complete forward function. 
        # * Hint : Use self.fc and self.upconv defined above
        #    - out_disc : head for real/fake discrimination 
        #    - out_cls : head for classification
        # ============================================================ #

        # Fill here.   
        out = self.conv(x)
        out = out.view(-1, 128 * 7 * 7)
        out = self.fc(out)
        out_disc = self.fc_disc(out) 
        out_cls = self.fc_cls(out)

        return out_disc, out_cls


#### Prepare GAN model and Optimizers

In [4]:
# weight initialization function. 
def weights_init(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

# define GAN model.
G = Generator().cuda()
D = Discriminator().cuda()

# weight initialization & set both modes to train mode.
G.apply(weights_init)
D.apply(weights_init)

# define optimizer. Here we use Adam optimizer. 
optimizer_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))


#### Start training GAN

In [5]:
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

# Hyper-parameters. 
# ====== You don't need to change here ===== #
EPOCHS = 50
Z_DIM = 64
NUM_CLASS = 10
# ========================================== #

# logger for tensorboard.
logger = SummaryWriter()

# Fixed latent variable z, label y for visualization. 
FIXED_Z = torch.randn(size=(100,Z_DIM)).cuda()
FIXED_Y = torch.arange(10).repeat(10)
FIXED_Y = torch.zeros(size=(100,NUM_CLASS)).scatter_(1, FIXED_Y.unsqueeze(1), 1).cuda()

# GT labels for calculating binary cross entropy loss. 
real_label = torch.ones(size=(BATCH_SIZE,1)).cuda()
fake_label = torch.zeros(size=(BATCH_SIZE,1)).cuda()

# criterion for binary cross entropy loss
BCE_criterion = torch.nn.BCELoss()
CE_criterion = torch.nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    # Set both models to train modes.
    G.train()
    D.train()

    # For logging in tensorboard
    loss_G_total, loss_D_total = 0., 0.

    for batch_idx, (data, label) in enumerate(loader):
        data = data.cuda()
        label = label.cuda()
        
        # ============================================================ #
        # TODO : Fill the part for updating D&G.
        # First sample z and y. 
        # z : (BATCH_SIZE, Z_DIM) size random latent variable
        # y : (BATCH_SIZE, NUM_CLASS) size random label
        # Then Calculate GAN loss (loss_D, loss_G)
        # * Don't forget, you should also consider classification loss!!!     
        # ============================================================ #
        z = torch.randn(size=(BATCH_SIZE, Z_DIM)).cuda()
        y = torch.zeros(size=(BATCH_SIZE, NUM_CLASS)).long().cuda()
        y = y.scatter_(1, label.unsqueeze(1), 1)

        fake_img = G(z,y).detach()
        real_img = data

        # ================= Update D ================== # 
                       
        # Fill here. 
        # First compute loss_D 
        # Then update the network with loss_D using optimizer_D
        logit_disc_real, logit_cls_real = D(real_img)
        logit_disc_fake, logit_cls_fake = D(fake_img)
        disc_real = BCE_criterion(logit_disc_real, real_label)
        disc_fake = BCE_criterion(logit_disc_fake, fake_label)
        cls_real = CE_criterion(logit_cls_real,label)
        cls_fake = CE_criterion(logit_cls_fake,label)
        loss_D = disc_real + disc_fake + cls_real + cls_fake

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # ================= Update G ================== # 
 
        # Fill here. 
        z = torch.randn(size=(BATCH_SIZE, Z_DIM)).cuda()
        fake_img = G(z,y)
        # First compute loss_G 
        logit_disc_fake, logit_cls_fake = D(fake_img)
        disc_fake = BCE_criterion(logit_disc_fake, real_label)
        cls_fake = CE_criterion(logit_cls_fake,label)
        loss_G = disc_fake + cls_fake
        # Then update the network with loss_G using optimizer_G.
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()   
        # Note that we need additional auxiliary classification loss.


        loss_D_total += loss_D.item()
        loss_G_total += loss_G.item()
        
        # print current states
        if batch_idx % 100  == 0:
            print('Epoch : {} || {}/{} || loss_G={:.3f} loss_D={:.3f}'.format(
                epoch, batch_idx, len(loader), loss_G.item(), loss_D.item()
            ))

    loss_G_total /= len(loader)
    loss_D_total /= len(loader)

    # ================= Genearte example samples ================== # 
    fake_img = G(FIXED_Z, FIXED_Y)
    fake_img = fake_img.view(fake_img.shape[0], 1, 28, 28)
    fake_img = (fake_img + 1)*0.5
    fake_img = make_grid(fake_img, nrow=10)

    
    # ============================================================ #
    # TODO : Logging on the tensorboard
    # * log loss_G_total, loss_D_total, and fake_img
    # * use logger.add_scalar() and logger.add_image() for logging
    # ============================================================ #

    # Fill here
    logger.add_scalar('loss_G', loss_G_total, epoch)
    logger.add_scalar('loss_D', loss_D_total, epoch)
    logger.add_image('Generated_samples', fake_img, epoch)

    # print current states
    print('Epoch : {} has done. AVG loss : loss_G={:.3f} loss_D={:.3f}'.format(
        epoch, loss_G_total, loss_D_total
    ))


Epoch : 0 || 0/234 || loss_G=3.057 loss_D=6.106
Epoch : 0 || 100/234 || loss_G=1.001 loss_D=1.260
Epoch : 0 || 200/234 || loss_G=1.253 loss_D=0.925
Epoch : 0 has done. AVG loss : loss_G=1.163 loss_D=1.404
Epoch : 1 || 0/234 || loss_G=1.269 loss_D=0.868
Epoch : 1 || 100/234 || loss_G=1.349 loss_D=0.889
Epoch : 1 || 200/234 || loss_G=1.134 loss_D=1.036
Epoch : 1 has done. AVG loss : loss_G=1.306 loss_D=0.964
Epoch : 2 || 0/234 || loss_G=1.053 loss_D=1.127
Epoch : 2 || 100/234 || loss_G=1.108 loss_D=1.093
Epoch : 2 || 200/234 || loss_G=0.989 loss_D=1.194
Epoch : 2 has done. AVG loss : loss_G=1.056 loss_D=1.150
Epoch : 3 || 0/234 || loss_G=1.090 loss_D=1.125
Epoch : 3 || 100/234 || loss_G=1.012 loss_D=1.252
Epoch : 3 || 200/234 || loss_G=0.937 loss_D=1.171
Epoch : 3 has done. AVG loss : loss_G=0.984 loss_D=1.184
Epoch : 4 || 0/234 || loss_G=1.013 loss_D=1.149
Epoch : 4 || 100/234 || loss_G=0.956 loss_D=1.185
Epoch : 4 || 200/234 || loss_G=0.916 loss_D=1.241
Epoch : 4 has done. AVG loss : l

In [7]:
# Check Tensorboard.
%ls runs
%load_ext tensorboard
%tensorboard --logdir runs --port 8888 --samples_per_plugin images=100

 C ����̺��� �������� �̸��� �����ϴ�.
 ���� �Ϸ� ��ȣ: 32F1-D8BD

 c:\Users\AI_15\momo\python_prj\CV_Generative_models\Lab2 GAN\runs ���͸�

2022-07-12  ���� 04:31    <DIR>          .
2022-07-12  ���� 04:31    <DIR>          ..
2022-07-12  ���� 02:59    <DIR>          Jul12_14-59-27_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:39    <DIR>          Jul12_15-39-13_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:39    <DIR>          Jul12_15-39-55_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:40    <DIR>          Jul12_15-40-20_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:43    <DIR>          Jul12_15-43-43_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:44    <DIR>          Jul12_15-44-27_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:44    <DIR>          Jul12_15-44-55_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:46    <DIR>          Jul12_15-46-30_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:53    <DIR>          Jul12_15-53-54_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:55    <DIR>          Jul12_15-55-48_LAPTOP-ASLN0PJR
2022-07-12  ���� 03:58    <DIR>          Jul12_15-58-49_L

Reusing TensorBoard on port 8888 (pid 23112), started 0:06:50 ago. (Use '!kill 23112' to kill it.)