Three Hidden Layers in both Classical Discriminator and Quantum Generator

main_lat_10_mid_4_8_16_size_16_13_10_4_1_class0

In [1]:
import pennylane as qml
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torchsummary import summary


import numpy as np
import math
import random
import os
import sys

sys.path.append('./QuantumGAN_utils')
from DataLoading import MNIST_DataLoading
from new_QGAN_threeMid_add_lastBN import QuantumGenerator

In [2]:
%matplotlib inline
torch.set_printoptions(profile="full")


# Set random seed for reproducibility
mySeed = 100
random.seed(mySeed)
np.random.seed(mySeed)
torch.manual_seed(mySeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

# Setting & Hyperparameter

In [3]:
epoch_size = 20 
batch_size = 64

lr_D = 0.0005
lr_G = 0.005


KernelAncilla = 0
N_layer = 10


# Classical Discriminator
final_dim = 1
final_size = (1,1)


# Quantum Generator

latent_dim = 10
latent_size = (1,1)
kernel_size_4 = (4,4)


middle_dim_3 = 16
middle_size_3 = (4,4)
kernel_size_3 = (7,7)


middle_dim_2 = 8
middle_size_2 = (10,10)
kernel_size_2 = (4,4)


middle_dim_1 = 4
middle_size_1 = (13,13)
kernel_size_1 = (4,4)


image_dim = 1
image_size = (16,16)




In [4]:
device = "cpu"

saving_path = "QGAN_ClassicalDis_lat_{0}_size_16_13_10_4_1".format(latent_dim)


image_file = "./{0}/image/mid_{1}_{2}_{3}".format(saving_path, middle_dim_1, middle_dim_2, middle_dim_3)
ckpt_file = "./{0}/ckpt/mid_{1}_{2}_{3}".format(saving_path, middle_dim_1, middle_dim_2, middle_dim_3)

os.makedirs(image_file, exist_ok=True)
os.makedirs(ckpt_file, exist_ok=True)



discriminator_tensor = "./{0}/discriminator_tensor/mid_{1}_{2}_{3}".format(saving_path, middle_dim_1, middle_dim_2, middle_dim_3)
generator_tensor = "./{0}/generator_tensor/mid_{1}_{2}_{3}".format(saving_path, middle_dim_1, middle_dim_2, middle_dim_3)

os.makedirs(discriminator_tensor, exist_ok=True)
os.makedirs(generator_tensor, exist_ok=True)

# DataLoading

In [5]:
train_loader, test_loader , train_dataset, test_dataset = MNIST_DataLoading(saving_location="./mnist", image_size=image_size[0], 
                                                                            batch_size=batch_size, data_label=[0])

# Model Preparation

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.main = nn.Sequential(
            nn.Conv2d(1, middle_dim_1, 4, 1, 0, bias=False),
            nn.BatchNorm2d(middle_dim_1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(middle_dim_1, middle_dim_2, 4, 1, 0, bias=False),
            nn.BatchNorm2d(middle_dim_2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(middle_dim_2, middle_dim_3, 7, 1, 0, bias=False),
            nn.BatchNorm2d(middle_dim_3),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(middle_dim_3, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):

        return self.main(input)
    

In [7]:
# Model
# -------------------------------- Discriminator: 1*16*16 --> 4*13*13 --> 8*10*10 --> 16*4*4 --> 1*1*1 --------------------------------
discriminator = Discriminator().to(device)


# -------------------------------- Generator: 10*1*1 --> 16*4*4 --> 8*10*10 --> 4*13*13 --> 1*16*16 --------------------------------
generator = QuantumGenerator(KernelAncilla=KernelAncilla, layer_num=N_layer, 
                             stage_1_dim=latent_dim, stage_1_input_size=latent_size, 
                             stage_1_kernel_size=kernel_size_4,stage_1_stride=1, stage_1_padding=3,
                             stage_2_dim=middle_dim_3, stage_2_input_size=middle_size_3, 
                             stage_2_kernel_size=kernel_size_3,stage_2_stride=1, stage_2_padding=6,  
                             stage_3_dim=middle_dim_2, stage_3_input_size=middle_size_2, 
                             stage_3_kernel_size=kernel_size_2,stage_3_stride=1, stage_3_padding=3, 
                             stage_4_dim=middle_dim_1, stage_4_input_size=middle_size_1, 
                             stage_4_kernel_size=kernel_size_1,stage_4_stride=1, stage_4_padding=3, 
                             final_dim=image_dim, final_size=image_size).to(device)




In [8]:
print(discriminator)
summary(discriminator, (1,16,16))

print(generator)
summary(generator, (latent_dim,1,1))

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 4, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(4, 8, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(8, 16, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(16, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (10): Sigmoid()
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 4, 13, 13]              64
       BatchNorm2d-2            [-1, 4, 13

In [9]:
# discriminator.load_state_dict(torch.load('./ckpt/mid_4/1/discriminator_8_0.pt'))
# generator.load_state_dict(torch.load('./ckpt/mid_4/1/generator_8_0.pt'))

In [10]:
# Loss function
criterion = nn.BCELoss()

# Optimizer
b1 = 0.5
b2 = 0.999
opt_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(b1, b2))
opt_G = optim.Adam(generator.parameters(), lr=lr_G, betas=(b1, b2))

# Label
real_label = torch.full((batch_size,), 1.0, dtype=torch.float32).to(device)
fake_label = torch.full((batch_size,), 0.0, dtype=torch.float32).to(device)

# Testing Noise
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device).to(device)

# Training

In [11]:
D_train_loss_saving = []
G_train_loss_saving = []


for epoch in range(epoch_size):
    for batch_index, (real_image, _) in enumerate(train_loader):

        # ---------------- Update Discriminator, Fixed Generator ----------------

        opt_D.zero_grad()

        # create real & fake images, then discriminate them
        real_image = real_image.to(device)               # (batch_size, 1, image_size[0], image_size[1])
        output_real_after_sigmoid = discriminator(real_image) 
        output_real = output_real_after_sigmoid.view(-1) # (batch_size, 1, 1) --- view(-1) ---> (batch_size,)

        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device).to(device)
        fake_image, tensor_beforeTanh_BN, fake_tensor_afterBN = generator(noise)
        output_fake_after_sigmoid = discriminator(fake_image.detach())
        output_fake = output_fake_after_sigmoid.view(-1)

        # compute the loss of output_real with real_label & output_fake with fake_label
        loss_real = criterion(output_real, real_label)
        loss_fake = criterion(output_fake, fake_label)


        loss_real.backward()
        loss_fake.backward()

        loss_D = loss_real + loss_fake

        opt_D.step()

        # ---------------- Update Generator, Fixed Discriminator ----------------

        opt_G.zero_grad()

        # create fake images, then discriminate them
        output_fake_after_sigmoid = discriminator(fake_image)
        output_fake = output_fake_after_sigmoid.view(-1)

        # compute the loss of output_fake with real_label
        loss_G = criterion(output_fake, real_label)

        loss_G.backward()

        opt_G.step()

        if batch_index % 10 == 0:
            print(f"[Epoch {epoch}/{epoch_size}] [Batch {batch_index}/{len(train_loader)}] [D loss: {loss_D.item()}] [G loss: {loss_G.item()}]")

            D_train_loss_saving.append(loss_D.item())
            G_train_loss_saving.append(loss_G.item())

            fixed_image, fake_tensor_beforeTanh_BN, fake_tensor_afterBN  = generator(fixed_noise)



            save_image(fixed_image, os.path.join(image_file, '{0}_{1}.png'.format(epoch,batch_index)), nrow=8, normalize=True)


            torch.save(discriminator.state_dict(), os.path.join(ckpt_file, 'discriminator_{0}_{1}.pt'.format(epoch,batch_index)))
            torch.save(generator.state_dict(), os.path.join(ckpt_file, 'generator_{0}_{1}.pt'.format(epoch,batch_index)))


            torch.save(output_real_after_sigmoid, os.path.join(discriminator_tensor, 'output_real_after_sigmoid_{0}_{1}.pt'.format(epoch,batch_index)))
            torch.save(output_fake_after_sigmoid, os.path.join(discriminator_tensor, 'output_fake_after_sigmoid_{0}_{1}.pt'.format(epoch,batch_index)))


            torch.save(fake_tensor_beforeTanh_BN, os.path.join(generator_tensor, 'fake_tensor_beforeTanh_BN_{0}_{1}.pt'.format(epoch,batch_index)))
            torch.save(fake_tensor_afterBN, os.path.join(generator_tensor, 'fake_tensor_afterBN_{0}_{1}.pt'.format(epoch,batch_index)))
            torch.save(fixed_image, os.path.join(generator_tensor, 'fake_tensor_afterTanh_{0}_{1}.pt'.format(epoch,batch_index)))
            

            print("saved images and state")





[Epoch 0/20] [Batch 0/92] [D loss: 1.478825330734253] [G loss: 0.7675439119338989]
saved images and state
[Epoch 0/20] [Batch 10/92] [D loss: 0.9659445285797119] [G loss: 0.8994162678718567]
saved images and state
[Epoch 0/20] [Batch 20/92] [D loss: 0.9473181366920471] [G loss: 0.9160370826721191]
saved images and state
[Epoch 0/20] [Batch 30/92] [D loss: 0.9179482460021973] [G loss: 1.1494485139846802]
saved images and state
[Epoch 0/20] [Batch 40/92] [D loss: 1.1194690465927124] [G loss: 0.9893391132354736]
saved images and state
[Epoch 0/20] [Batch 50/92] [D loss: 0.9693689942359924] [G loss: 1.1957731246948242]
saved images and state
[Epoch 0/20] [Batch 60/92] [D loss: 0.9932500123977661] [G loss: 1.2388792037963867]
saved images and state
[Epoch 0/20] [Batch 70/92] [D loss: 0.816138744354248] [G loss: 1.44775390625]
saved images and state
[Epoch 0/20] [Batch 80/92] [D loss: 0.8229727745056152] [G loss: 1.430643081665039]
saved images and state
[Epoch 0/20] [Batch 90/92] [D loss: 0

In [12]:
# Deeper Structure
# Run ??-hr/2787-min on mac mini (now)
# 17-min 10-batch (now)
# 11/6 17:34 ~ 11/9 15:58 

torch.save(D_train_loss_saving, os.path.join(ckpt_file, 'D_train_loss_saving.pt'))
torch.save(G_train_loss_saving, os.path.join(ckpt_file, 'G_train_loss_saving.pt'))
