In [1]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve  # 生成S形二维数据点 https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html
import torch
import torch.nn as nn
from tqdm import tqdm

## ----------------------------- 1、生成数据，(10000, 2)的数据点集，组成一个S形 ----------------------------- ##
s_curve, _ = make_s_curve(10 ** 4, noise=0.1)  # 生成10000个数据点，形状为S形并且带有噪声，shape为(10000,3)，形状是3维的
s_curve = s_curve[:, [0, 2]] / 10.0 # 选择数据的第一列和第三列，并进行缩放
print("shape of s:", np.shape(s_curve))
dataset = torch.Tensor(s_curve).float()

shape of s: (10000, 2)


In [2]:
## ----------------------------- 2、确定超参数的值 ----------------------------- ##
# 采样时间步总长度 t
num_steps = 100
 
# 制定每一步的beta
betas = torch.linspace(-6, 6, num_steps) # 在-6到6之间生成100个等间距的值
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5 # 将betas缩放到合适的范围
 
# 计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas # 计算每一步的alpha值
alphas_prod = torch.cumprod(alphas, 0) # 每个t时刻的alpha值的累积乘积
# alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod) # 计算累积乘积的平方根
one_minus_alphas_bar_log = torch.log(1 - alphas_prod) # 计算1减去累积乘积的对数
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) # 计算1减去累积乘积的平方根

In [3]:
## ----------------------------- 3、确定扩散前向过程任意时刻的采样值 x[t]： x[0] + t --> x[t] ----------------------------- ##此代码并未使用这个
def q_x(x_0, t):
    """
    x[0] + t --> x[t]
    :param x_0:初始数据
    :param t:任意时刻
    :return:
    """
    noise = torch.randn_like(x_0)
    alphas_t = alphas_bar_sqrt[t]
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
    x_t = alphas_t * x_0 + alphas_1_m_t * noise
    return x_t


In [4]:
## ----------------------------- 4、编写求逆扩散过程噪声的模型U-Net（这里使用的是MLP模拟U-Net，官方使用的是U-Net） x[t] + t --> noise_predict----------------------------- ##预测噪声
class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
 
        self.linears = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
            ]
        )
 
    def forward(self, x, t):
        #  x = x[0]
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)
        x = self.linears[-1](x)
 
        return x


In [5]:
## ----------------------------- 损失函数 = 真实噪声eps与预测出的噪声noise_predict 之间的loss ----------------------------- ##
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]
 
    # 对一个batchsize样本生成随机的时刻t, t的形状是torch.Size([batchsize, 1])
    t = torch.randint(0, n_steps, size=(batch_size // 2,)) # 随机生成时间步t，一半时间
    t = torch.cat([t, n_steps - 1 - t], dim=0) # 创建对称的时间步
    t = t.unsqueeze(-1) # 添加一个维度，使t的形状为(batch_size, 1)
 
    ## 1) 根据 alphas_bar_sqrt, one_minus_alphas_bar_sqrt --> 得到任意时刻t的采样值x[t]
    # x0的系数
    a = alphas_bar_sqrt[t] # 获取时间步t对应的alphas_bar_sqrt值
    # 噪声eps的系数
    aml = one_minus_alphas_bar_sqrt[t] # 获取时间步t对应的one_minus_alphas_bar_sqrt值
    # 生成生成与x_0形状相同的随机噪声e
    e = torch.randn_like(x_0)
    # 计算任意时刻t的采样值
    x = x_0 * a + e * aml
 
    ## 2) x[t]送入U-Net模型，得到t时刻的随机噪声预测值，这里是用UNet直接预测噪声，输入网络的参数是加上噪声的图像和时间t，网络返回预测所加的噪声
    output = model(x, t.squeeze(-1))
 
    ## 3)计算真实噪声eps与预测出的噪声之间的loss
    loss = (e - output).square().mean()
    return loss

In [6]:
## ----------------------------- 训练模型 ----------------------------- ##

if __name__ == "__main__":
    print('Training model...')
    batch_size = 128
    num_epoch = 4000
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = MLPDiffusion(num_steps)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for t in tqdm(range(num_epoch),desc="Traing epoch"):
        for idx, batch_x in enumerate(dataloader):
            loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            optimizer.step()
    
        if (t % 100 == 0):
            print(loss)
            torch.save(model.state_dict(), 'model_{}.pth'.format(t))


Training model...


Traing epoch:   0%|                                                                   | 1/4000 [00:00<26:15,  2.54it/s]

tensor(0.6277, grad_fn=<MeanBackward0>)


Traing epoch:   3%|█▋                                                               | 101/4000 [00:50<33:59,  1.91it/s]

tensor(0.7523, grad_fn=<MeanBackward0>)


Traing epoch:   5%|███▎                                                             | 201/4000 [01:45<40:31,  1.56it/s]

tensor(0.8193, grad_fn=<MeanBackward0>)


Traing epoch:   8%|████▉                                                            | 301/4000 [02:36<24:34,  2.51it/s]

tensor(0.4371, grad_fn=<MeanBackward0>)


Traing epoch:  10%|██████▌                                                          | 401/4000 [03:17<23:39,  2.54it/s]

tensor(0.2920, grad_fn=<MeanBackward0>)


Traing epoch:  13%|████████▏                                                        | 501/4000 [03:56<22:51,  2.55it/s]

tensor(0.4311, grad_fn=<MeanBackward0>)


Traing epoch:  15%|█████████▊                                                       | 601/4000 [04:35<22:17,  2.54it/s]

tensor(0.3186, grad_fn=<MeanBackward0>)


Traing epoch:  18%|███████████▍                                                     | 701/4000 [05:14<21:37,  2.54it/s]

tensor(0.4662, grad_fn=<MeanBackward0>)


Traing epoch:  20%|█████████████                                                    | 801/4000 [05:54<21:24,  2.49it/s]

tensor(0.2823, grad_fn=<MeanBackward0>)


Traing epoch:  23%|██████████████▋                                                  | 901/4000 [06:34<21:05,  2.45it/s]

tensor(0.3782, grad_fn=<MeanBackward0>)


Traing epoch:  25%|████████████████                                                | 1001/4000 [07:20<18:50,  2.65it/s]

tensor(0.3412, grad_fn=<MeanBackward0>)


Traing epoch:  28%|█████████████████▌                                              | 1101/4000 [07:58<18:34,  2.60it/s]

tensor(0.2555, grad_fn=<MeanBackward0>)


Traing epoch:  30%|███████████████████▏                                            | 1201/4000 [08:37<19:01,  2.45it/s]

tensor(0.3031, grad_fn=<MeanBackward0>)


Traing epoch:  33%|████████████████████▊                                           | 1301/4000 [09:18<18:39,  2.41it/s]

tensor(0.2732, grad_fn=<MeanBackward0>)


Traing epoch:  35%|██████████████████████▍                                         | 1401/4000 [09:59<17:45,  2.44it/s]

tensor(0.3399, grad_fn=<MeanBackward0>)


Traing epoch:  38%|████████████████████████                                        | 1501/4000 [10:40<17:26,  2.39it/s]

tensor(0.2077, grad_fn=<MeanBackward0>)


Traing epoch:  40%|█████████████████████████▌                                      | 1601/4000 [11:22<16:43,  2.39it/s]

tensor(0.3160, grad_fn=<MeanBackward0>)


Traing epoch:  43%|███████████████████████████▏                                    | 1701/4000 [12:04<16:08,  2.37it/s]

tensor(0.3021, grad_fn=<MeanBackward0>)


Traing epoch:  45%|████████████████████████████▊                                   | 1801/4000 [12:45<15:06,  2.43it/s]

tensor(0.2696, grad_fn=<MeanBackward0>)


Traing epoch:  48%|██████████████████████████████▍                                 | 1901/4000 [13:27<14:26,  2.42it/s]

tensor(0.5254, grad_fn=<MeanBackward0>)


Traing epoch:  50%|████████████████████████████████                                | 2001/4000 [14:08<13:40,  2.44it/s]

tensor(0.2974, grad_fn=<MeanBackward0>)


Traing epoch:  53%|█████████████████████████████████▌                              | 2101/4000 [14:50<12:56,  2.44it/s]

tensor(0.2059, grad_fn=<MeanBackward0>)


Traing epoch:  55%|███████████████████████████████████▏                            | 2201/4000 [15:31<12:27,  2.41it/s]

tensor(0.4414, grad_fn=<MeanBackward0>)


Traing epoch:  58%|████████████████████████████████████▊                           | 2301/4000 [16:13<11:45,  2.41it/s]

tensor(0.7000, grad_fn=<MeanBackward0>)


Traing epoch:  60%|██████████████████████████████████████▍                         | 2401/4000 [16:55<10:57,  2.43it/s]

tensor(0.2715, grad_fn=<MeanBackward0>)


Traing epoch:  63%|████████████████████████████████████████                        | 2501/4000 [17:37<10:28,  2.39it/s]

tensor(0.2859, grad_fn=<MeanBackward0>)


Traing epoch:  65%|█████████████████████████████████████████▌                      | 2601/4000 [18:21<10:09,  2.29it/s]

tensor(0.2143, grad_fn=<MeanBackward0>)


Traing epoch:  68%|███████████████████████████████████████████▏                    | 2701/4000 [19:03<09:13,  2.35it/s]

tensor(0.3457, grad_fn=<MeanBackward0>)


Traing epoch:  70%|████████████████████████████████████████████▊                   | 2801/4000 [19:45<08:32,  2.34it/s]

tensor(0.2908, grad_fn=<MeanBackward0>)


Traing epoch:  73%|██████████████████████████████████████████████▍                 | 2901/4000 [20:28<07:49,  2.34it/s]

tensor(0.2334, grad_fn=<MeanBackward0>)


Traing epoch:  75%|████████████████████████████████████████████████                | 3001/4000 [21:11<07:56,  2.09it/s]

tensor(0.3302, grad_fn=<MeanBackward0>)


Traing epoch:  78%|█████████████████████████████████████████████████▌              | 3101/4000 [22:02<05:56,  2.52it/s]

tensor(0.3580, grad_fn=<MeanBackward0>)


Traing epoch:  80%|███████████████████████████████████████████████████▏            | 3201/4000 [22:43<05:30,  2.42it/s]

tensor(0.2557, grad_fn=<MeanBackward0>)


Traing epoch:  83%|████████████████████████████████████████████████████▊           | 3301/4000 [23:24<05:05,  2.29it/s]

tensor(0.1737, grad_fn=<MeanBackward0>)


Traing epoch:  85%|██████████████████████████████████████████████████████▍         | 3401/4000 [24:08<04:13,  2.36it/s]

tensor(0.1599, grad_fn=<MeanBackward0>)


Traing epoch:  88%|████████████████████████████████████████████████████████        | 3501/4000 [24:52<03:36,  2.30it/s]

tensor(0.5057, grad_fn=<MeanBackward0>)


Traing epoch:  90%|█████████████████████████████████████████████████████████▌      | 3601/4000 [25:36<02:54,  2.29it/s]

tensor(0.4080, grad_fn=<MeanBackward0>)


Traing epoch:  93%|███████████████████████████████████████████████████████████▏    | 3701/4000 [26:18<02:05,  2.39it/s]

tensor(0.6685, grad_fn=<MeanBackward0>)


Traing epoch:  95%|████████████████████████████████████████████████████████████▊   | 3801/4000 [27:02<01:21,  2.44it/s]

tensor(0.3077, grad_fn=<MeanBackward0>)


Traing epoch:  98%|██████████████████████████████████████████████████████████████▍ | 3901/4000 [27:46<00:41,  2.36it/s]

tensor(0.2923, grad_fn=<MeanBackward0>)


Traing epoch: 100%|████████████████████████████████████████████████████████████████| 4000/4000 [28:30<00:00,  2.34it/s]


In [9]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve  # 生成S形二维数据点 https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html
import torch
import torch.nn as nn
from tqdm import tqdm

from train import MLPDiffusion



## ----------------------------- 1、生成数据，(10000, 2)的数据点集，组成一个S形 ----------------------------- ##
s_curve, _ = make_s_curve(10 ** 4, noise=0.1)  # 10000个数据点
s_curve = s_curve[:, [0, 2]] / 10.0
print("shape of s:", np.shape(s_curve))
dataset = torch.Tensor(s_curve).float()

## ----------------------------- 2、确定超参数的值 ----------------------------- ##
# 采样时间步总长度 t
num_steps = 100
 
# 制定每一步的beta
betas = torch.linspace(-6, 6, num_steps)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
 
# 计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    """
    从x[t]采样t-1时刻的重构值x[t-1]，根据论文中的采样公式计算单步的采样
    :param model:
    :param x: x[T]
    :param t:
    :param betas:
    :param one_minus_alphas_bar_sqrt:
    :return:
    """
    ## 1) 求出 bar_u_t
    t = torch.tensor([t])
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t] # 这里先计算采样公式中的一部分参数，方便后面表示，看不懂的可以直接对着论文公式看
    # 送入U-Net模型，得到t时刻的随机噪声预测值 eps_theta
    eps_theta = model(x, t)
    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))
 
    ## 2) 得到 x[t-1]
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return sample

def p_sample_loop(model, noise_x_t, n_steps, betas, one_minus_alphas_bar_sqrt):
    """
    从x[T]恢复x[T-1]、x[T-2]|...x[0] 的循环
    :param model:
    :param shape:数据集的形状，也就是x[T]的形状
    :param n_steps:
    :param betas:
    :param one_minus_alphas_bar_sqrt:
    :return: x_seq由x[T]、x[T-1]、x[T-2]|...x[0]组成, cur_x是从噪声中生成的图片
    """
    # 得到噪声x[T]
    cur_x = noise_x_t # 初始化当前的x为噪声x[T]
    x_seq = [noise_x_t] # 初始化x序列为第一个元素为x[T],也就是纯噪声
    # 从x[T]恢复x[T-1]、x[T-2]|...x[0]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq, cur_x

# 1) 加载训练好的diffusion model
model = MLPDiffusion(num_steps)
model.load_state_dict(torch.load('./checkpoints_cpu/model_3900.pth'))

# 2) 生成随机噪声x[T]
noise_x_t = torch.randn(dataset.shape)

# 3) 根据随机噪声逆扩散为x[T-1]、x[T-2]|...x[0] + 图片x[0]
x_seq, cur_x = p_sample_loop(model, noise_x_t, num_steps, betas, one_minus_alphas_bar_sqrt)

# 4) 绘制并保存图像
def plot_samples(x_seq, cur_x):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    # 绘制 x_seq
    for i, x in enumerate(x_seq):
        if i % 10 == 0:  # 每10个时间步绘制一次
            ax[0].scatter(x.detach().numpy()[:, 0], x.detach().numpy()[:, 1], label=f'Step {i}', alpha=0.5)
    ax[0].legend()
    ax[0].set_title('x_seq')
    
    # 绘制 cur_x
    ax[1].scatter(cur_x.detach().numpy()[:, 0], cur_x.detach().numpy()[:, 1], color='red')
    ax[1].set_title('cur_x')
    
    plt.savefig('samples_plot.png')
    plt.show()

plot_samples(x_seq, cur_x)

shape of s: (10000, 2)


  model.load_state_dict(torch.load('./checkpoints_cpu/model_3900.pth'))


FileNotFoundError: [Errno 2] No such file or directory: './checkpoints_cpu/model_3900.pth'