In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import json
from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import IPython.display as display
%matplotlib inline

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models,transforms, datasets
from torch.utils.data import DataLoader

import sys
sys.path.append("../src/libs/")

from transform import BaseTransform
from dataset import CustomDataset
from loss import LossFunction

print("Pytorch Version: ", torch.__version__)
print("Torchvision Version:", torchvision.__version__)

In [None]:
img_path = "../data/img/"
cor_path = "../data/cor_img/"
transform = BaseTransform() #256*256
dataset = CustomDataset(data_dir=img_path,cor_dir=cor_path,transform=transform)
data_loader = DataLoader(dataset, batch_size=4, shuffle=False)

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(Encoder, self).__init__()
        self.condition_dim = condition_dim
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.fc1 = nn.Linear(128 * 64 * 64 + self.condition_dim, 1024)
        self.fc2 = nn.Linear(1024, latent_dim)
        self.fc3 = nn.Linear(1024, latent_dim)

    def forward(self, x, condition):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        condition = condition.view(condition.size(0), -1)
        x = torch.cat((x, condition), dim=1)
        x = F.relu(self.fc1(x))
        mu = self.fc2(x)
        logvar = self.fc3(x)
        return mu, logvar


In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(Decoder, self).__init__()
        self.condition_dim = condition_dim
        self.fc1 = nn.Linear(latent_dim + self.condition_dim, 1024)
        self.fc2 = nn.Linear(1024, 128 * 64 * 64)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(64, 3, 4, 2, 1)

    def forward(self, z, condition):
        z = torch.cat((z, condition), dim=1)
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = z.view(z.size(0), 128, 64, 64)
        z = F.relu(self.deconv1(z))
        z = torch.sigmoid(self.deconv2(z))
        return z

In [None]:
# Conditional VAEモデルの構築
class ConditionalVAE(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(ConditionalVAE, self).__init__()
        self.encoder = Encoder(latent_dim, condition_dim)
        self.decoder = Decoder(latent_dim, condition_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, condition):
        mu, logvar = self.encoder(x, condition)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z, condition)
        return recon_x, mu, logvar


In [None]:
# 再構築誤差の計算
def reconstruction_loss(recon_x, x):
    # 通常の平均二乗誤差 (MSE) またはクロスエントロピー誤差を使用
    mse_loss = nn.MSELoss()  # または nn.BCELoss() など
    recon_loss = mse_loss(recon_x, x)
    return recon_loss

# KLダイバージェンスの計算
def kl_divergence(mu, logvar):
    # ガウス分布のKLダイバージェンス
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl_loss

# Conditional VAEの損失関数
def vae_loss(recon_x, x, mu, logvar, condition):
    recon_loss = reconstruction_loss(recon_x, x)
    kl_loss = kl_divergence(mu, logvar)
    
    # 通常の再構築誤差とKLダイバージェンスに重みを掛けて合算
    beta = 1.0  # KLダイバージェンスの重み（調整が必要）
    vae_loss = recon_loss + beta * kl_loss
    
    return vae_loss

In [None]:
# モデルのインスタンス化
z_dim = 128  # 潜在変数の次元
condition_dim = 256
vae = ConditionalVAE(z_dim, condition_dim)

# 最適化アルゴリズムと学習率
optimizer = optim.Adam(vae.parameters(), lr=0.001)

device = torch.device('cpu')

In [None]:
# モデルのトレーニング
num_epochs = 50
for epoch in range(num_epochs):
    for batch in data_loader:
        optimizer.zero_grad()
        x, condition = batch[0].to(device), batch[1].to(device)
        recon_x, mu, logvar = vae(x, condition)
        loss = vae_loss(recon_x, x, mu, logvar, condition)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')