#code referenced from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#where-to-go-next
DC-GAN tutorial from the official pytorch tutorials page

In [None]:
from dataclasses import dataclass
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from tqdm import tqdm

In [2]:
@dataclass
class Config:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    model_name = "DC_GAN"
    batch_size = 64
    epochs = 400
    learn_rate = 1e-4
    image_size = 80
    num_samples = 4  # Number of samples to generate
    load_from_checkpoint = True  # Start training from checkpoint if one exists for the given model name

In [3]:
def create_checkpoint_dir(output_folder, epoch):
    checkpoint_dir = os.path.join(output_folder, f'checkpoint_{epoch}')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    return checkpoint_dir



def save_checkpoint(generator, discriminator, g_optimizer, d_optimizer, epoch, filename):
    """
    Save GAN model checkpoints.

    Parameters:
    - generator (Module): The generator model.
    - discriminator (Module): The discriminator model.
    - g_optimizer (Optimizer): Generator's optimizer.
    - d_optimizer (Optimizer): Discriminator's optimizer.
    - epoch (int): Current epoch number.
    - filename (str): File path and name to save checkpoint without extension.
    """
    # Create checkpoint dictionary
    checkpoint = {
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'g_optimizer_state_dict': g_optimizer.state_dict(),
        'd_optimizer_state_dict': d_optimizer.state_dict(),
        'epoch': epoch
    }
    # Save checkpoint
    torch.save(checkpoint, f"{filename}.pt")


def load_checkpoint(checkpoint_path, generator, discriminator, g_optimizer, d_optimizer):
    """
    Load GAN model checkpoints.

    Parameters:
    - checkpoint_path (str): Path to the checkpoint file.
    - generator (Module): The generator model to be updated.
    - discriminator (Module): The discriminator model to be updated.
    - g_optimizer (Optimizer): Generator's optimizer to be updated.
    - d_optimizer (Optimizer): Discriminator's optimizer to be updated.

    Returns:
    - epoch (int): The epoch number to resume training from.
    """
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    # Update models and optimizers
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
    d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
    # Return the epoch number to resume training
    return checkpoint['epoch']


def get_dataloader(data_dir, batch_size, image_size):
    # Define image transformations
    # TODO: Review if this can be improved.
    transform = transforms.Compose([
        transforms.Resize(round(image_size * 5/4)),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize from [0, 1] to [-1, 1]
    ])

    # Create an instance of the ImageFolder dataset
    dataset = ImageFolder(root=data_dir, transform=transform)

    # Create a DataLoader to batch and shuffle the data
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader, dataset

In [4]:
from torch.nn import Module, Sequential, ConvTranspose2d, BatchNorm2d, ReLU, Tanh, Conv2d, Sigmoid, LeakyReLU, BCELoss

class Generator(Module):
    def __init__(self):
        super().__init__()
 
        self.gen = Sequential(
            ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
            BatchNorm2d(num_features=512),
            ReLU(inplace=True),

            ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            BatchNorm2d(num_features=256),
            ReLU(inplace=True),

            ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            BatchNorm2d(num_features=128),
            ReLU(inplace=True),

            ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            BatchNorm2d(num_features=64),
            ReLU(inplace=True),

            ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
            Tanh()
        )
 
    def forward(self, input):
        return self.gen(input)
    

# Define the Discriminator network
class Discriminator(Module):
    def __init__(self):
        super().__init__()
        self.dis = Sequential(
            # input is (3, 64, 64)
            Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            # output from above layer is b_size, 64, 32, 32
            LeakyReLU(0.2, inplace=True),
 
            Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            # output from above layer is b_size, 128, 16, 16
            BatchNorm2d(128),
            LeakyReLU(0.2, inplace=True),
 
            Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            # output from above layer is b_size, 256, 8, 8
            BatchNorm2d(256),
            LeakyReLU(0.2, inplace=True),
 
            Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            # output from above layer is b_size, 512, 4, 4
            BatchNorm2d(512),
            LeakyReLU(0.2, inplace=True),

            Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            # output from above layer is b_size, 1, 1, 1
            Sigmoid()
        )
     
    def forward(self, input):
        return self.dis(input)




In [5]:
# creating gen and disc
netG = Generator().to(Config.device)
netD = Discriminator().to(Config.device)

In [6]:


def init_weights(m):
    if type(m) == ConvTranspose2d:
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif type(m) == BatchNorm2d:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.constant_(m.bias, 0)

In [7]:
netD.apply(init_weights)
netG.apply(init_weights)

Generator(
  (gen): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [66]:
opt_D = optim.Adam(netD.parameters(), lr = Config.learn_rate, betas= (0.5, 0.999))
opt_G = optim.Adam(netG.parameters(), lr = Config.learn_rate, betas= (0.5, 0.999))
loss = BCELoss()

In [67]:
dataloader, dataset = get_dataloader('/Users/dheerajanikar/Desktop/EE_641_project/images/S2TLD_extracted/unconditional', Config.batch_size, 64)

In [None]:
nz = 100
fixed_noise = torch.randn(Config.batch_size, nz, 1, 1, device=Config.device)
real_label = 1
fake_label = 0

out_folder = '/Users/dheerajanikar/Desktop/EE_641_project/output_iter2'


import torchvision.utils as vutils


for epoch in range(Config.epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(Config.device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=Config.device)
                           
        print(len(output))

        output = netD(real_cpu)
        output = output.view(-1)
    
        errD_real = loss(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=Config.device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())

        output = output.view(-1)
        errD_fake = loss(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        opt_D.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        output = output.view(-1)
        errG = loss(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        opt_G.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, Config.epochs, i, len(dataloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    if epoch%10 == 0:
        checkpoint_dir = create_checkpoint_dir('/Users/dheerajanikar/Desktop/EE_641_project/DCGAN/', epoch)

        for image_num in range(10):

            noise = torch.randn(1, nz, 1, 1, device=Config.device)
                
            fake_image = netG(noise)

            
            image_filename = os.path.join(checkpoint_dir, f'image_epoch_{epoch}_num_{image_num}.png')
            vutils.save_image(fake_image, image_filename, normalize=True)
        vutils.save_image(real_cpu, os.path.join('/Users/dheerajanikar/Desktop/EE_641_project/DCGAN/', 'real_samples.png'), normalize=True)


        
    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (out_folder, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (out_folder, epoch))

In [20]:
state_dict_G = torch.load('/Users/dheerajanikar/Desktop/EE_641_project/output_iter2/netG_epoch_199.pth')
state_dict_D = torch.load('/Users/dheerajanikar/Desktop/EE_641_project/output_iter2/netD_epoch_199.pth')
netG.load_state_dict(state_dict_G)
netD.load_state_dict(state_dict_D)

<All keys matched successfully>

In [21]:
netG

Generator(
  (gen): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [22]:
nz = 100
fixed_noise = torch.randn(Config.batch_size, nz, 1, 1, device=Config.device)
real_label = 1
fake_label = 0

out_folder = '/Users/dheerajanikar/Desktop/EE_641_project/output_iter2'


import torchvision.utils as vutils
errD_fake_list = []
errD_real_list = []
errG_list = []

for epoch in range(200,Config.epochs+200):
    print(epoch)
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(Config.device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=Config.device)
                           


        output = netD(real_cpu)
        output = output.view(-1)
    
        errD_real = loss(output, label)
        errD_real_list.append(errD_real)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=Config.device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())

        output = output.view(-1)
        errD_fake = loss(output, label)
        errD_fake_list.append(errD_fake)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        opt_D.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        output = output.view(-1)
        errG = loss(output, label)
        errG_list.append(errG)
        errG.backward()
        D_G_z2 = output.mean().item()
        opt_G.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, Config.epochs+200, i, len(dataloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    if epoch%10 == 0:
        checkpoint_dir = create_checkpoint_dir('/Users/dheerajanikar/Desktop/EE_641_project/DCGAN/', epoch)

        for image_num in range(10):

            noise = torch.randn(1, nz, 1, 1, device=Config.device)
                
            fake_image = netG(noise)

            
            image_filename = os.path.join(checkpoint_dir, f'image_epoch_{epoch}_num_{image_num}.png')
            vutils.save_image(fake_image, image_filename, normalize=True)
        vutils.save_image(real_cpu, os.path.join('/Users/dheerajanikar/Desktop/EE_641_project/DCGAN/', 'real_samples.png'), normalize=True)


        
    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (out_folder, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (out_folder, epoch))

200
[200/400][0/130] Loss_D: 0.0089 Loss_G: 6.3659 D(x): 0.9969 D(G(z)): 0.0058 / 0.0031
[200/400][1/130] Loss_D: 0.0082 Loss_G: 6.5542 D(x): 0.9978 D(G(z)): 0.0059 / 0.0035
[200/400][2/130] Loss_D: 0.1664 Loss_G: 2.2243 D(x): 0.8635 D(G(z)): 0.0039 / 0.2425
[200/400][3/130] Loss_D: 0.4574 Loss_G: 12.4607 D(x): 0.9999 D(G(z)): 0.2866 / 0.0001
[200/400][4/130] Loss_D: 0.4355 Loss_G: 4.8588 D(x): 0.7440 D(G(z)): 0.0009 / 0.0251
[200/400][5/130] Loss_D: 0.0707 Loss_G: 3.0195 D(x): 0.9676 D(G(z)): 0.0304 / 0.1042
[200/400][6/130] Loss_D: 0.2406 Loss_G: 8.5574 D(x): 0.9944 D(G(z)): 0.1637 / 0.0004
[200/400][7/130] Loss_D: 0.1501 Loss_G: 5.2977 D(x): 0.8934 D(G(z)): 0.0156 / 0.0195
[200/400][8/130] Loss_D: 0.1106 Loss_G: 6.2684 D(x): 0.9743 D(G(z)): 0.0633 / 0.0048
[200/400][9/130] Loss_D: 0.0317 Loss_G: 6.2266 D(x): 0.9867 D(G(z)): 0.0172 / 0.0052
[200/400][10/130] Loss_D: 0.0593 Loss_G: 6.0545 D(x): 0.9642 D(G(z)): 0.0049 / 0.0080
[200/400][11/130] Loss_D: 0.0392 Loss_G: 5.7482 D(x): 0.991

In [34]:
errG_list[0]

tensor(6.3659, device='mps:0', grad_fn=<BinaryCrossEntropyBackward0>)

In [15]:
state_dict_G = torch.load('/Users/dheerajanikar/Downloads/netG_epoch_106.pth')
state_dict_D = torch.load('/Users/dheerajanikar/Desktop/EE_641_project/output_iter1/netD_epoch_199.pth')
netG.load_state_dict(state_dict_G)
netD.load_state_dict(state_dict_D)

<All keys matched successfully>

In [14]:
noise = torch.randn(1, 100, 1, 1, device=Config.device)
                
fake_image = netG(noise)

In [20]:

import torchvision.utils as vutils
checkpoint_dir = create_checkpoint_dir('/Users/dheerajanikar/Desktop',10)

for image_num in range(10):

    noise = torch.randn(1, 100, 1, 1, device=Config.device)
                
    fake_image = netG(noise)

            
    image_filename = os.path.join(checkpoint_dir, f'image_num_{image_num}.png')
    vutils.save_image(fake_image, image_filename, normalize=True)