# 微分幾何的にフラットなパラメータ表現を用いたVAE

### まだバグがある．＜＝ 学習アルゴリズムをちゃんと導出して，修正する必要あり！！！

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA



In [23]:
# ハイパーパラメータ
batch_size = 128
latent_dim = 1  # フラットなパラメータ表現では1次元
epochs = 50
learning_rate = 1e-3

# MNISTデータセットの読み込みとデータローダーの作成
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



In [24]:
# VAEモデルの定義（フラットなパラメータ表現を使用）
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7*7*32, 256),
            nn.ReLU()
        )
        self.fc_theta1 = nn.Linear(256, latent_dim)
        self.fc_theta2 = nn.Sequential(
            nn.Linear(256, latent_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 7*7*32),
            nn.ReLU(),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        theta1 = self.fc_theta1(h)
        theta2 = -self.fc_theta2(h)
        return theta1, theta2

    def reparameterize(self, theta1, theta2):
        # ここは，exponential分布のパラメータを使用してzを生成する
        # exp()
        # フラットなパラメータ表現から平均と分散に変換
        print("theta1:", theta1)
        print("theta2:", theta2)
        # theta2は分散の逆数であるため，分散を計算する
        mu = theta1 / (-2 * theta2)
        sigma = torch.sqrt(-1 / (2 * theta2))
        eps = torch.randn_like(mu)
        z = mu + eps * sigma
        return z

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        theta1, theta2 = self.encode(x)
        z = self.reparameterize(theta1, theta2)
        x_recon = self.decode(z)
        return x_recon, theta1, theta2



In [25]:
# モデル、損失関数、最適化手法の定義
model = VAE(latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def loss_function(recon_x, x, theta1, theta2):
#    BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    MSE = nn.functional.mse_loss(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    # KLダイバージェンスの計算（フラットなパラメータ表現に対応）
    mu = theta1 / (-2 * theta2)
    sigma = torch.sqrt(-1 / (2 * theta2))
    KLD = -0.5 * torch.sum(1 + 2 * torch.log(sigma) - mu.pow(2) - sigma.pow(2))
    return MSE + KLD

# 学習ループ
train_losses = []
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, theta1, theta2 = model(data)
        loss = loss_function(recon_batch, data, theta1, theta2)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    train_losses.append(train_loss / len(train_loader.dataset))
    print(f'Epoch {epoch+1}, Loss: {train_losses[-1]:.4f}')

# 学習曲線のプロット
plt.plot(train_losses)
plt.title('VAE Training Loss (Flat Parametrization)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()



theta1: tensor([[-0.0370],
        [-0.0494],
        [-0.0346],
        [-0.0571],
        [-0.0477],
        [-0.0423],
        [-0.0526],
        [-0.0535],
        [-0.0532],
        [-0.0402],
        [-0.0441],
        [-0.0488],
        [-0.0613],
        [-0.0462],
        [-0.0623],
        [-0.0502],
        [-0.0531],
        [-0.0545],
        [-0.0419],
        [-0.0545],
        [-0.0397],
        [-0.0435],
        [-0.0357],
        [-0.0520],
        [-0.0509],
        [-0.0465],
        [-0.0470],
        [-0.0463],
        [-0.0387],
        [-0.0546],
        [-0.0403],
        [-0.0292],
        [-0.0577],
        [-0.0441],
        [-0.0617],
        [-0.0507],
        [-0.0446],
        [-0.0660],
        [-0.0480],
        [-0.0563],
        [-0.0491],
        [-0.0437],
        [-0.0515],
        [-0.0424],
        [-0.0426],
        [-0.0474],
        [-0.0556],
        [-0.0495],
        [-0.0332],
        [-0.0514],
        [-0.0545],
        [-0.0439],
    

KeyboardInterrupt: 

In [None]:
# 中間層の出力の可視化 (PCA適用)
model.eval()
z_list = []
labels_list = []
with torch.no_grad():
    for data, labels in train_loader:
        theta1, theta2 = model.encode(data)
        z = model.reparameterize(theta1, theta2)
        z_list.extend(z.cpu().numpy())
        labels_list.extend(labels.cpu().numpy())
z_list = np.array(z_list)
labels_list = np.array(labels_list)

pca = PCA(n_components=2)
z_pca = pca.fit_transform(z_list)

plt.figure(figsize=(10, 8))
plt.scatter(z_pca[:, 0], z_pca[:, 1], c=labels_list, cmap='viridis', s=5)
plt.colorbar()
plt.title('Latent Space Visualization (PCA, Flat Parametrization)')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.show()

# 入力画像と再構成画像の比較
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    with torch.no_grad():
        sample = train_dataset[i][0].unsqueeze(0)
        recon_sample, _, _ = model(sample)
    plt.subplot(2, n, i+1)
    plt.imshow(sample.squeeze().numpy(), cmap='gray')
    plt.title('Input')
    plt.axis('off')
    plt.subplot(2, n, i+n+1)
    plt.imshow(recon_sample.squeeze().numpy(), cmap='gray')
    plt.title('Reconstruction')
    plt.axis('off')
plt.show()