In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutils

device = torch.device("cuda" if not torch.cuda.is_available() else "cpu")
print(device)

cpu


In [14]:
# 데이터셋 클래스 정의
class StyleTransferDataset(Dataset):
    def __init__(self, gothic_dir, handwriting_dir, transform=None):
        self.gothic_images = sorted([os.path.join(gothic_dir, img) for img in os.listdir(gothic_dir)])
        self.handwriting_images = sorted([os.path.join(handwriting_dir, img) for img in os.listdir(handwriting_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.gothic_images)

    def __getitem__(self, idx):
        gothic_image = Image.open(self.gothic_images[idx]).convert("RGB")
        handwriting_image = Image.open(self.handwriting_images[idx]).convert("RGB")
        
        if self.transform:
            gothic_image = self.transform(gothic_image)
            handwriting_image = self.transform(handwriting_image)
        
        return gothic_image, handwriting_image

# 데이터 변환 정의
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# 데이터셋 경로 설정
gothic_dir = '../data/FONT_GODIC_ONE/9'
handwriting_dir = '../data/FONT_ONE/9'

# 데이터셋 및 데이터로더 초기화
dataset = StyleTransferDataset(gothic_dir, handwriting_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [15]:
# 특징 추출기 (Encoder) 정의
class Encoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        vgg_features = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.encoder = nn.Sequential(*list(vgg_features.children())[:21])  # Conv4_1까지 사용
        
        # 조정된 레이어
        self.conv = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)  # 추가 Conv 레이어로 크기 조정
        self.fc_mu = nn.Linear(512 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(512 * 16 * 16, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


In [16]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 16 * 16)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 512, 16, 16)
        x = self.decoder(x)
        return x

In [17]:
class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar


In [18]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [22]:
def vae_loss_function(recon_x, x, mu, logvar):
    recon_x = recon_x.clamp(0, 1)  # 값 범위 조정
    x = x.clamp(0, 1)  # 값 범위 조정
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 모델 초기화
latent_dim = 128
vae = VAE(latent_dim).to(device)
discriminator = Discriminator().to(device)

# 손실 함수
adversarial_loss = nn.BCELoss().to(device)

# 옵티마이저
optimizer_G = optim.Adam(vae.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [23]:
def save_generated_images(images, num_images, epoch, idx):
    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.title("Generated Images")
    images = vutils.make_grid(images[:num_images], padding=2, normalize=True)
    images = np.transpose(images.cpu(), (1, 2, 0))
    fname = '../data/output_images/' + str(epoch) + '_' + str(idx) + '.jpg'
    plt.imsave(fname, images.numpy())
    plt.close()

In [None]:
# 훈련 루프
num_epochs = 10
output_dir = 'output_images'
os.makedirs(output_dir, exist_ok=True)

for epoch in range(num_epochs):
    for i, (gothic, handwriting) in enumerate(dataloader):
        gothic, handwriting = gothic.to(device), handwriting.to(device)
        batch_size = gothic.size(0)

        # 판별자 출력 크기에 맞춰 목표 텐서 크기 조정
        valid = torch.ones(discriminator(handwriting).size(), requires_grad=False).to(device)
        fake = torch.zeros(discriminator(handwriting).size(), requires_grad=False).to(device)

        # ---------------------
        #  Train VAE (Generator)
        # ---------------------
        optimizer_G.zero_grad()

        recon_gothic, mu, logvar = vae(gothic)
        # handwriting 이미지를 recon_gothic와 동일한 크기로 변환
        handwriting_resized = torch.nn.functional.interpolate(handwriting, size=recon_gothic.shape[2:], mode='bilinear')

        vae_loss = vae_loss_function(recon_gothic, handwriting_resized, mu, logvar)
        g_adv_loss = adversarial_loss(discriminator(recon_gothic), valid)

        g_loss = vae_loss + 0.1 * g_adv_loss  # 가중치를 조정하여 손실 함수의 균형 맞추기
        g_loss.backward()
        nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)  # Gradient Clipping
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(handwriting), valid)
        fake_loss = adversarial_loss(discriminator(recon_gothic.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)  # Gradient Clipping
        optimizer_D.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}")

        # 결과 이미지 저장
        if (epoch * len(dataloader) + i) % 10 == 0:
            save_generated_images(recon_gothic, 1, epoch=epoch, idx=i)

print("Training completed.")

Epoch [1/10], Step [1/10000], G Loss: 130567.4453, D Loss: 0.7202
Epoch [1/10], Step [11/10000], G Loss: 127425.3359, D Loss: 0.0808
Epoch [1/10], Step [21/10000], G Loss: 87146.9922, D Loss: 0.0541
Epoch [1/10], Step [31/10000], G Loss: 63438.8164, D Loss: 0.0213
Epoch [1/10], Step [41/10000], G Loss: 65542.6094, D Loss: 0.0087
Epoch [1/10], Step [51/10000], G Loss: 63458.8086, D Loss: 0.0050
Epoch [1/10], Step [61/10000], G Loss: 72711.0391, D Loss: 0.0035
Epoch [1/10], Step [71/10000], G Loss: 60661.4531, D Loss: 0.0027
Epoch [1/10], Step [81/10000], G Loss: 62347.8711, D Loss: 0.0060
Epoch [1/10], Step [91/10000], G Loss: 61237.2734, D Loss: 0.0025
Epoch [1/10], Step [101/10000], G Loss: 61564.5039, D Loss: 0.0015
Epoch [1/10], Step [111/10000], G Loss: 57473.9414, D Loss: 0.0011
Epoch [1/10], Step [121/10000], G Loss: 62711.0547, D Loss: 0.0117
Epoch [1/10], Step [131/10000], G Loss: 54518.5820, D Loss: 0.0614
Epoch [1/10], Step [141/10000], G Loss: 55048.2656, D Loss: 0.0015
Epoc