In [1]:
pip install torch torchvision scikit-learn matplotlib numpy tqdm

Collecting torch
  Downloading torch-2.8.0-cp312-cp312-win_amd64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.23.0-cp312-cp312-win_amd64.whl.metadata (6.1 kB)
Collecting filelock (from torch)
  Downloading filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.9.0-py3-none-any.whl.metadata (10 kB)
Collecting setuptools (from torch)
  Using cached setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading markupsafe-3.0.3-cp312-cp312-win_amd64.whl.meta


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
# -*- coding: utf-8 -*-
"""
Hinton & Salakhutdinov (Science, 2006) 재현 미니멀 구현
- RBM(제한 볼츠만 기계) 층별 사전학습(Pretraining)
- Autoencoder(오토인코더) 전개(Unroll) + 미세조정(Fine-tuning)
- Curves(합성 데이터) & MNIST 실험
- PCA(주성분 분석) 비교 (재구성 MSE)

주의: 데모용으로 학습 에폭 수를 줄였음. 결과 재현도를 높이려면 에폭/배치/네트워크/옵티마이저 설정을 늘려서 돌려보세요.
"""

import math
import random
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from tqdm import tqdm


# ---------------------------
# Utils
# ---------------------------
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------
# 1) RBM (Bernoulli-Bernoulli, Gaussian-Bernoulli 지원)
# ---------------------------
class RBM(nn.Module):
    """
    Restricted Boltzmann Machine (RBM)
    - visible_type: 'bernoulli' 또는 'gaussian'
    - hidden_type : 'bernoulli'
    Contrastive Divergence(CD-k)로 학습
    """
    def __init__(self, n_visible, n_hidden, visible_type='bernoulli', k=1):
        super().__init__()
        assert visible_type in ['bernoulli', 'gaussian']
        self.n_v = n_visible
        self.n_h = n_hidden
        self.visible_type = visible_type
        self.k = k

        self.W = nn.Parameter(torch.randn(n_visible, n_hidden) * 0.01)
        self.v_bias = nn.Parameter(torch.zeros(n_visible))
        self.h_bias = nn.Parameter(torch.zeros(n_hidden))

        # 가우시안 가시 유닛용 표준편차(여기선 1로 고정; 학습하려면 파라미터화 가능)
        self.sigma = 1.0

    def sample_h(self, v):
        # P(h=1|v)
        p = torch.sigmoid(F.linear(v, self.W.t(), self.h_bias))
        return p, torch.bernoulli(p)

    def sample_v(self, h):
        if self.visible_type == 'bernoulli':
            p = torch.sigmoid(F.linear(h, self.W, self.v_bias))
            return p, torch.bernoulli(p)
        else:
            # Gaussian visible units: mean = W h + b, sigma=1
            mean = F.linear(h, self.W, self.v_bias)
            v = mean + torch.randn_like(mean) * self.sigma
            return mean, v  # 확률적 샘플도 반환

    def free_energy(self, v):
        vbias_term = (v * self.v_bias).sum(dim=1)
        wx_b = F.linear(v, self.W.t(), self.h_bias)
        hidden_term = torch.log1p(torch.exp(wx_b)).sum(dim=1)
        return -(vbias_term + hidden_term)

    def forward(self, v):
        # reconstruction (one-step)
        ph, h = self.sample_h(v)
        pv, v_sample = self.sample_v(h)
        return pv, ph

    def cd_loss(self, v0):
        # CD-k
        ph0, h0 = self.sample_h(v0)
        vk = v0
        hk = h0
        for _ in range(self.k):
            pvk, vk = self.sample_v(hk)
            phk, hk = self.sample_h(vk)

        # Positive / Negative phase
        w_pos = torch.einsum('bi,bj->ij', v0, ph0)
        w_neg = torch.einsum('bi,bj->ij', vk, phk)

        loss = (self.free_energy(v0).mean() - self.free_energy(vk).mean())
        # weight/bias 업데이트는 optimizer로 처리; 여기선 loss만 반환
        stats = {
            'recon_mse': F.mse_loss(vk, v0).item(),
            'pos_mean_h': ph0.mean().item(),
            'neg_mean_h': phk.mean().item(),
        }
        grads = {
            'W_pos': w_pos / v0.size(0),
            'W_neg': w_neg / v0.size(0),
            'v_bias_pos': v0.mean(dim=0),
            'v_bias_neg': vk.mean(dim=0),
            'h_bias_pos': ph0.mean(dim=0),
            'h_bias_neg': phk.mean(dim=0),
        }
        return loss, stats, grads

    def manual_step(self, grads, lr=1e-3):
        with torch.no_grad():
            self.W += lr * (grads['W_pos'] - grads['W_neg'])
            self.v_bias += lr * (grads['v_bias_pos'] - grads['v_bias_neg'])
            self.h_bias += lr * (grads['h_bias_pos'] - grads['h_bias_neg'])


# ---------------------------
# 2) RBM 스택 사전학습
# ---------------------------
def pretrain_rbm_stack(data_loader, layer_sizes: List[int],
                       visible_type_first='bernoulli', k=1, epochs=10, lr=1e-3):
    """
    layer_sizes: [n_visible, h1, h2, ...] 형태
    """
    rbms = []
    data_rep = None

    for li in range(len(layer_sizes) - 1):
        n_v = layer_sizes[li]
        n_h = layer_sizes[li + 1]
        vtype = visible_type_first if li == 0 else 'bernoulli'
        rbm = RBM(n_v, n_h, visible_type=vtype, k=k).to(device)

        print(f"[Pretrain] RBM {li+1}/{len(layer_sizes)-1}: {n_v} -> {n_h} ({vtype} -> bernoulli)")
        for ep in range(1, epochs + 1):
            meters = []
            for batch in data_loader:
                v0 = batch[0].to(device)
                # 첫 층이 아니라면, 이전 층 RBM의 은닉 확률을 '데이터'로 사용
                if data_rep is not None:
                    with torch.no_grad():
                        v0 = data_rep(v0)

                loss, stats, grads = rbm.cd_loss(v0)
                rbm.manual_step(grads, lr=lr)
                meters.append(stats['recon_mse'])
            print(f"  Epoch {ep:02d} | recon_mse={np.mean(meters):.4f}")

        # 다음 층 입력으로 사용할 함수(은닉 확률)
        def make_rep(rbm_):
            def rep(v):
                p_h, _ = rbm_.sample_h(v)
                return p_h
            return rep

        rep_fn = make_rep(rbm)
        # data_rep를 체인으로 누적
        data_rep = (lambda f_old, f_new: (lambda x: f_new(f_old(x))))(data_rep, rep_fn) if data_rep else rep_fn
        rbms.append(rbm)

    return rbms


# ---------------------------
# 3) RBM 가중치로 Autoencoder 초기화 + 미세조정
# ---------------------------
class DeepAutoencoder(nn.Module):
    """
    RBM 스택을 Encoder로 사용하고, Transpose Weight(가중치 전치)를 Decoder에 공유하지는 않지만
    초기값을 대칭으로 설정. (원 논문 컨셉: 전개 후 역전파 미세조정)
    """
    def __init__(self, layer_sizes: List[int], code_linear=False):
        super().__init__()
        enc_layers = []
        for i in range(len(layer_sizes) - 1):
            enc_layers += [nn.Linear(layer_sizes[i], layer_sizes[i+1])]
            if i < len(layer_sizes) - 2:
                enc_layers += [nn.Sigmoid()]  # 로지스틱 유닛
            else:
                enc_layers += [nn.Identity() if code_linear else nn.Sigmoid()]
        self.encoder = nn.Sequential(*enc_layers)

        dec_layers = []
        for i in range(len(layer_sizes) - 1, 0, -1):
            dec_layers += [nn.Linear(layer_sizes[i], layer_sizes[i-1])]
            if i-1 > 0:
                dec_layers += [nn.Sigmoid()]
            else:
                dec_layers += [nn.Sigmoid()]  # [0,1] 범위 재구성
        self.decoder = nn.Sequential(*dec_layers)

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z

def init_from_rbms(model: DeepAutoencoder, rbms: List[RBM]):
    # 1) Encoder: v->h 이므로 Linear(out=n_h, in=n_v) = RBM.W.T
    e_idx = 0
    for m in model.encoder.modules():
        if isinstance(m, nn.Linear):
            W = rbms[e_idx].W.detach().t().clone()      # [n_h, n_v]
            b = rbms[e_idx].h_bias.detach().clone()     # hidden bias
            m.weight.data.copy_(W)
            m.bias.data.copy_(b)
            e_idx += 1

    # 2) Decoder: h->v 이므로 Linear(out=n_v, in=n_h) = RBM.W  (전치 금지!)
    d_linears = [m for m in model.decoder.modules() if isinstance(m, nn.Linear)]
    for i, lin in enumerate(d_linears):
        rbm = rbms[-(i+1)]                              # 역순으로 매칭
        W = rbm.W.detach().clone()                      # [n_v, n_h] → Linear(out=n_v, in=n_h)
        b = rbm.v_bias.detach().clone()                 # visible bias
        # ✅ 기존의 .t() 를 제거하세요
        lin.weight.data.copy_(W)
        lin.bias.data.copy_(b)


def fine_tune_autoencoder(model, data_loader, epochs=20, lr=1e-3, loss_type='bce'):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    history = []

    for ep in range(1, epochs + 1):
        losses = []
        for (x,) in data_loader:
            x = x.to(device)
            x_hat, _ = model(x)
            if loss_type == 'bce':
                loss = F.binary_cross_entropy(x_hat, x)
            else:
                loss = F.mse_loss(x_hat, x)
            opt.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())
        mean_loss = float(np.mean(losses))
        history.append(mean_loss)
        print(f"[Fine-tune] Epoch {ep:02d} | {loss_type.upper()}={mean_loss:.4f}")
    return history


# ---------------------------
# 4) 데이터셋: Curves(합성), MNIST
# ---------------------------
def make_curves_dataset(n_train=20000, n_test=10000, img_size=28):
    """
    원 논문 'Curves' 합성 아이디어를 단순화:
    - 2D에 랜덤 제어점 3개로 베지어/폴리라인 곡선 생성 후, 래스터라이즈
    - 결과는 [0,1] intensity의 28x28 이미지
    """
    def draw_curve(seed=None):
        if seed is not None:
            np.random.seed(seed)
        p = np.random.rand(3, 2) * 0.8 + 0.1  # 0.1~0.9 범위
        # 샘플 포인트
        t = np.linspace(0, 1, 60)
        curve = (1-t)[:,None]*(1-t)[:,None]*p[0] + 2*(1-t)[:,None]*t[:,None]*p[1] + t[:,None]*t[:,None]*p[2]
        # 렌더링
        img = np.zeros((img_size, img_size))
        xy = (curve * (img_size-1)).astype(int)
        for (x, y) in xy:
            img[y, x] = 1.0
        # 두께/블러
        from scipy.ndimage import gaussian_filter
        img = gaussian_filter(img, sigma=0.8)
        img = np.clip(img, 0, 1)
        return img

    def build(n):
        xs = []
        for _ in tqdm(range(n), desc="Synth Curves"):
            xs.append(draw_curve())
        X = np.stack(xs).astype(np.float32)
        X = X.reshape(len(X), -1)
        return X

    Xtr = build(n_train)
    Xte = build(n_test)
    return Xtr, Xte


def make_mnist_dataset(binarize=False):
    tfms = transforms.Compose([
        transforms.ToTensor(),  # [0,1]
    ])
    train = datasets.MNIST(root='./data', train=True, download=True, transform=tfms)
    test  = datasets.MNIST(root='./data', train=False, download=True, transform=tfms)

    def to_matrix(ds):
        X = ds.data.numpy().astype(np.float32) / 255.0
        X = X.reshape(len(X), -1)
        if binarize:
            X = (X > 0.5).astype(np.float32)
        return X, ds.targets.numpy()

    Xtr, ytr = to_matrix(train); Xte, yte = to_matrix(test)
    return Xtr, ytr, Xte, yte


# ---------------------------
# 5) 평가: 재구성 MSE & 2D 시각화
# ---------------------------
def recon_mse(model, X, batch=512):
    model.eval()
    with torch.no_grad():
        tot = 0.0; n = 0
        for i in range(0, len(X), batch):
            xb = torch.from_numpy(X[i:i+batch]).to(device)
            xh, _ = model(xb)
            tot += F.mse_loss(xh, xb, reduction='sum').item()
            n += xb.numel()
        return tot / (len(X)*X.shape[1])

def pca_recon_mse(X_train, X_test, k):
    pca = PCA(n_components=k)
    pca.fit(X_train)
    Xh = pca.inverse_transform(pca.transform(X_test))
    return ((X_test - Xh) ** 2).mean()

def plot_2d_codes(model, X, y=None, nsamp=5000, title="2D codes"):
    model.eval()
    idx = np.random.choice(len(X), size=min(nsamp, len(X)), replace=False)
    with torch.no_grad():
        xb = torch.from_numpy(X[idx]).to(device)
        _, z = model(xb)
        z = z[:, :2].detach().cpu().numpy() if z.shape[1] > 2 else z.detach().cpu().numpy()
    plt.figure()
    if y is not None:
        yt = y[idx]
        scatter = plt.scatter(z[:,0], z[:,1], c=yt, s=5, alpha=0.7)
        plt.legend(*scatter.legend_elements(num=10), title="class", loc="upper right")
    else:
        plt.scatter(z[:,0], z[:,1], s=5, alpha=0.7)
    plt.title(title)
    plt.xlabel("z1"); plt.ylabel("z2"); plt.tight_layout()
    plt.show()


# ---------------------------
# 6) 메인 실험 루틴
# ---------------------------
@dataclass
class ExpConfig:
    dataset: str = "mnist"       # 'mnist' or 'curves'
    layer_sizes: Tuple[int,...] = (784, 1000, 500, 250, 30)  # 논문 예시
    batch_size: int = 128
    rbm_epochs: int = 5          # 데모용 (논문처럼 하려면 30+ 추천)
    rbm_k: int = 1
    rbm_lr: float = 1e-3
    ft_epochs: int = 10          # 데모용
    ft_lr: float = 1e-3
    loss_type: str = "bce"       # 'bce' for [0,1] logistic, 'mse' otherwise
    binarize_mnist: bool = False
    code_linear: bool = True     # 상위 코드 선형(논문 비교 편의)


def run_experiment(cfg: ExpConfig):
    set_seed(7)

    # 데이터 준비
    if cfg.dataset == "mnist":
        Xtr, ytr, Xte, yte = make_mnist_dataset(binarize=cfg.binarize_mnist)
        input_dim = 784
        assert cfg.layer_sizes[0] == input_dim, "layer_sizes 입력 차원(784) 확인"
        train_loader = DataLoader(TensorDataset(torch.from_numpy(Xtr)), batch_size=cfg.batch_size, shuffle=True)
        test_loader  = DataLoader(TensorDataset(torch.from_numpy(Xte)), batch_size=cfg.batch_size, shuffle=False)
        visible_type_first = 'bernoulli'  # 픽셀 [0,1] → 로지스틱
    elif cfg.dataset == "curves":
        Xtr, Xte = make_curves_dataset(n_train=20000, n_test=10000, img_size=28)
        input_dim = 784
        assert cfg.layer_sizes[0] == input_dim
        train_loader = DataLoader(TensorDataset(torch.from_numpy(Xtr)), batch_size=cfg.batch_size, shuffle=True)
        test_loader  = DataLoader(TensorDataset(torch.from_numpy(Xte)), batch_size=cfg.batch_size, shuffle=False)
        visible_type_first = 'bernoulli'  # 곡선 intensity [0,1] → 로지스틱
    else:
        raise ValueError("dataset must be 'mnist' or 'curves'")

    # RBM 스택 사전학습
    rbms = pretrain_rbm_stack(
        data_loader=train_loader,
        layer_sizes=list(cfg.layer_sizes),
        visible_type_first=visible_type_first,
        k=cfg.rbm_k, epochs=cfg.rbm_epochs, lr=cfg.rbm_lr
    )

    # Autoencoder 구성 & 초기화
    dae = DeepAutoencoder(list(cfg.layer_sizes), code_linear=cfg.code_linear)
    init_from_rbms(dae, rbms)

    # 미세조정
    print("[Fine-tune] start")
    fine_tune_autoencoder(dae, train_loader, epochs=cfg.ft_epochs, lr=cfg.ft_lr, loss_type=cfg.loss_type)

    # 평가: 재구성 MSE (테스트)
    mse_dae = recon_mse(dae, Xte, batch=512)
    mse_pca = pca_recon_mse(Xtr, Xte, k=cfg.layer_sizes[-1])
    print(f"[Test] Recon MSE — DAE: {mse_dae:.6f} | PCA(k={cfg.layer_sizes[-1]}): {mse_pca:.6f}")

    # 2D 시각화 (코드 차원이 2면)
    if cfg.layer_sizes[-1] == 2:
        if cfg.dataset == "mnist":
            plot_2d_codes(dae, Xtr, ytr, title="MNIST 2D codes (Autoencoder)")
        else:
            plot_2d_codes(dae, Xtr, None, title="Curves 2D codes (Autoencoder)")

    return dae


# ---------------------------
# Entry
# ---------------------------
if __name__ == "__main__":
    # 1) MNIST: 784-1000-500-250-30 (논문 표준)
    cfg_mnist = ExpConfig(
        dataset="mnist",
        layer_sizes=(784, 1000, 500, 250, 30),
        rbm_epochs=5,   # 재현 높이려면 30+ 로 늘리기
        ft_epochs=10,   # 30~100 추천
        loss_type="bce",
        code_linear=True
    )
    dae_mnist = run_experiment(cfg_mnist)

    # 2) Curves 합성: 784-400-200-100-50-25-6 (논문 예시)
    # 곡선 생성에 scipy가 필요합니다: pip install scipy
    cfg_curves = ExpConfig(
        dataset="curves",
        layer_sizes=(784, 400, 200, 100, 50, 25, 6),
        rbm_epochs=5,
        ft_epochs=10,
        loss_type="bce",
        code_linear=True
    )
    dae_curves = run_experiment(cfg_curves)


[Pretrain] RBM 1/4: 784 -> 1000 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.1812
  Epoch 02 | recon_mse=0.1538
  Epoch 03 | recon_mse=0.1509
  Epoch 04 | recon_mse=0.1474
  Epoch 05 | recon_mse=0.1421
[Pretrain] RBM 2/4: 1000 -> 500 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2489
  Epoch 02 | recon_mse=0.2467
  Epoch 03 | recon_mse=0.2466
  Epoch 04 | recon_mse=0.2463
  Epoch 05 | recon_mse=0.2453
[Pretrain] RBM 3/4: 500 -> 250 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2344
  Epoch 02 | recon_mse=0.2182
  Epoch 03 | recon_mse=0.2156
  Epoch 04 | recon_mse=0.2145
  Epoch 05 | recon_mse=0.2138
[Pretrain] RBM 4/4: 250 -> 30 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2835
  Epoch 02 | recon_mse=0.2724
  Epoch 03 | recon_mse=0.2648
  Epoch 04 | recon_mse=0.2582
  Epoch 05 | recon_mse=0.2523
[Fine-tune] start
[Fine-tune] Epoch 01 | BCE=0.2407
[Fine-tune] Epoch 02 | BCE=0.2319
[Fine-tune] Epoch 03 | BCE=0.2252
[Fine-tune] Epoch 04 | BCE=0.2189
[Fine-tune] Epoch 05 

Synth Curves: 100%|██████████| 20000/20000 [00:01<00:00, 10644.82it/s]
Synth Curves: 100%|██████████| 10000/10000 [00:00<00:00, 10901.45it/s]


[Pretrain] RBM 1/6: 784 -> 400 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2419
  Epoch 02 | recon_mse=0.1148
  Epoch 03 | recon_mse=0.0783
  Epoch 04 | recon_mse=0.0616
  Epoch 05 | recon_mse=0.0526
[Pretrain] RBM 2/6: 400 -> 200 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2504
  Epoch 02 | recon_mse=0.2494
  Epoch 03 | recon_mse=0.2492
  Epoch 04 | recon_mse=0.2491
  Epoch 05 | recon_mse=0.2491
[Pretrain] RBM 3/6: 200 -> 100 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2524
  Epoch 02 | recon_mse=0.2486
  Epoch 03 | recon_mse=0.2471
  Epoch 04 | recon_mse=0.2462
  Epoch 05 | recon_mse=0.2456
[Pretrain] RBM 4/6: 100 -> 50 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2622
  Epoch 02 | recon_mse=0.2540
  Epoch 03 | recon_mse=0.2495
  Epoch 04 | recon_mse=0.2469
  Epoch 05 | recon_mse=0.2449
[Pretrain] RBM 5/6: 50 -> 25 (bernoulli -> bernoulli)
  Epoch 01 | recon_mse=0.2667
  Epoch 02 | recon_mse=0.2597
  Epoch 03 | recon_mse=0.2547
  Epoch 04 | recon_mse=0.2517
  E