In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 하이퍼파라미터 설정
batch_size = 64
latent_dim = 100
epochs = 50
lr = 0.0002
beta1 = 0.5

# 데이터셋 로딩 (MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

# 생성자 네트워크 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# 판별자 네트워크 정의
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 네트워크 초기화
generator = Generator()
discriminator = Discriminator()

# 손실 함수와 옵티마이저 설정
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# 학습 루프
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(data_loader):
        
        # 실제 이미지 라벨: 1, 생성된 이미지 라벨: 0
        real_labels = torch.ones((imgs.size(0), 1))
        fake_labels = torch.zeros((imgs.size(0), 1))

        # 생성자 학습
        optimizer_G.zero_grad()
        z = torch.randn((imgs.size(0), latent_dim))
        generated_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(generated_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

        # 판별자 학습
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # 중간 결과 출력
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Batch {i}/{len(data_loader)} Loss D: {d_loss.item()}, Loss G: {g_loss.item()}")

    # 생성된 이미지 확인을 위한 샘플 저장 (매 에포크 끝날 때)
    with torch.no_grad():
        z = torch.randn(64, latent_dim)
        generated_imgs = generator(z)
        generated_imgs = (generated_imgs + 1) / 2.0  # 이미지를 [0, 1] 범위로 변환
        # 이미지 저장 코드는 생략

print("학습 완료!")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:02<00:00, 3606950.95it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 146289.08it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:01<00:00, 1280805.07it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 2280408.04it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch [1/50] Batch 0/938 Loss D: 0.7455304861068726, Loss G: 0.7299652099609375
Epoch [1/50] Batch 100/938 Loss D: 0.47361522912979126, Loss G: 0.853972852230072
Epoch [1/50] Batch 200/938 Loss D: 0.6459413170814514, Loss G: 0.6012499332427979
Epoch [1/50] Batch 300/938 Loss D: 0.48337578773498535, Loss G: 0.7055453062057495
Epoch [1/50] Batch 400/938 Loss D: 0.6094355583190918, Loss G: 0.5253724455833435
Epoch [1/50] Batch 500/938 Loss D: 0.36225876212120056, Loss G: 0.9337050318717957
Epoch [1/50] Batch 600/938 Loss D: 0.45116013288497925, Loss G: 1.1629369258880615
Epoch [1/50] Batch 700/938 Loss D: 0.4891432523727417, Loss G: 1.0086379051208496
Epoch [1/50] Batch 800/938 Loss D: 0.3125303387641907, Loss G: 1.065537929534912
Epoch [1/50] Batch 900/938 Loss D: 0.21333864331245422, Loss G: 1.8150198459625244
Epoch [2/50] Batch 0/938 Loss D: 0.3555065989494324, Loss G: 0.7941721081733704
Epoch [2/50] Batch 100/9