In [1]:
%run Config.ipynb
%run Discriminator.ipynb
%run Generator.ipynb

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [3]:
def gradient_penalty(critic, labels,real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(Config.device)
    interpolated_images = real * alpha + fake * (1 - alpha)
    mixed_scores = critic(interpolated_images, labels)
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,)[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [4]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


In [5]:
transforms = transforms.Compose([
    transforms.Resize((Config.IMAGE_SIZE,Config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(Config.CHANNELS_IMG)],[0.5 for _ in range(Config.CHANNELS_IMG)]),
])

In [6]:
dataset = datasets.ImageFolder(root="C:\\Users\\User\\pythonProject\\GAN\\DCGAN\\Doctorimage", transform=transforms)
loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=True)

In [7]:
gen = Generator(Config.Z_DIM, Config.CHANNELS_IMG, Config.FEATURES_GEN, Config.NUM_CLASSES,Config.IMAGE_SIZE, Config.GEN_EMBEDDIN).to(Config.device)
disc = Discriminator(Config.CHANNELS_IMG, Config.FEATURES_GEN, Config.NUM_CLASSES, Config.IMAGE_SIZE).to(Config.device)
initialize_weights(gen)
initialize_weights(disc)

In [8]:
opt_gen = optim.Adam(gen.parameters(), lr=Config.LEARNING_RATE, betas=(0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr=Config.LEARNING_RATE, betas=(0.0, 0.9))
step = 0
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(4, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
  )
  (embed): Embedding(10, 4096

In [9]:
for epoch in range(Config.NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(tqdm(loader)):
        real = real.to(Config.device)
        cur_batch_size = real.shape[0]
        labels = labels.to(Config.device)

        #max E[critic(real)] - E[critic(fake)]
        for _ in range(Config.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Config.Z_DIM, 1, 1).to(Config.device)
            fake = gen(noise, labels)
            disc_real = disc(real, labels).reshape(-1)
            disc_fake = disc(fake, labels).reshape(-1)
            gp = gradient_penalty(disc, labels, real, fake, device=Config.device)
            loss_disc = (
                -(torch.mean(disc_real) - torch.mean(disc_fake)) + Config.LAMBDA_GP * gp
            )
            disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            opt_disc.step()

        #min -E[critic(gen_fake)]
        gen_fake = disc(fake, labels).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

 67%|██████▋   | 2/3 [00:52<00:26, 26.43s/it]


KeyboardInterrupt: 