In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from tqdm import tqdm

# 定義DIP模型
class DIP(nn.Module):
    def __init__(self):
        super(DIP, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(32, 1, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.conv6(x)
        return x

# 定義DDPM模型
class DDPM(nn.Module):
    def __init__(self, T, beta_start=0.0001, beta_end=0.02):
        super(DDPM, self).__init__()
        self.T = T
        self.beta = torch.linspace(beta_start, beta_end, T)
        self.alpha = 1.0 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, 0)

    def forward(self, x0):
        t = torch.randint(0, self.T, (x0.size(0),), device=x0.device)
        epsilon = torch.randn_like(x0)
        xt = torch.sqrt(self.alpha_hat[t]).unsqueeze(1).unsqueeze(2).unsqueeze(3) * x0 + \
             torch.sqrt(1 - self.alpha_hat[t]).unsqueeze(1).unsqueeze(2).unsqueeze(3) * epsilon
        return xt

# 加載數據
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = FashionMNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# 訓練DIP模型
def train_DIP(dip_model, dataloader, epochs=2, lr=0.001):
    optimizer = optim.Adam(dip_model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    for epoch in range(epochs):
        progress_bar = tqdm(dataloader, desc=f'Epoch [{epoch+1}/{epochs}]')
        for data in progress_bar:
            img, _ = data
            optimizer.zero_grad()
            output = dip_model(img)
            loss = criterion(output, img)
            loss.backward()
            optimizer.step()
            progress_bar.set_postfix(loss=loss.item())
        save_image(output, f'output_DIP_epoch{epoch+1}.png')
    return dip_model

# 使用DIP模型輸出作為DDPM初始先驗
def train_DDPM(ddpm_model, dip_model, dataloader, epochs=2):
    for epoch in range(epochs):
        progress_bar = tqdm(dataloader, desc=f'Epoch [{epoch+1}/{epochs}]')
        for data in progress_bar:
            img, _ = data
            dip_output = dip_model(img).detach()
            ddpm_output = ddpm_model(dip_output)
            progress_bar.update(1)
            if (epoch + 1) % 1 == 0:
                save_image(ddpm_output, f'output_DDPM_epoch{epoch+1}.png')

# 初始化模型並進行訓練
dip_model = DIP()
ddpm_model = DDPM(T=500)

# 訓練DIP模型
dip_model = train_DIP(dip_model, dataloader, epochs=2)

# 使用DIP模型的輸出作為DDPM初始先驗進行訓練
train_DDPM(ddpm_model, dip_model, dataloader, epochs=2)


Epoch [1/2]: 100%|██████████| 469/469 [13:03<00:00,  1.67s/it, loss=0.00062]
Epoch [2/2]: 100%|██████████| 469/469 [12:41<00:00,  1.62s/it, loss=0.000806]
Epoch [1/2]: 100%|██████████| 469/469 [04:46<00:00,  1.64it/s]
Epoch [2/2]: 100%|██████████| 469/469 [04:46<00:00,  1.64it/s]
