In [1]:
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets

from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from pydantic import BaseModel

import wandb

## Set Hyper Parameters

In [2]:
class Settings(BaseModel):
    project_name: str = 'jp50_02'
    device: str = 'cuda'
    epoch: int = 200
    batch: int = 8
    learning_rate: float = 2e-4
    image_size: int = 64
    sample_interval: int = 500

    # Size of z latent vector (i.e. size of generator input)
    nz = 100
    # Size of feature maps in generator
    ngf = 16
    # Number of channels in the training images. For color images this is 3
    nc = 1
    # Size of feature maps in discriminator
    ndf = 64
    # Beta1 hyperparam for Adam optimizers
    beta1 = 0.5
    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

## Define Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf,kernel_size= 4,stride= 2,padding= 1, bias=False),#(64-4+2)/2=31+1=32
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),#(32-4+2)/2=31+1=16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),#(16-4+2)/2=8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),#(8-4+2)/2  +1=4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),# 4-4
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

## Define Generator

In [4]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, x):
        return self.main(x)

## Define Dataset (Data Loader)

- Here use pytorch built-in function

In [5]:
from torchvision.datasets import mnist
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 調整大小為 32x32
    transforms.Grayscale(num_output_channels=1),  # 轉為單通道灰度圖像
    transforms.ToTensor(),  # 轉為張量
    transforms.Normalize((0.5,), (0.5,))  # 歸一化
])


## Build model strategy

In [6]:
class DCGAN():
    def __init__(self) -> None:
        self.args = Settings()

        wandb.init(project=self.args.project_name, config=self.args.dict(), save_code=True)

    def load_dataset(self):
        self.dataset = datasets.ImageFolder(root='./dataset/train_images', transform=transform) # datasets.ImageFolder 會將資料夾中的圖片依照資料夾名稱分類
        #訓練集
        self.loader = DataLoader(self.dataset, batch_size=self.args.batch, shuffle=True)  #每次訓練數量 = Data set size(*0.8) / Batch size = 800/20 = 40



    def load_model(self):
        self.g_model, self.g_loss, self.g_optim = self.generator()
        self.d_model, self.d_loss, self.d_optim = self.discriminator()

    def generator(self):
        model = Generator(nc=self.args.nc, ngf=self.args.ngf, nz=self.args.nz).to(self.args.device)
        model.apply(self.weights_init)
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=self.args.learning_rate, betas=(self.args.beta1, 0.999))
        return model, criterion, optimizer

    def discriminator(self):
        model = Discriminator(nc=self.args.nc, ndf=self.args.ndf).to(self.args.device)
        model.apply(self.weights_init)
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=self.args.learning_rate, betas=(self.args.beta1, 0.999))
        return model, criterion, optimizer

    def weights_init(self, m):
        # custom weights initialization called on netG and netD
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def train_one_epoch(self, i_epoch: int):
        self.metric = {
            'train_d': {},
            'train_g': {},
        }
        self.i_epoch = i_epoch
        self.d_model.train(mode=True)
        self.g_model.train(mode=True)
        bar = tqdm(self.loader, unit='batch', leave=True)
        for i_batch, (data, label) in enumerate(bar):
            self.step = (i_epoch * len(self.loader) + i_batch)
            data = data.to(self.args.device)

            fake = self.train_d(data)

            self.train_g(fake)

            self.show_result(fake)

            loss = {
                'd_loss': self.metric['train_d']['loss'][-1],
                'g_loss': self.metric['train_g']['loss'][-1],
            }
            wandb.log({**loss, **{'i_epoch': self.i_epoch, 'step': self.step, }}, step=self.step)

            bar.set_description(f'Epoch [{self.i_epoch + 1}/{self.args.epoch}]')
            bar.set_postfix(**loss)
        return 0

    def show_result(self, fake):
        if self.step % self.args.sample_interval == 0:
            wandb.log({
                'fake': [wandb.Image(im.permute(1,2,0).detach().cpu().numpy()) for index, im in enumerate(fake) if index < 24]
            }, step=self.step)
        return 0

    def train_d(self, data):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # Train with all-real batch
        self.d_optim.zero_grad()
        # Format batch
        real_inputs = data.to(self.args.device)
        b_size = real_inputs.size(0)
        label = torch.full((b_size,), self.args.real_label, dtype=torch.float, device=self.args.device)
        # Forward pass real batch through D
        output = self.d_model.forward(real_inputs).view(-1)
        # Calculate loss on all-real batch
        errD_real = self.d_loss.forward(output, label)
        # Calculate gradients for D in backward pass
        torch.autograd.backward(errD_real)
        D_x = output.mean().item()

        # Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, self.args.nz, 1, 1, device=self.args.device)
        # Generate fake image batch with G
        fake = self.g_model.forward(noise)
        label.fill_(self.args.fake_label)
        # Classify all fake batch with D
        output = self.d_model.forward(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = self.d_loss.forward(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        torch.autograd.backward(errD_fake)
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        self.d_optim.step()
        self.metric['train_d'].setdefault('loss', []).append(errD.item())
        return fake

    def train_g(self, fake):
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        self.g_optim.zero_grad()
        # fake labels are real for generator cost
        label = torch.full((fake.size(0),), self.args.real_label, dtype=torch.float, device=self.args.device)
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = self.d_model.forward(fake).view(-1)
        # Calculate G's loss based on this output
        errG = self.d_loss(output, label)
        # Calculate gradients for G
        torch.autograd.backward(errG)
        D_G_z2 = output.mean().item()
        # Update G
        self.g_optim.step()
        self.metric['train_g'].setdefault('loss', []).append(errG.item())
        return 0

    def train(self):
        for i_epoch in range(self.args.epoch):
            self.train_one_epoch(i_epoch)
            self.validation()
            self.save_model()
        return 0

    def validation(self):
        return 0

    def test(self):
        return 0

    def save_model(self):
        return 0

## Running Process

In [7]:
deep_conv_gan = DCGAN()
deep_conv_gan.load_dataset()
deep_conv_gan.load_model()
deep_conv_gan.train()

[34m[1mwandb[0m: Currently logged in as: [33mhj6hki123[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch [1/200]: 100%|██████████| 125/125 [00:02<00:00, 43.53batch/s, d_loss=0.0555, g_loss=9.67] 
Epoch [2/200]: 100%|██████████| 125/125 [00:02<00:00, 61.11batch/s, d_loss=0.284, g_loss=15.4]   
Epoch [3/200]: 100%|██████████| 125/125 [00:02<00:00, 58.03batch/s, d_loss=0.2, g_loss=14]      
Epoch [4/200]: 100%|██████████| 125/125 [00:02<00:00, 61.91batch/s, d_loss=0.0118, g_loss=6.94]  
Epoch [5/200]: 100%|██████████| 125/125 [00:02<00:00, 59.80batch/s, d_loss=0.074, g_loss=4.66]  
Epoch [6/200]: 100%|██████████| 125/125 [00:01<00:00, 63.09batch/s, d_loss=0.0463, g_loss=13.5]  
Epoch [7/200]: 100%|██████████| 125/125 [00:01<00:00, 63.03batch/s, d_loss=0.0163, g_loss=7.02] 
Epoch [8/200]: 100%|██████████| 125/125 [00:02<00:00, 61.68batch/s, d_loss=0.00409, g_loss=6.8]  
Epoch [9/200]: 100%|██████████| 125/125 [00:02<00:00, 58.32batch/s, d_loss=0.0628, g_loss=4.82]  
Epoch [10/200]: 100%|██████████| 125/125 [00:02<00:00, 58.43batch/s, d_loss=0.004, g_loss=7.2]   
Epoch [11/200]: 100%|███

0