Skip to content

Commit

Permalink
Update: Edited to english in 'base.py', 'ddpm.py' and 'ddim.py'.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 10, 2023
1 parent 8c0bb55 commit c92fd82
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 95 deletions.
75 changes: 41 additions & 34 deletions model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,63 +12,68 @@

class BaseDiffusion:
"""
扩散模型基类
Base diffusion class
"""

def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cpu"):
"""
扩散模型基类
:param noise_steps: 噪声步长
:param beta_start: β开始值
:param beta_end: β结束值
:param img_size: 图像大小
:param device: 设备类型
Diffusion model base class
:param noise_steps: Noise steps
:param beta_start: β start
:param beta_end: β end
:param img_size: Image size
:param device: Device type
"""
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_size = img_size
self.device = device

# 噪声步长
# Noise steps
self.beta = self.prepare_noise_schedule().to(self.device)
# 公式α = 1 - β
# Formula: α = 1 - β
self.alpha = 1. - self.beta
# 这里做α累加和操作
# The cumulative sum of α.
self.alpha_hat = torch.cumprod(input=self.alpha, dim=0)

def prepare_noise_schedule(self, schedule_name="linear"):
"""
准备噪声schedule,可以自定义,可使用openai的schedule
:param schedule_name: 方法名称,linear线性方法;cosine余弦方法
Prepare the noise schedule
:param schedule_name: Function, linear and cosine
:return: schedule
"""
if schedule_name == "linear":
# torch.linspace为指定的区间内生成一维张量,其中的值均匀分布
# 'torch.linspace' generates a 1-dimensional tensor for the specified interval,
# and the values in it are evenly distributed
return torch.linspace(start=self.beta_start, end=self.beta_end, steps=self.noise_steps)
elif schedule_name == "cosine":
def alpha_hat(t):
"""
其参数t从0到1,并生成(1 - β)到扩散过程的该部分的累积乘积
原式â计算公式为:α_hat(t) = f(t) / f(0)
原式f(t)计算公式为:f(t) = cos(((t / (T + s)) / (1 + s)) · (π / 2))²
在此函数中s = 0.008且f(0) = 1
所以仅返回f(t)即可
:param t: 时间
:return: t时alpha_hat的值
The parameter t ranges from 0 to 1
Generate (1 - β) to the cumulative product of this part of the diffusion process
The original formula â is calculated as: α_hat(t) = f(t) / f(0)
The original formula f(t) is calculated as: f(t) = cos(((t / (T + s)) / (1 + s)) · (π / 2))²
In this function, s = 0.008 and f(0) = 1
So just return f(t)
:param t: Time
:return: The value of alpha_hat at t
"""
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

# 要产生的beta的数量
# Number of betas to generate
noise_steps = self.noise_steps
# 使用的最大β值;使用小于1的值来防止出现奇点
# The max value of β, use a value less than 1 to prevent singularities
max_beta = 0.999
# 创建一个分散给定alpha_hat(t)函数的β时间表,从t = [0,1]定义了(1 - β)的累积产物
# Create a beta schedule that scatter given the alpha_hat(t) function,
# defining the cumulative product of (1 - β) from t = [0,1]
betas = []
# 循环遍历
# Loop
for i in range(noise_steps):
t1 = i / noise_steps
t2 = (i + 1) / noise_steps
# 计算β在t时刻的值,公式为:β(t) = min(1 - (α_hat(t) - α_hat(t-1)), 0.999)
# Calculate the value of β at time t
# Formula: β(t) = min(1 - (α_hat(t) - α_hat(t-1)), 0.999)
beta_t = min(1 - alpha_hat(t2) / alpha_hat(t1), max_beta)
betas.append(beta_t)
return torch.tensor(betas)
Expand All @@ -77,22 +82,24 @@ def alpha_hat(t):

def noise_images(self, x, time):
"""
给图片增加噪声
:param x: 输入图像信息
:param time: 时间
:return: sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, t时刻形状与x张量相同的张量
Add noise to the image
:param x: Input image
:param time: Time
:return: sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, a tensor of the same shape as the x tensor at time t
"""
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[time])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[time])[:, None, None, None]
# 生成一个形状与x张量相同的张量,其中的元素是从标准正态分布(均值为0,方差为1)中随机抽样得到的
# Generates a tensor of the same shape as the x tensor,
# with elements randomly sampled from a standard normal distribution (mean 0, variance 1)
Ɛ = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

def sample_time_steps(self, n):
"""
采样时间步长
:param n: 图像尺寸
:return: 形状为(n,)的整数张量
Sample time steps
:param n: Image size
:return: Integer tensor of shape (n,)
"""
# 生成一个具有指定形状(n,)的整数张量,其中每个元素都在low和high之间(包含 low,不包含 high)随机选择
# Generate a tensor of integers with the specified shape (n,)
# where each element is randomly chosen between low and high (contains low, does not contain high)
return torch.randint(low=1, high=self.noise_steps, size=(n,))
66 changes: 34 additions & 32 deletions model/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,80 +19,82 @@

class Diffusion(BaseDiffusion):
"""
DDIM扩散模型
DDIM class
"""

def __init__(self, noise_steps=1000, sample_steps=20, beta_start=1e-4, beta_end=0.02, img_size=256, device="cpu"):
"""
扩散模型ddim复现
论文:《Denoising Diffusion Implicit Models
链接:https://arxiv.org/abs/2010.02502
:param noise_steps: 噪声步长
:param sample_steps: 采样步长
:param beta_start: β开始值
:param beta_end: β结束值
:param img_size: 图像大小
:param device: 设备类型
The implement of DDIM
Paper: Denoising Diffusion Implicit Models
URL: https://arxiv.org/abs/2010.02502
:param noise_steps: Noise steps
:param sample_steps: Sample steps
:param beta_start: β start
:param beta_end: β end
:param img_size: Image size
:param device: Device type
"""
super().__init__(noise_steps, beta_start, beta_end, img_size, device)
# 采样步长,用于跳步
# Sample steps, it skips some steps
self.sample_steps = sample_steps

self.eta = 0

# 计算迭代步长,跳步操作
# Calculate time step size, it skips some steps
self.time_step = torch.arange(0, self.noise_steps, (self.noise_steps // self.sample_steps)).long() + 1
self.time_step = reversed(torch.cat((torch.tensor([0], dtype=torch.long), self.time_step)))
self.time_step = list(zip(self.time_step[:-1], self.time_step[1:]))

def sample(self, model, n, labels=None, cfg_scale=None):
"""
采样
:param model: 模型
:param n: 采样图片个数
:param labels: 标签
:param cfg_scale: classifier-free guidance插值权重,用于提升生成质量,避免后验坍塌(posterior collapse)问题
参考论文:《Classifier-Free Diffusion Guidance
:return: 采样图片
DDIM sample method
:param model: Model
:param n: Number of sample images
:param labels: Labels
:param cfg_scale: classifier-free guidance interpolation weight, users can better generate model effect.
Avoiding the posterior collapse problem, Reference paper: 'Classifier-Free Diffusion Guidance'
:return: Sample images
"""
logger.info(msg=f"DDIM Sampling {n} new images....")
model.eval()
with torch.no_grad():
# 输入格式为[n, 3, img_size, img_size]
# Input dim: [n, 3, img_size, img_size]
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
# i和i的前一个时刻
# The list of current time and previous time
for i, p_i in tqdm(self.time_step):
# t时间步长,创建大小为n的张量
# Time step, creating a tensor of size n
t = (torch.ones(n) * i).long().to(self.device)
# t的前一个时间步长
# Previous time step, creating a tensor of size n
p_t = (torch.ones(n) * p_i).long().to(self.device)
# 拓展为4维张量,根据时间步长t获取值
# Expand to a 4-dimensional tensor, and get the value according to the time step t
alpha_t = self.alpha_hat[t][:, None, None, None]
alpha_prev = self.alpha_hat[p_t][:, None, None, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
# 这里判断网络是否有条件输入,例如多个类别输入
# Whether the network has conditional input, such as multiple category input
if labels is None and cfg_scale is None:
# 图像与时间步长输入进模型中
# Images and time steps input into the model
predicted_noise = model(x, t)
else:
predicted_noise = model(x, t, labels)
# 用于提升生成,避免后验坍塌(posterior collapse)问题
# Avoiding the posterior collapse problem and better generate model effect
if cfg_scale > 0:
# 无条件预测噪声
# Unconditional predictive noise
unconditional_predicted_noise = model(x, t, None)
# torch.lerp根据给定的权重,在起始值和结束值之间进行线性插值,公式:input + weight * (end - input)
# 'torch.lerp' performs linear interpolation between the start and end values
# according to the given weights
# Formula: input + weight * (end - input)
predicted_noise = torch.lerp(unconditional_predicted_noise, predicted_noise, cfg_scale)
# 核心计算公式
# Calculation formula
x0_t = torch.clamp((x - (predicted_noise * torch.sqrt((1 - alpha_t)))) / torch.sqrt(alpha_t), -1, 1)
c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t))
c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2)
x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise
model.train()
# 将值恢复到0和1的范围
# Return the value to the range of 0 and 1
x = (x + 1) * 0.5
# 乘255进入有效像素范围
# Multiply by 255 to enter the effective pixel range
x = (x * 255).type(torch.uint8)
return x
62 changes: 33 additions & 29 deletions model/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,70 +19,74 @@

class Diffusion(BaseDiffusion):
"""
DDPM扩散模型
DDPM class
"""

def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cpu"):
"""
扩散模型ddpm复现
论文:《Denoising Diffusion Probabilistic Models
链接:https://arxiv.org/abs/2006.11239
:param noise_steps: 噪声步长
:param beta_start: β开始值
:param beta_end: β结束值
:param img_size: 图像大小
:param device: 设备类型
The implement of DDPM
Paper: Denoising Diffusion Probabilistic Models
URL: https://arxiv.org/abs/2006.11239
:param noise_steps: Noise steps
:param beta_start: β start
:param beta_end: β end
:param img_size: Image size
:param device: Device type
"""

super().__init__(noise_steps, beta_start, beta_end, img_size, device)

def sample(self, model, n, labels=None, cfg_scale=None):
"""
采样
:param model: 模型
:param n: 采样图片个数
:param labels: 标签
:param cfg_scale: classifier-free guidance插值权重,用于提升生成质量,避免后验坍塌(posterior collapse)问题
参考论文:《Classifier-Free Diffusion Guidance
:return: 采样图片
DDPM sample method
:param model: Model
:param n: Number of sample images
:param labels: Labels
:param cfg_scale: classifier-free guidance interpolation weight, users can better generate model effect.
Avoiding the posterior collapse problem, Reference paper: 'Classifier-Free Diffusion Guidance'
:return: Sample images
"""
logger.info(msg=f"DDPM Sampling {n} new images....")
model.eval()
with torch.no_grad():
# 输入格式为[n, 3, img_size, img_size]
# Input dim: [n, 3, img_size, img_size]
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
# reversed(range(1, self.noise_steps)为反向迭代整数序列
# 'reversed(range(1, self.noise_steps)' iterates over a sequence of integers in reverse
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
# 时间步长,创建大小为n的张量
# Time step, creating a tensor of size n
t = (torch.ones(n) * i).long().to(self.device)
# 这里判断网络是否有条件输入,例如多个类别输入
# Whether the network has conditional input, such as multiple category input
if labels is None and cfg_scale is None:
# 图像与时间步长输入进模型中
# Images and time steps input into the model
predicted_noise = model(x, t)
else:
predicted_noise = model(x, t, labels)
# 用于提升生成,避免后验坍塌(posterior collapse)问题
# Avoiding the posterior collapse problem and better generate model effect
if cfg_scale > 0:
# 无条件预测噪声
# Unconditional predictive noise
unconditional_predicted_noise = model(x, t, None)
# torch.lerp根据给定的权重,在起始值和结束值之间进行线性插值,公式:input + weight * (end - input)
# 'torch.lerp' performs linear interpolation between the start and end values
# according to the given weights
# Formula: input + weight * (end - input)
predicted_noise = torch.lerp(unconditional_predicted_noise, predicted_noise, cfg_scale)
# 拓展为4维张量,根据时间步长t获取值
# Expand to a 4-dimensional tensor, and get the value according to the time step t
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
# 只需要步长大于1的噪声,详细参考论文P4页Algorithm2的第3行
# Only noise with a step size greater than 1 is required.
# For details, refer to line 3 of Algorithm 2 on page 4 of the paper
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
# 在每一轮迭代中用x计算x的t - 1,详细参考论文P4页Algorithm2的第4行
# In each epoch, use x to calculate t - 1 of x
# For details, refer to line 4 of Algorithm 2 on page 4 of the paper
x = 1 / torch.sqrt(alpha) * (
x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(
beta) * noise
model.train()
# 将值恢复到0和1的范围
# Return the value to the range of 0 and 1
x = (x.clamp(-1, 1) + 1) / 2
# 乘255进入有效像素范围
# Multiply by 255 to enter the effective pixel range
x = (x * 255).type(torch.uint8)
return x

0 comments on commit c92fd82

Please sign in to comment.