In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 1. 장치 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. 데이터셋 로드 (CIFAR-10)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]로 정규화
])
cifar10 = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(cifar10, batch_size=64, shuffle=True)

# 3. Generator 정의
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_size):
        super(Generator, self).__init__()
        self.label_embed = nn.Embedding(num_classes, num_classes)
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim + num_classes, 128 * self.init_size**2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_input = self.label_embed(labels)
        gen_input = torch.cat((noise, label_input), -1)
        out = self.l1(gen_input)
        out = out.view(out.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# 4. Critic(Discriminator) 정의
class Critic(nn.Module):
    def __init__(self, num_classes, img_size):
        super(Critic, self).__init__()
        self.label_embed = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Conv2d(3 + num_classes, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, stride=1, padding=0)
        )

    def forward(self, img, labels):
        label_input = self.label_embed(labels).unsqueeze(2).unsqueeze(3)  # (batch_size, num_classes, 1, 1)
        label_input = label_input.expand(img.size(0), label_input.size(1), img.size(2), img.size(3))  # (batch_size, num_classes, height, width)
        d_in = torch.cat((img, label_input), 1)  # 채널 차원 기준으로 결합
        validity = self.model(d_in)
        return validity.view(-1)

# 5. 하이퍼파라미터 설정
latent_dim = 100
num_classes = 10
img_size = 32
lr = 0.00005
n_epochs = 50
n_critic = 5
clip_value = 0.01

# 모델 초기화
generator = Generator(latent_dim, num_classes, img_size).to(device)
critic = Critic(num_classes, img_size).to(device)

# 최적화 함수
optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_C = optim.RMSprop(critic.parameters(), lr=lr)

# 6. 학습 루프
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        imgs, labels = imgs.to(device), labels.to(device)
        
        # Critic 학습
        optimizer_C.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        fake_imgs = generator(z, labels).detach()
        
        real_loss = critic(imgs, labels).mean()
        fake_loss = critic(fake_imgs, labels).mean()
        critic_loss = -(real_loss - fake_loss)
        critic_loss.backward()
        optimizer_C.step()
        
        # 가중치 클리핑
        for p in critic.parameters():
            p.data.clamp_(-clip_value, clip_value)
        
        # Generator 학습 (n_critic 주기마다)
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            gen_imgs = generator(z, labels)
            generator_loss = -critic(gen_imgs, labels).mean()
            generator_loss.backward()
            optimizer_G.step()

    print(f"[Epoch {epoch+1}/{n_epochs}] [C Loss: {critic_loss.item():.4f}] [G Loss: {generator_loss.item():.4f}]")

# 7. 결과 확인
def show_generated_images():
    z = torch.randn(9, latent_dim).to(device)
    labels = torch.randint(0, num_classes, (9,)).to(device)
    with torch.no_grad():
        gen_imgs = generator(z, labels)
        gen_imgs = (gen_imgs * 0.5) + 0.5  # [-1, 1]에서 [0, 1]로 변환
        grid = vutils.make_grid(gen_imgs, nrow=3)
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.title("Generated Images")
        plt.axis("off")
        plt.show()

show_generated_images()


Files already downloaded and verified


KeyboardInterrupt: 