In [None]:
import torch
from torch import nn

import math
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets as torch_dataset
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import make_grid
import numpy as np

In [None]:
device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
img_size = 64
data_dir = 'C:/HUST/20222/LAB/TestGAN/Vule'
img_size =64
transforms = T.Compose([
    T.Resize(img_size),
    T.CenterCrop(img_size),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

vule_dataset = torch_dataset.ImageFolder(root=data_dir, transform=transforms)
dataloader = DataLoader(dataset=vule_dataset, batch_size=25, shuffle=True, num_workers=4)


In [None]:
img_batch = next(iter(dataloader))[0]
combine_img = make_grid(img_batch[:32], normalize=True, padding=2).permute(1,2,0)
plt.figure(figsize=(15,15))
plt.imshow(combine_img)
plt.show()

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(12288, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), 12288)
        output = self.model(x)
        return output

In [None]:
discriminator = Discriminator().to(device=device)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(12288, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 12288),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 3, 64, 64)
        return output

generator = Generator().to(device=device)

In [None]:
batch_size=25
lr = 0.0001
num_epochs = 500
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
for epoch in range(num_epochs):
    for n, (real_samples, mnist_labels) in enumerate(dataloader):
        img_list = []
        # Data for training the discriminator
        real_samples = real_samples.to(device=device)
        real_samples_labels = torch.ones((batch_size, 1)).to(
            device=device
        )
        latent_space_samples = torch.randn((batch_size, 12288)).to(
            device=device
        )
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1)).to(
            device=device
        )
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels
        )
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 12288)).to(
            device=device
        )

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        if epoch %25 ==0:   
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")
            generated_samples = generator(latent_space_samples).detach().cpu()
            fake_img = make_grid(generated_samples, padding=2, normalize=True)
            img_list.append(fake_img)
            plt.figure(figsize=(10,10))
            plt.imshow(img_list[-1].permute(1,2,0))
            plt.show()

In [None]:
generated_samples = generator(latent_space_samples).detach().cpu()
fake_img = make_grid(generated_samples, padding=2, normalize=True)
img_list.append(fake_img)
plt.figure(figsize=(10,10))
plt.imshow(img_list[-1].permute(1,2,0))
plt.show()