## 2. 代码部分

In [None]:
## 导包部分
import torch.nn as nn
import torch
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import numpy as np
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
import datetime
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML
# 路径：
project_path = os.getcwd()
data_path = os.path.join(project_path, 'data')
result_path = os.path.join(project_path, 'results')
model_path = os.path.join(result_path, 'model')
save_dir = os.path.join(result_path, 'figures')
print(f'[Info] setting project path: \n{project_path}')


### 2.1 网络构建

#### 2.1.1 定义 Residual Conv 网络
Residual Conv 是指利用残差连接将输入和输出跳跃连接起来, 使得梯度可以回传到早期层，防止梯度消失。

In [4]:
class ResidualConvBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size = 3,
            stride = 1,
            padding = 1,
            # 是否启用残差连接
            is_res: bool = False
    ):
        super().__init__()
        # 检查输入和输出的channel数是否相同
        self.same_channels = in_channels == out_channels
        # 是否启用残差连接
        self.is_res = is_res

        # 第一层卷积
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

        # 第二层卷积
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_channels, 
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

        # 如果输入和输出的通道数不同，则使用 1x1 卷积来调整输入的通道数
        if not self.same_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0
            )
        
    def forward(self, x):

        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            if self.same_channels:
                out = x + x2
            else:
                # 如果输入和输出不是相同的维度, 则通过一个1x1的卷积层将输入和输出的维度保持一致
                shortcut_x = self.shortcut(x)
                out = shortcut_x + x2
            # 如果进行残差连接，假设x和x2是独立同分布的，那么相加之后，新分布out的方差是原来的两倍
            # 需要将新的输出除以根号2，来降低out的方差
            return out / torch.sqrt(torch.tensor(2.0))
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2
            
    def get_out_channels(self):
        return self.conv2[0].out_channels
    
        # Method to set the number of output channels for this block
    def set_out_channels(self, out_channels):
        self.conv1[0].out_channels = out_channels
        self.conv2[0].in_channels = out_channels
        self.conv2[0].out_channels = out_channels
                

#### 2.1.2 定义 UnetDown 网络
下采样网络，将高纬的空间信息通过pool的操作提取特征

In [5]:
class UnetDown(nn.Module):
    # 默认不需要残差连接
    def __init__(self, in_channels, out_channels, is_res = False, pool_kernel_size = 2):
        super().__init__()
        layers = [
            ResidualConvBlock(in_channels, out_channels, is_res=is_res),
            ResidualConvBlock(out_channels, out_channels, is_res=is_res),
            nn.MaxPool2d(kernel_size=pool_kernel_size)
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

#### 2.1.3 定义 UnetDown-embed-UnetUP 中 embed 网络
在标准Unet网络当中，是不需要嵌入的操作的，但是在扩散模型当中会对输入进行逐步的去噪，这个过程中需要结合当前的时间步信息，这种时间步信息通常需要嵌入到一个适合网络处理的维度中，EmbedFC 这样的嵌入层可以用于将时间步 t 嵌入到一个更高维的向量空间中，然后作为额外的条件信息输入到网络中。

In [6]:
# input_dim 定义为 1
class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super().__init__()

        self.input_dim = input_dim

        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        # input:(10,256,10,10)
        x = x.view(-1, self.input_dim)
        return self.model(x)


#### 2.1.3 定义 UnetUP 网络


In [7]:
class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels, is_res = False, upsample_size = 2):
        super().__init__()
        layers = [
            # 使用 ConvTranspose2d 对特征图进行上采样，将空间尺寸扩展为原来的两倍。
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=upsample_size, stride=upsample_size),
            ResidualConvBlock(out_channels, out_channels, is_res=is_res),
            ResidualConvBlock(out_channels, out_channels, is_res=is_res)
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        # 将上采样后的特征图与 skip connection 进行拼接
        # 会在特征纬度上进行拼接
        x = torch.cat((x, skip), dim=1)
        
        return self.model(x)

In [8]:
class ContextUnet(nn.Module):
    def __init__(self,
                 in_channels,
                 n_feat = 256,
                 n_cfeat = 10,
                 height = 28,
                 time_input_dim = 1,
                 group_norm_group = 8,
                 out_kernel_size = 3,
                 out_stride = 1,
                 out_padding = 1,
                 # 压缩比例,to_vec 中将特征图的高度压缩的比例
                 # 在up0当中，再将特征图的高度扩展的比例，保持一致
                 compress_ratio = 4):
        super().__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
        # 通过down1 的操作，将特征图的高度减半，但是特征图的通道不变，类似于特征提取
        self.down1 = UnetDown(n_feat, n_feat)
        # 通过down2 的操作，将特征图的高度减半，但是特征图的通道增加一倍，类似于特征压缩过程
        self.down2 = UnetDown(n_feat, 2 * n_feat)
        # 生成一个形状为 (10, 256, 10, 10) 的张量
        # x = torch.randn(10, 256, 10, 10)
        # print(x.shape) torch.Size([10, 256, 10, 10])
        # reshape_x = nn.AvgPool2d((10))(x)
        # print(reshape_x.shape) torch.Size([10, 256, 1, 1])
        # out = nn.GELU()(reshape_x)
        # print(out.shape) torch.Size([10, 256, 1, 1])
        self.to_vec = nn.Sequential(nn.AvgPool2d((compress_ratio)), nn.GELU())

        self.timeembed1 = EmbedFC(time_input_dim, 2 * n_feat)
        self.timeembed2 = EmbedFC(time_input_dim, 1 * n_feat)

        self.contextembed1 = EmbedFC(n_cfeat, 2 * n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1 * n_feat)
        # 1. convtranspose2d
        # nn.ConvTranspose2d 是 PyTorch 中用于实现反卷积（也叫转置卷积）的操作。其主要用于上采样输入特征图，使得特征图的空间尺寸增大。
	    # 如果 self.h = 28，那么 kernel_size = 7 和 stride = 7。这种设置会使得输出特征图的空间尺寸大约是输入特征图的七倍。
        # 2. GroupNorm
        #   nn.GroupNorm 并不是计算所有组的均值和方差后统一应用，而是对每一组内的通道进行独立的均值和方差计算，之后用来标准化该组的通道。
        #   每组的均值和方差是独立计算的，标准化也是对每组单独完成的，保证每一组的特征在标准化后均值为 0，方差为 1。
        #   这种分组归一化的方式使得每组的特征能够相互依赖，但不同组之间的归一化操作是独立的。这在保持特征的稳定性的同时，不依赖大批量数据，因此在小批量训练或者单个样本情况下表现很好。
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, compress_ratio, compress_ratio),
            nn.GroupNorm(group_norm_group, 2 * n_feat),
            nn.ReLU()
        )

        self.up1 = UnetUp( 4 * n_feat, n_feat)
        self.up2 = UnetUp( 2 * n_feat, n_feat)

        self.out = nn.Sequential(
            nn.Conv2d( 2 * n_feat, n_feat, 
                      kernel_size=out_kernel_size, 
                      stride=out_stride,
                        padding=out_padding),
            nn.GroupNorm(group_norm_group, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, 
                      self.in_channels, 
                      kernel_size=out_kernel_size,
                      stride=out_stride,
                      padding=out_padding)
        )

    def forward(self, x, t, c=None):
        """
        x : (batch, in_channels, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """
        # x : (batch, in_channels, h, w) -> (batch, n_feat, h, w)
        x = self.init_conv(x)
        # pool_size: 2 x: (batch, n_feat, h, w) -> (batch, n_feat, h/2, w/2)
        down1 = self.down1(x)
        # pool_size: 2 x: (batch, n_feat, h/2, w/2) -> (batch, 2*n_feat, h/4, w/4)
        down2 = self.down2(down1)
        # pool_size: 4 x: (batch, 2*n_feat, h/4, w/4) -> (batch, 2*n_feat, h/16, h/16)
        hiddenvec = self.to_vec(down2)
        # context_mask 的概念用于控制是否使用上下文信息。
        # 上下文信息可以是各种条件信息，例如类别标签、额外的特征描述、控制变量等，
        # 在这里通常用来调节网络的行为以适应不同的条件输入。
        # 在 ContextUnet 代码中，context_mask 的作用是动态地选择是否使用上下文条件，
        # 即条件信息是否参与模型的计算。
        # 如果没有传入c，则默认为None，即不使用上下文信息，初始化为全零张量。
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x.device)

        # c: batch_size, n_cfeat -> cemb1: batch_size, 2 * n_feat
        cemb1 = self.contextembed1(c)
        # cemb1: batch_size, 2 * n_feat -> cemb1: batch_size, 2 * n_feat, 1, 1
        cemb1 = cemb1.view(-1, 2 * self.n_feat, 1, 1)
        # t: batch_size, -> temb1: batch_size, 2 * n_feat
        temb1 = self.timeembed1(t)
        # temb1: batch_size, 2 * n_feat -> temb1: batch_size, 2 * n_feat, 1, 1
        temb1 = temb1.view(-1, 2 * self.n_feat, 1, 1)
        # c: batch_size, n_cfeat -> cemb: batch_size, n_feat
        cemb2 = self.contextembed2(c)
        # cemb: batch_size, n_feat -> cemb: batch_size, n_feat, 1, 1
        cemb2 = cemb2.view(-1, self.n_feat, 1, 1)
        # t: batch_size, -> temb: batch_size, n_feat
        temb2 = self.timeembed2(t)
        # temb: batch_size, n_feat -> temb: batch_size, n_feat, 1, 1
        temb2 = temb2.view(-1, self.n_feat, 1, 1)

        # hiddenvec: batch_size, 2 * n_feat, 1, 1 -> up1: batch_size, 2 * n_feat, h/4, w/4
        up1 = self.up0(hiddenvec)
        # cemb * up1: batch_size, 2 * n_feat, h/4, w/4 （broadcast?）
        # temb1: batch_size, 2 * n_feat, 1, 1 
        # cemb1 * up1 + temb1: batch_size, 2 * n_feat, h/4, w/4
        # down2: batch_size, 2 * n_feat, h/4, w/4
        # 在up1当中，先将cemb1 * up1 + temb1 和 down2 进行拼接，
        # 形状变成了(batch_size, 4 * n_feat, h/4, w/4)
        # 然后再进行卷积操作，pool_size: 2， out_channels: n_feat
        # 形状变成了(batch_size, n_feat, h/2, w/2)
        up2 = self.up1(cemb1 * up1 + temb1, down2)
        # 第二层上采样
        # cemb2 * up2 + temb2: batch_size, n_feat, h/2, w/2
        # down1: batch_size, n_feat, h/2, w/2
        # 在up2当中，先将cemb2 * up2 + temb2 和 down1 进行拼接，
        # 形状变成了(batch_size, 2 * n_feat, h/2, w/2)
        # 然后再进行卷积操作，pool_size: 2， out_channels: n_feat
        # 形状变成了(batch_size, n_feat, h, w)
        up3 = self.up2(cemb2 * up2 + temb2, down1)

        # 最后一层卷积
        # up3: batch_size, n_feat, h, w
        # x: batch_size, n_feat, h, w
        # last_cat: batch_size, 2 * n_feat, h, w
        last_cat = torch.cat((up3, x), dim=1)
        # out: 2 * feat -> feat -> in_channels
        # out: batch_size, in_channels, h, w
        out = self.out(last_cat)

        return out



### 2.2 训练代码

#### 2.2.1 参数设置

In [9]:
# 1. DDPM 超参数
# 时间步:
time_steps = 500
# 第一个时间步和最后一个时间步的超声水平，第一步一般比较小，最后一般比较大
beta1 = 1e-4
beta2 = 0.02

# 2. 网络超参数
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
# device = torch.device('cpu')
# 每个时间步的隐变量
n_feat = 32
# 上下文变量
n_cfeat = 5
# 图像的长和宽，默认是个正方形
height = 16

# 3. 训练超参数
batch_size = 100
n_epoch = 32
lr = 1e-3

#### 2.2.2 加噪过程

In [10]:
# 在β1 - β2 之间进行线性插值来进行噪声调度: b_t: (time_steps + 1,)
b_t = (beta2 - beta1) * torch.linspace(0, 1, time_steps + 1, device=device) + beta1
# 计算at,即保留率，原始图像的保留比率
a_t = 1 - b_t
# 计算累计噪声因子: loga1 + loga2 + loga3 + ... + logan = log (a1 * a2 * a3 * ... * an), 再取exp变为原值
# 计算的是在时间步，之前累计的原始图像的保有率是多少，例如： 1：80%, 2: 50%, 累计：80% * 50% = 40% 
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# 第一步为原始图像
ab_t[0] = 1

# 加噪函数： 向图像中加入高斯分布采样的噪声
# 加噪公式：ab_t * x + (1 - ab_t) * noise
# x shape: (batch_size, height, width, channels)
def perturb_input(x, t, noise):
    # [t, None, None, None]是广播机制，使得张量形状相同
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t)[t, None, None, None] * noise

#### 2.2.3 去噪模型

In [11]:
nn_model = ContextUnet(
    in_channels=3,
    n_feat=n_feat,
    n_cfeat=n_cfeat,
    height=height
).to(device)

### 2.3 训练

#### 2.3.1 数据准备

In [22]:
class CustomDataset(Dataset):
    def __init__(self, sfilename, lfilename, transform, null_context = False):
        super().__init__()
        print('[Info] Loading dataset...')
        self.sprites = np.load(sfilename)
        self.slabels = np.load(lfilename)
        print('[Info] Loading success!')
        print(f'sprite shape: {self.sprites.shape}')
        print(f'labels shape: {self.slabels.shape}')

        self.transform = transform
        self.null_context = null_context
        self.sprites_shape = self.sprites.shape
        self.slabels_shape = self.slabels.shape
    # 定义一些dataset 必备的方法：len/getitem
    def __len__(self):
        return len(self.sprites)
    
    def __getitem__(self, index):
        if self.transform:
            image = self.transform(self.sprites[index])
            if self.null_context:
                label = torch.tensor(0).to(torch.int64)
            else:
                label = torch.tensor(self.slabels[index]).to(torch.int64)
        else:
            # print('[Info] Without transform, you need check the image dtype and dimension order: (C H W) in torch')
            image = torch.tensor(self.sprites[index])
            label = torch.tensor(self.slabels[index],dtype=torch.int64)
        
        return (image, label)
    
    def getshapes(self):
        return self.sprites_shape, self.slabels_shape

In [23]:

# 对图像进行的预处理操作
transform = transforms.Compose([
    # 将图像从 NumPy 数组或 PIL 图像 转换为 PyTorch 张量
    # from [0,255] to range [0.0,1.0]
    # 自动改变图像的维度顺序，从 (H, W, C) -> (C, H, W)
    transforms.ToTensor(), 
    # 将每个通道的均值减去 0.5，然后除以标准差 0.5。
    # 将输入张量的数值范围从 [0.0, 1.0] 转换到 [-1, 1]              
    transforms.Normalize((0.5,), (0.5,))  # range [-1,1]
])

dataset = CustomDataset(
    sfilename=os.path.join(data_path, 'sprites_1788_16x16.npy'),
    lfilename=os.path.join(data_path, 'sprite_labels_nc_1788_16x16.npy'),
    transform=transform,
    null_context=False
)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)
optim = torch.optim.Adam(nn_model.parameters(), lr=lr)


[Info] Loading dataset...
[Info] Loading success!
sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


#### 2.3.2 开始训练

In [24]:
print('[Info] Start training...')
print(f'[Info] device: {device}')
print(f'[Info] batch_size: {batch_size}')
print(f'[Info] n_epoch: {n_epoch}')
print(f'[Info] lr: {lr}')
print(f'[Info] time_steps: {time_steps}')
print(f'[Info] beta1: {beta1}')
print(f'[Info] beta2: {beta2}')
print(f'[Info] n_feat: {n_feat}')
print(f'[Info] n_cfeat: {n_cfeat}')
print(f'[Info] height: {height}')
print(f'[Info] model total params: {sum(p.numel() for p in nn_model.parameters())}')
print(f'[Info] model total trainable params: {sum(p.numel() for p in nn_model.parameters() if p.requires_grad)}')
for ep in range(n_epoch):
    loss_epoch = []
    print(f'[Info] epoch: {ep}/{n_epoch}')
    # 动态学习率
    optim.param_groups[0]['lr'] = lr * (1-ep/n_epoch)
    
    # 创建进度条
    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar:
        optim.zero_grad()

        x = x.to(device)
        # 使用 torch.randn_like(x) 来生成一个与输入张量 x 形状相同服从标准正态分布（均值为 0，方差为 1）的随机噪声张量。
        noise = torch.randn_like(x)
        # 对于每一个样本 x，随机选择一个时间步 t，然后对 x 进行加噪，生成扰动图像 x_t，
        # 模型的目标是从 x_t 预测出原始的图像或中间步骤的干净图像
        t = torch.randint(1, time_steps + 1, (x.shape[0],)).to(device)
        x_pert = perturb_input(x, t, noise).to(device)
        pred_noise = nn_model(
            x_pert,
            # 将时间步进行归一化
            t / time_steps
        )
        # 回归问题，用MSE error
        loss = F.mse_loss(pred_noise, noise)
        loss_epoch.append(loss.item())
        loss.backward()
        optim.step()
    # 计算平均损失
    mean_loss = np.mean(loss_epoch)
    print(f'[Info] epoch: {ep}/{n_epoch}, loss: {mean_loss}')
    if ep == int(n_epoch - 1):
        if not os.path.exists(model_path):
            os.mkdir(model_path)
        current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_save_path = os.path.join(model_path, f'model_on_epoch_{ep}_{current_time}')
        torch.save(nn_model.state_dict(), model_save_path)
        print(f'[Info] model save at: {model_save_path}')
    

[Info] Start training...
[Info] device: cuda:0
[Info] batch_size: 100
[Info] n_epoch: 32
[Info] lr: 0.001
[Info] time_steps: 500
[Info] beta1: 0.0001
[Info] beta2: 0.02
[Info] n_feat: 32
[Info] n_cfeat: 5
[Info] height: 16
[Info] model total params: 375011
[Info] model total trainable params: 375011
[Info] epoch: 0/32


100%|██████████| 894/894 [00:16<00:00, 53.18it/s]


[Info] epoch: 0/32, loss: 0.23933626424679552
[Info] epoch: 1/32


100%|██████████| 894/894 [00:16<00:00, 54.74it/s]


[Info] epoch: 1/32, loss: 0.1836802214884118
[Info] epoch: 2/32


100%|██████████| 894/894 [00:17<00:00, 52.04it/s]


[Info] epoch: 2/32, loss: 0.17116769653238706
[Info] epoch: 3/32


100%|██████████| 894/894 [00:16<00:00, 53.73it/s]


[Info] epoch: 3/32, loss: 0.16452340353321976
[Info] epoch: 4/32


100%|██████████| 894/894 [00:17<00:00, 52.37it/s]


[Info] epoch: 4/32, loss: 0.15861609028923165
[Info] epoch: 5/32


100%|██████████| 894/894 [00:16<00:00, 54.41it/s]


[Info] epoch: 5/32, loss: 0.15482068496029117
[Info] epoch: 6/32


100%|██████████| 894/894 [00:15<00:00, 56.86it/s]


[Info] epoch: 6/32, loss: 0.1513781051507732
[Info] epoch: 7/32


100%|██████████| 894/894 [00:15<00:00, 58.04it/s]


[Info] epoch: 7/32, loss: 0.148964546029133
[Info] epoch: 8/32


100%|██████████| 894/894 [00:16<00:00, 54.70it/s]


[Info] epoch: 8/32, loss: 0.1451976345956192
[Info] epoch: 9/32


100%|██████████| 894/894 [00:16<00:00, 55.48it/s]


[Info] epoch: 9/32, loss: 0.14403079610142933
[Info] epoch: 10/32


100%|██████████| 894/894 [00:16<00:00, 54.95it/s]


[Info] epoch: 10/32, loss: 0.14345418669333393
[Info] epoch: 11/32


100%|██████████| 894/894 [00:16<00:00, 53.93it/s]


[Info] epoch: 11/32, loss: 0.14049297799533378
[Info] epoch: 12/32


100%|██████████| 894/894 [00:15<00:00, 58.06it/s]


[Info] epoch: 12/32, loss: 0.13853003588032137
[Info] epoch: 13/32


100%|██████████| 894/894 [00:17<00:00, 50.25it/s]


[Info] epoch: 13/32, loss: 0.137009203183971
[Info] epoch: 14/32


100%|██████████| 894/894 [00:17<00:00, 50.88it/s]


[Info] epoch: 14/32, loss: 0.13591475523924934
[Info] epoch: 15/32


100%|██████████| 894/894 [00:16<00:00, 52.93it/s]


[Info] epoch: 15/32, loss: 0.1350219123798862
[Info] epoch: 16/32


100%|██████████| 894/894 [00:16<00:00, 55.12it/s]


[Info] epoch: 16/32, loss: 0.1333006476395882
[Info] epoch: 17/32


100%|██████████| 894/894 [00:16<00:00, 54.22it/s]


[Info] epoch: 17/32, loss: 0.13201285110290686
[Info] epoch: 18/32


100%|██████████| 894/894 [00:16<00:00, 54.83it/s]


[Info] epoch: 18/32, loss: 0.1312951760354058
[Info] epoch: 19/32


100%|██████████| 894/894 [00:16<00:00, 52.85it/s]


[Info] epoch: 19/32, loss: 0.12919924982408815
[Info] epoch: 20/32


100%|██████████| 894/894 [00:16<00:00, 55.41it/s]


[Info] epoch: 20/32, loss: 0.12903485235485188
[Info] epoch: 21/32


100%|██████████| 894/894 [00:16<00:00, 55.67it/s]


[Info] epoch: 21/32, loss: 0.1275817568870199
[Info] epoch: 22/32


100%|██████████| 894/894 [00:15<00:00, 56.68it/s]


[Info] epoch: 22/32, loss: 0.1269939245460284
[Info] epoch: 23/32


100%|██████████| 894/894 [00:16<00:00, 54.92it/s]


[Info] epoch: 23/32, loss: 0.1270334833453699
[Info] epoch: 24/32


100%|██████████| 894/894 [00:17<00:00, 52.21it/s]


[Info] epoch: 24/32, loss: 0.12454572445827576
[Info] epoch: 25/32


100%|██████████| 894/894 [00:16<00:00, 54.84it/s]


[Info] epoch: 25/32, loss: 0.12442043757428659
[Info] epoch: 26/32


100%|██████████| 894/894 [00:17<00:00, 52.39it/s]


[Info] epoch: 26/32, loss: 0.1236797249990555
[Info] epoch: 27/32


100%|██████████| 894/894 [00:16<00:00, 53.10it/s]


[Info] epoch: 27/32, loss: 0.12250499064790323
[Info] epoch: 28/32


100%|██████████| 894/894 [00:16<00:00, 54.94it/s]


[Info] epoch: 28/32, loss: 0.12220784137546349
[Info] epoch: 29/32


100%|██████████| 894/894 [00:15<00:00, 56.58it/s]


[Info] epoch: 29/32, loss: 0.1222489106425103
[Info] epoch: 30/32


100%|██████████| 894/894 [00:16<00:00, 55.74it/s]


[Info] epoch: 30/32, loss: 0.12156609768335451
[Info] epoch: 31/32


100%|██████████| 894/894 [00:16<00:00, 54.30it/s]

[Info] epoch: 31/32, loss: 0.12147338491691573
[Info] model save at: /public/share/sd23/d2l/AI_study/5DiffusionModel/1homework/results/model/model_on_epoch_31_2024-10-11_22-14-34





### 3. 采样过程
采样过程是在一张全部是噪声的图像当中，不断减去预测的噪声，从而恢复原始的图像。

#### 3.1 去噪函数
定义一个从当前时间步的图像当中减去预测噪声的函数, 这一步的算法等同于原文的 Algorithm2：Sampling

In [12]:
def denoise_from_current_timestep(x, t, pred_noise, z = None):
    '''
    x: 原始图像
    t: 当前时间步
    pred_noise: 预测的噪声
    z: bias, 对应DDPM公式中 Samping 中的sigma_t_z
    '''
    if z is None:
        z = torch.randn_like(x)
    # 通过当前时间步 t 和预测的噪声 pred_noise，计算出当前时间步的图像
    noise = b_t.sqrt()[t] * z
    mean = (x - ((1 - a_t[t]) / (1 - ab_t[t]).sqrt()) * pred_noise) / a_t[t].sqrt()

    return mean + noise

#### 3.2 采样过程

In [37]:
print(time_steps)
@torch.no_grad()
def sample_ddpm(n_sample, save_rate = 20):
    '''
    n_sample: 采样数量
    save_rate: 保存采样图像的频率
    '''
    # torch.randn 用于生成服从标准正态分布（均值为0，标准差为1）的随机数张量的函数。
    samples = torch.randn(n_sample, 3, height, height).to(device)

    intermediate = []

    for i in range(time_steps, 0, -1):
        print(f'[Info] sampling at time step: {i}')
        
        # 将 t 转变成模型训练时输入的 张量形式 归一化
        t = torch.tensor([i / time_steps])[: , None, None, None].to(device)

        # 按照原文的算法2, 生成z
        z = torch.randn_like(samples) if i > 1 else 0
        # 预测当前时间步的噪声
        pred_noise = nn_model(samples, t)
        # 通过当前时间步 t 和预测的噪声 pred_noise，计算出当前时间步的图像
        samples = denoise_from_current_timestep(samples, i, pred_noise, z)

        if i % save_rate == 0 or i == time_steps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)

    return samples, intermediate

500


#### 3.3 采样实例

In [38]:
# 加载模型
model_save_path = '/public/share/sd23/d2l/AI_study/5DiffusionModel/1homework/results/model/model_on_epoch_31_2024-10-11_22-14-34'
nn_model.load_state_dict(torch.load(model_save_path, map_location=device))
nn_model.eval()
print(f'[Info] load model from: {model_save_path}')
print('[Info] model: ')
print(nn_model)

[Info] load model from: /public/share/sd23/d2l/AI_study/5DiffusionModel/1homework/results/model/model_on_epoch_31_2024-10-11_22-14-34
[Info] model: 
ContextUnet(
  (init_conv): ResidualConvBlock(
    (conv1): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (conv2): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (shortcut): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (down1): UnetDown(
    (model): Sequential(
      (0): ResidualConvBlock(
        (conv1): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_runni

In [39]:
# 定义可视化函数
# 将生成的数据 x_gen_store 以动态图（GIF）的形式可视化出来
def norm_all(store, n_t, n_s):
    # runs unity norm on all timesteps of all samples
    nstore = np.zeros_like(store)
    for t in range(n_t):
        for s in range(n_s):
            nstore[t,s] = unorm(store[t,s])
    return nstore
def unorm(x):
    # unity norm. results in range of [0,1]
    # assume x (h,w,3)
    # print(x.shape)
    # print(x[:,:,0].max())
    # print(x.max().shape)
    # print(x.max((0,1)))
    xmax = x.max((0,1))
    xmin = x.min((0,1))
    return (x - xmin)/(xmax - xmin)
def plot_sample(x_gen_store, n_sample, nrows, save_dir, file_name, w, save = False):
    '''
    x_gen_store: 生成的图像矩阵
    n_sample: 生成的样本数量
    nrows: 每行显示的样本数量
    save_dir: 保存路径
    file_name: 保存文件名
    w: 标识该 GIF 所基于的特定值
    save: 是否保存
    '''
    n_cols = n_sample // nrows
    # 将数据从 (channels, height, width) 格式转换为 (height, width, channels) 格式，这是为了适应 matplotlib 的图像展示格式（h, w, channels）
    print(f'[Info] x_gen_store shape: {x_gen_store.shape}')
    switch_x_gen_store = np.moveaxis(x_gen_store, 2, 4)
    print(f'[Info] sx_gen_store shape: {switch_x_gen_store.shape}')
    norm_sx_gen_store = norm_all(switch_x_gen_store, switch_x_gen_store.shape[0], n_sample)
    print(f'[Info] norm_sx_gen_store shape: {norm_sx_gen_store.shape}')
    fig, axs = plt.subplots(nrows=nrows, ncols=n_cols, sharex=True, sharey=True,figsize=(n_cols,nrows))
    def animate_diff(i, store):
        plots = []
        for row in range(nrows):
            for col in range(n_cols):
                axs[row, col].clear()
                axs[row, col].set_xticks([])
                axs[row, col].set_yticks([])
                plots.append(axs[row, col].imshow(store[i,(row*n_cols)+col]))
        return plots
    ani = FuncAnimation(fig, animate_diff, fargs=[norm_sx_gen_store],  interval=200, blit=False, repeat=True, frames=norm_sx_gen_store.shape[0]) 
    plt.close()
    if save:
        ani.save(save_dir + f"{file_name}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
        print('saved gif at ' + save_dir + f"{file_name}_w{w}.gif")
    return ani

In [40]:
print('[Info] Start sampling...')
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
print(samples.shape)
print(intermediate_ddpm.shape)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

[Info] Start sampling...
[Info] sampling at time step: 500
[Info] sampling at time step: 499
[Info] sampling at time step: 498
[Info] sampling at time step: 497
[Info] sampling at time step: 496
[Info] sampling at time step: 495
[Info] sampling at time step: 494
[Info] sampling at time step: 493
[Info] sampling at time step: 492
[Info] sampling at time step: 491
[Info] sampling at time step: 490
[Info] sampling at time step: 489
[Info] sampling at time step: 488
[Info] sampling at time step: 487
[Info] sampling at time step: 486
[Info] sampling at time step: 485
[Info] sampling at time step: 484
[Info] sampling at time step: 483
[Info] sampling at time step: 482
[Info] sampling at time step: 481
[Info] sampling at time step: 480
[Info] sampling at time step: 479
[Info] sampling at time step: 478
[Info] sampling at time step: 477
[Info] sampling at time step: 476
[Info] sampling at time step: 475
[Info] sampling at time step: 474
[Info] sampling at time step: 473
[Info] sampling at time

<Figure size 640x480 with 0 Axes>