# Library

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from pydantic import BaseModel
import wandb


In [14]:
class Parameters(BaseModel):
    # Hyperparameters etc.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    lr = 2e-4
    z_dim = 64
    image_size = 32  # 32*32
    # image_channel = 1
    batch_size = 10
    num_epochs = 30

P = Parameters()

wandb.init(project="gan_training", save_code=True,config=P.dict())

$$
W_{new} = {W_{o} - Kernelsize + (2 \times Padding) \over Stride} +1
$$

In [15]:
#鑑別器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential( #<--32
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), #-->(32-4+2)/2 +1 = 16
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),#-->(16-4+2)/2 +1= 8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),#-->(8-4+2)/2 +1= 4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0),#-->(4-4)/1 +1= 1
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# ConvTranspose2d
output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding

In [16]:
#生成器
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(  #<--1
            nn.ConvTranspose2d(z_dim, 256, kernel_size=4, stride=1, padding=0), #-->4
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),#-->8
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1),#-->16
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),#-->32
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img



In [17]:
transforms = transforms.Compose(
    [
        transforms.Resize((P.image_size, P.image_size)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.ImageFolder(root='./dataset/train_images', transform=transforms)
dataloader = DataLoader(dataset, batch_size=P.batch_size, shuffle=True)  #每次訓練數量 = Data size / Batch size

# 初始化生成器和鑑別器
generator = Generator(P.z_dim).to(P.device)
discriminator = Discriminator().to(P.device)

# 定義損失函數和優化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=P.lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=P.lr, betas=(0.5, 0.999))

step=0


In [18]:
# 訓練GAN
for epoch in range(P.num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)
        real_images = real_images.to(P.device)

        #隨機雜訊->生成器->假圖像
        z = torch.randn(batch_size, P.z_dim, 1, 1).to(P.device)
        fake_images = generator(z)

        # 訓練鑑別器_真實圖像
        optimizer_D.zero_grad()
        real_output = discriminator(real_images)
        real_loss = criterion(real_output,  torch.ones_like(real_output))

        # 訓練鑑別器_假圖像
        fake_output = discriminator(fake_images.detach())
        fake_loss = criterion(fake_output, torch.zeros_like(fake_output))


        # 總鑑別器損失
        loss_D = (real_loss + fake_loss) / 2
        # 更新鑑別器的權重
        loss_D.backward()
        optimizer_D.step()


        # 雜訊->[生成器]->假圖像->[鑑別器]-> 機率
        # 訓練生成器
        optimizer_G.zero_grad()
        fake_output = discriminator(fake_images)
        loss_G = criterion(fake_output, torch.ones_like(fake_output))

        # 更新生成器的權重
        loss_G.backward()
        optimizer_G.step()


        wandb.log({"Loss_D": loss_D.item(), "Loss_G": loss_G.item()}, step=step)

        # 顯示訓練進度
        if (i + 1) % 15 == 0:
            print(f"Epoch [{epoch + 1}/{P.num_epochs}], Batch {i+1}/{len(dataloader)}, Loss D: {loss_D.item()}, Loss G: {loss_G.item()}")

        if step % 200 == 0:
            generated_images = [wandb.Image(image) for image in fake_images]
            wandb.log({"Generated Images": generated_images}, step=step)


        step = (epoch * len(dataloader) + i)


wandb.finish()


Epoch [1/30], Batch 15/100, Loss D: 0.05296773836016655, Loss G: 4.11497688293457
Epoch [1/30], Batch 30/100, Loss D: 0.027065562084317207, Loss G: 4.728980541229248
Epoch [1/30], Batch 45/100, Loss D: 0.004367826506495476, Loss G: 5.584713935852051
Epoch [1/30], Batch 60/100, Loss D: 0.00414587277919054, Loss G: 5.7506632804870605
Epoch [1/30], Batch 75/100, Loss D: 0.0020484505221247673, Loss G: 6.435426235198975
Epoch [1/30], Batch 90/100, Loss D: 0.0014907813165336847, Loss G: 6.783929347991943
Epoch [2/30], Batch 15/100, Loss D: 0.001006641541607678, Loss G: 6.898268222808838
Epoch [2/30], Batch 30/100, Loss D: 0.0012075353879481554, Loss G: 7.024201393127441
Epoch [2/30], Batch 45/100, Loss D: 0.0005432592006400228, Loss G: 7.40516996383667
Epoch [2/30], Batch 60/100, Loss D: 0.0005184882902540267, Loss G: 7.430285930633545
Epoch [2/30], Batch 75/100, Loss D: 0.00037633703323081136, Loss G: 7.673314094543457
Epoch [2/30], Batch 90/100, Loss D: 0.00037225382402539253, Loss G: 7.61

0,1
Loss_D,▁▁▁▁▁▁▁▂▆▄▃▃▄▅▃▂▂▂▂▂▂▂▅▆▁▂▁▂▄▃▂█▂▁▂▁▁▁▁▂
Loss_G,▄▆▇▇██▇▆▂▁▄▂▆▁▂▄▃▄▄▃▃▃▄█▄▄▃▄▃▃▄▃▃▄▅▄▄▄▅▅

0,1
Loss_D,0.1178
Loss_G,5.03343
