In [None]:
import os, random
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
random.seed(42); torch.manual_seed(42); torch.cuda.manual_seed_all(42)
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
device="cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class VAE(nn.Module):
    def __init__(self, d=20):
        super().__init__()
        self.fc1=nn.Linear(784,400)
        self.mu=nn.Linear(400,d)
        self.lv=nn.Linear(400,d)
        self.fc3=nn.Linear(d,400)
        self.fc4=nn.Linear(400,784)
    def encode(self,x):
        h=torch.relu(self.fc1(x)); return self.mu(h),self.lv(h)
    def reparam(self,mu,lv):
        std=torch.exp(0.5*lv); return mu+torch.randn_like(std)*std
    def decode(self,z):
        return torch.sigmoid(self.fc4(torch.relu(self.fc3(z))))
    def forward(self,x):
        mu,lv=self.encode(x); z=self.reparam(mu,lv); xh=self.decode(z); return xh,mu,lv
def loss_fn(xh,x,mu,lv):
    bce=F.binary_cross_entropy(xh,x,reduction="sum")
    kld=-0.5*torch.sum(1+lv-mu.pow(2)-lv.exp())
    return bce+kld

In [None]:
batch=128; epochs=5; d=20; lr=1e-3
tr=datasets.MNIST("/content/mnist",train=True,download=True,transform=transforms.ToTensor())
te=datasets.MNIST("/content/mnist",train=False,download=True,transform=transforms.ToTensor())
trL=DataLoader(tr,batch_size=batch,shuffle=True,num_workers=2,pin_memory=True)
teL=DataLoader(te,batch_size=batch,shuffle=False,num_workers=2,pin_memory=True)
m=VAE(d).to(device); opt=optim.Adam(m.parameters(),lr=lr)
for e in range(1,epochs+1):
    m.train(); tot=0.0
    for x,_ in trL:
        x=x.to(device).view(x.size(0),-1); opt.zero_grad()
        xh,mu,lv=m(x); L=loss_fn(xh,x,mu,lv); L.backward(); opt.step(); tot+=L.item()
    print(f"Epoch {e}, Average loss: {tot/len(tr):.4f}")
print("訓練完成！")

In [None]:
m.eval()
x,_=next(iter(teL)); x=x.to(device)
with torch.no_grad():
    xh,_,_=m(x.view(x.size(0),-1))
xh=xh.view(-1,1,28,28).cpu().clamp(0,1)
x=x.cpu()
fig,axes=plt.subplots(2,8,figsize=(12,3))
for i in range(8):
    axes[0,i].imshow(x[i,0],cmap="gray"); axes[0,i].axis("off")
    axes[1,i].imshow(xh[i,0],cmap="gray"); axes[1,i].axis("off")
plt.tight_layout(); plt.savefig("reconstruction.png",dpi=150); plt.show()
print("已儲存重建圖檔 reconstruction.png")