In [9]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from typing import List, Tuple
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# CelebAデータセットのカスタムデータセットクラス
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.dataset = ImageFolder(root=root_dir, transform=transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx][0]

# VAEモデルの定義
class VAE_CelebA(nn.Module):
    def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None):
        super(VAE_CelebA, self).__init__()
        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # エンコーダーの構築
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        # エンコーダーの出力サイズを計算
        self.feature_size = self._get_feature_size(64)

        # エンコーダーの出力サイズに基づいて全結合層を定義
        self.fc_mu = nn.Linear(self.feature_size, latent_dim)
        self.fc_var = nn.Linear(self.feature_size, latent_dim)

        # デコーダーの構築
        modules = []
        self.decoder_input = nn.Linear(latent_dim, self.feature_size)
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1],
                               hidden_dims[-1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=3,
                      kernel_size=3, padding=1),
            nn.Tanh())

    def _get_feature_size(self, img_size):
        x = torch.zeros(1, 3, img_size, img_size)
        x = self.encoder(x)
        return x.view(1, -1).size(1)

    def encode(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)  # ここが問題であれば修正
        result = self.decoder(result)
        result = self.final_layer(result)
        result = (result + 1.) / 2.
        return result

    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), input, mu, log_var

    def loss_function(self,
                      recons: torch.Tensor, 
                      input: torch.Tensor, 
                      mu: torch.Tensor, 
                      log_var: torch.Tensor, 
                      kld_weight: float = 0.00025) -> torch.Tensor:
        recons_loss = F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
        loss = recons_loss + kld_weight * kld_loss
        return loss

# データ変換の定義
IMAGE_SIZE = 64
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE, antialias=True),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor()
])

# データセットとデータローダの作成
data_dir = '/home/data/hnakai/CelebA'
dataset = CelebADataset(root_dir=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

# モデル、オプティマイザ、ロス関数の設定
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VAE_CelebA(in_channels=3, latent_dim=512).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

# トレーニングループ
num_epochs = 100
model.train()
epoch_losses = []  # エポックごとのロスを保存するリスト

for epoch in range(num_epochs):
    epoch_loss = 0  # エポックごとのロスを記録する変数
    for batch_idx, images in enumerate(dataloader):
        images = images.to(device)
        optimizer.zero_grad()
        recons, input, mu, log_var = model(images)
        loss = model.loss_function(recons, input, mu, log_var)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()  # バッチごとのロスを累積
    epoch_losses.append(epoch_loss / len(dataloader))  # エポックごとの平均ロスを計算して保存
    print(f'Epoch [{epoch}/{num_epochs}] Loss: {epoch_losses[-1]:.4f}')

# トレーニング終了後、モデルを保存
torch.save(model.state_dict(), 'vae_celeba.pth')

# ロスのプロット
plt.figure()
plt.plot(range(num_epochs), epoch_losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

# モデルのテストと入力画像の再構成
model.eval()
with torch.no_grad():
    for batch_idx, images in enumerate(dataloader):
        if batch_idx == 0:  # 最初のバッチだけ使用
            images = images.to(device)
            recons, _, _, _ = model(images)
            break

    # 再構成された画像と元の画像を表示
    for i in range(5):
        original = transforms.ToPILImage()(images[i].cpu())
        reconstructed = transforms.ToPILImage()(recons[i].cpu())

        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(original)
        axes[0].set_title("Original")
        axes[1].imshow(reconstructed)
        axes[1].set_title("Reconstructed")
        plt.show()


Epoch [0/100] Loss: 0.0506
Epoch [1/100] Loss: 0.0380
Epoch [2/100] Loss: 0.0361
