In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
from tqdm import tqdm
import os

# --- 0. ハイパーパラメータ ---
BATCH_SIZE = 128
NOISE_DIM = 100
IMG_SIZE = 28
IMG_DIM = IMG_SIZE * IMG_SIZE
N_EPOCHS = 50
LR = 0.001
N_SAMPLES = 60000  # MNIST訓練データセットの総数
DELTA = 0.05       # 95%の信頼度 (1 - delta)
PAC_WEIGHT = 0.1   # PAC-Bayesペナルティの重み（要調整）

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 出力ディレクトリ
os.makedirs("pac_gan_images", exist_ok=True)

# --- 1. データの準備 ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # -1から1に正規化
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist_dataset, batch_size=BATCH_SIZE, shuffle=True)


# --- 2. Stochastic Generator (PAC-Bayes) ---
# のバウンドを目的関数に使う
class StochasticGenerator(nn.Module):
    def __init__(self, noise_dim, output_dim, hidden_dim=256):
        super().__init__()
        
        self.noise_dim = noise_dim
        
        # --- 重みの「平均 (mu)」を定義 ---
        self.mu_params = nn.ModuleDict({
            'fc1_mu': nn.Linear(noise_dim, hidden_dim),
            'fc2_mu': nn.Linear(hidden_dim, hidden_dim),
            'fc3_mu': nn.Linear(hidden_dim, output_dim),
        })
        
        # --- 重みの「標準偏差 (sigma)」の対数 (rho) を定義 ---
        self.rho_params = nn.ModuleDict({
            'fc1_rho': nn.Linear(noise_dim, hidden_dim),
            'fc2_rho': nn.Linear(hidden_dim, hidden_dim),
            'fc3_rho': nn.Linear(hidden_dim, output_dim),
        })

        # --- 事前分布 p(w) = N(0, 1) ---
        self.prior_mu = 0.0
        self.prior_sigma = 1.0

    def sample_weights(self):
        """Reparameterization Trick で重みをサンプリング"""
        sampled_weights = {}
        for name, mu_layer in self.mu_params.items():
            rho_layer = self.rho_params[name.replace('_mu', '_rho')]
            
            mu_w, mu_b = mu_layer.weight, mu_layer.bias
            rho_w, rho_b = rho_layer.weight, rho_layer.bias
            
            # sigma = log(1 + exp(rho)) (Softplus)
            sigma_w = F.softplus(rho_w)
            sigma_b = F.softplus(rho_b)
            
            # epsilon ~ N(0, 1)
            eps_w = torch.randn_like(mu_w)
            eps_b = torch.randn_like(mu_b)
            
            # w = mu + sigma * epsilon
            w = mu_w + sigma_w * eps_w
            b = mu_b + sigma_b * eps_b
            
            sampled_weights[name.replace('_mu', '')] = (w, b)
            
        return sampled_weights

    def forward(self, z, weights):
        """サンプリングされた重みで順伝播"""
        w1, b1 = weights['fc1']
        x = F.relu(F.linear(z, w1, b1))
        
        w2, b2 = weights['fc2']
        x = F.relu(F.linear(x, w2, b2))
        
        w3, b3 = weights['fc3']
        x = torch.tanh(F.linear(x, w3, b3)) # -1 ~ 1 の画像ピクセル
        return x

    def calculate_kl(self):
        """KL(q || p) を解析的に計算"""
        kl_total = 0.0
        
        for name, mu_layer in self.mu_params.items():
            rho_layer = self.rho_params[name.replace('_mu', '_rho')]
            
            for mu, rho in [(mu_layer.weight, rho_layer.weight), (mu_layer.bias, rho_layer.bias)]:
                
                sigma = F.softplus(rho)
                q_dist = torch.distributions.Normal(mu, sigma)
                p_dist = torch.distributions.Normal(self.prior_mu, self.prior_sigma)
                
                # GPU/CPUの互換性のために、事前分布を q_dist と同じデバイス、型にする
                p_dist = torch.distributions.Normal(
                    torch.full_like(mu, self.prior_mu), 
                    torch.full_like(sigma, self.prior_sigma)
                )

                kl_div = torch.distributions.kl.kl_divergence(q_dist, p_dist).sum()
                kl_total += kl_div
                
        return kl_total

# --- 3. Discriminator ---
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid() # 0 (Fake) ~ 1 (Real)
        )
    
    def forward(self, x):
        return self.model(x)

# --- 4. モデルと最適化の初期化 ---

G = StochasticGenerator(NOISE_DIM, IMG_DIM).to(device)
D = Discriminator(IMG_DIM).to(device)

# D は通常通り学習
optimizer_D = optim.Adam(D.parameters(), lr=LR)
# G は「重みの分布 (mu, rho)」を学習
optimizer_G = optim.Adam(
    list(G.mu_params.parameters()) + list(G.rho_params.parameters()), 
    lr=LR
)

bce_loss = nn.BCELoss()

# McAllesterのバウンドの定数項
# log(2 * sqrt(N) / delta)
log_term = np.log(2 * np.sqrt(N_SAMPLES) / DELTA)

# --- 5. 学習ループ ---

print("PAC-Bayes GAN の学習を開始します...")
for epoch in range(N_EPOCHS):
    pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{N_EPOCHS}")
    for i, (real_imgs, _) in enumerate(pbar):
        
        batch_size = real_imgs.shape[0]
        real_imgs = real_imgs.view(batch_size, -1).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --- (1) Discriminator の学習 ---
        optimizer_D.zero_grad()
        
        # Real 画像
        D_real_output = D(real_imgs)
        loss_D_real = bce_loss(D_real_output, real_labels)
        
        # Fake 画像 (Generatorの重みをサンプリングして生成)
        noise = torch.randn(batch_size, NOISE_DIM).to(device)
        with torch.no_grad(): # Gの勾配は不要
            sampled_weights_D_step = G.sample_weights()
            fake_imgs = G(noise, sampled_weights_D_step) # <-- 修正済み
        
        D_fake_output = D(fake_imgs)
        loss_D_fake = bce_loss(D_fake_output, fake_labels)
        
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()

        # --- (2) Generator の学習 (PAC-Bayes BBVI) ---
        optimizer_G.zero_grad()
        
        # 1. 重みをサンプリング (Reparameterization Trick)
        sampled_weights_G_step = G.sample_weights()
        
        # 2. 経験リスク R_n(q) の計算 (GANのG損失)
        noise_G = torch.randn(batch_size, NOISE_DIM).to(device)
        fake_imgs_G = G(noise_G, sampled_weights_G_step)
        D_output_G = D(fake_imgs_G)
        
        # R_n は G の経験損失 (D を騙せなかった度合い)
        # Gは D_output_G が 1 (real_labels) になることを目指す
        R_n = bce_loss(D_output_G, real_labels)
        
        # 3. KLペナルティ KL(q || p) の計算
        KL = G.calculate_kl()
        
        # 4. PAC-Bayes (McAllester) のバウンド
        # Bound = R_n(q) + sqrt( (KL(q||p) + log(...)) / 2n )
        
        complexity_penalty = torch.sqrt( (KL + log_term) / (2 * N_SAMPLES) )
        
        # 5. 最終的な PAC-Bayes 目的関数 (これを最小化)
        # PAC_WEIGHT でペナルティの強さを調整
        pac_bayes_loss = R_n + (PAC_WEIGHT * complexity_penalty)
        
        pac_bayes_loss.backward()
        optimizer_G.step()
        
        if i % 100 == 0:
            pbar.set_postfix({
                "D Loss": f"{loss_D.item():.4f}",
                "G (R_n)": f"{R_n.item():.4f}",
                "KL": f"{KL.item():.2f}",
                "Penalty": f"{complexity_penalty.item():.4f}"
            })

    # --- エポック終了時に画像を保存 ---
    G.eval() # 評価モード
    with torch.no_grad():
        fixed_noise = torch.randn(64, NOISE_DIM).to(device)
        sample_weights_for_eval = G.sample_weights()
        gen_imgs = G(fixed_noise, sample_weights_for_eval)
        gen_imgs = gen_imgs.view(-1, 1, IMG_SIZE, IMG_SIZE)
        save_image(gen_imgs, f"pac_gan_images/epoch_{epoch+1}.png", normalize=True)
    G.train() # 学習モードに戻す

print("学習完了。")

100%|██████████| 9.91M/9.91M [00:01<00:00, 7.61MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 242kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.09MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.15MB/s]


PAC-Bayes GAN の学習を開始します...


Epoch 1/50: 100%|██████████| 469/469 [00:11<00:00, 42.47it/s, D Loss=0.0030, G (R_n)=8.0962, KL=31638.93, Penalty=0.5136] 
Epoch 2/50: 100%|██████████| 469/469 [00:11<00:00, 41.01it/s, D Loss=0.0734, G (R_n)=15.5193, KL=31888.12, Penalty=0.5156]
Epoch 3/50: 100%|██████████| 469/469 [00:11<00:00, 41.53it/s, D Loss=0.2154, G (R_n)=9.3300, KL=32390.41, Penalty=0.5196] 
Epoch 4/50: 100%|██████████| 469/469 [00:11<00:00, 40.94it/s, D Loss=0.1825, G (R_n)=5.5684, KL=32740.10, Penalty=0.5224]
Epoch 5/50: 100%|██████████| 469/469 [00:12<00:00, 38.59it/s, D Loss=0.1063, G (R_n)=11.4651, KL=33125.97, Penalty=0.5255]
Epoch 6/50: 100%|██████████| 469/469 [00:11<00:00, 42.15it/s, D Loss=0.1043, G (R_n)=3.0560, KL=33516.70, Penalty=0.5286] 
Epoch 7/50: 100%|██████████| 469/469 [00:10<00:00, 42.89it/s, D Loss=0.3473, G (R_n)=12.6352, KL=33853.89, Penalty=0.5312]
Epoch 8/50: 100%|██████████| 469/469 [00:11<00:00, 42.60it/s, D Loss=0.1380, G (R_n)=8.9974, KL=34131.94, Penalty=0.5334] 
Epoch 9/50: 100%|

学習完了。
