In [5]:
import os

os.chdir('/app')

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import os
import torchvision.utils as vutils

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256, 0.8),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512, 0.8),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024, 0.8),
            nn.Linear(1024, int(np.prod(img_size))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_size)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_size)), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


In [8]:
# コンフィグ
CONFIG = {
    'img_shape': (1, 28, 28),
    'z_dim': 100,
    'lr': 0.0002,
    'betas': (0.5, 0.999),
    'batch_size': 64,
    'epochs': 200,
    'save_interval': 1000
}



In [20]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = (1, 28, 28)
    z_dim = 100
    lr = 0.0002
    betas = (0.5, 0.999)
    batch_size = 64
    epochs = 200
    save_interval = 1000
    output_dir = "images"

config = TrainingConfig()
    

In [29]:
from accelerate import Accelerator


def train(config: TrainingConfig):
    accelerator = Accelerator()
    device = accelerator.device

    # モデルの定義
    generator = Generator(config.z_dim, config.image_size).to(device)
    discriminator = Discriminator(config.image_size).to(device)

    # 損失関数とオプティマイザの定義
    adversarial_loss = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=config.lr, betas=config.betas)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=config.lr, betas=config.betas)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

    # Acceleratorによる準備
    generator, discriminator, g_optimizer, d_optimizer, dataloader = accelerator.prepare(
        generator, discriminator, g_optimizer, d_optimizer, dataloader
    )

    


    for epoch in range(config.epochs):
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.shape[0]

            valid = torch.ones(batch_size, 1, device=device)
            fake = torch.zeros(batch_size, 1, device=device)

            # Generator training
            g_optimizer.zero_grad()
            z = torch.randn(batch_size, config.z_dim, device=device)
            gen_imgs = generator(z)
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            accelerator.backward(g_loss)
            g_optimizer.step()

            # Discriminator training
            d_optimizer.zero_grad()
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            accelerator.backward(d_loss)
            d_optimizer.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch}/{config.epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

            batches_done = epoch * len(dataloader) + i
            if batches_done % config.save_interval == 0:
                save_imgs(generator, config, batches_done, device)

def save_imgs(generator, config, batches_done, device):
    z = torch.randn(25, config.z_dim, device=device)
    gen_imgs = generator(z)
    gen_imgs = (gen_imgs + 1) / 2

    os.makedirs(config.output_dir, exist_ok=True)
    save_path = os.path.join(config.output_dir, f"image_{batches_done}.png")
    
    vutils.save_image(gen_imgs.cpu().data, save_path, nrow=5, normalize=True)
    print(f"Saved image to {save_path}")


In [30]:
config = TrainingConfig()
train(config)

[Epoch 0/200] [Batch 0/938] [D loss: 0.6858] [G loss: 0.6631]
Saved image to images/image_0.png
[Epoch 0/200] [Batch 100/938] [D loss: 0.3834] [G loss: 0.7679]
[Epoch 0/200] [Batch 200/938] [D loss: 0.5246] [G loss: 1.7469]
[Epoch 0/200] [Batch 300/938] [D loss: 0.4507] [G loss: 0.9287]
[Epoch 0/200] [Batch 400/938] [D loss: 0.3941] [G loss: 0.9291]
[Epoch 0/200] [Batch 500/938] [D loss: 0.3964] [G loss: 1.1251]
[Epoch 0/200] [Batch 600/938] [D loss: 0.4307] [G loss: 0.9221]
[Epoch 0/200] [Batch 700/938] [D loss: 0.4395] [G loss: 1.2100]


KeyboardInterrupt: 