Skip to content

Commit

Permalink
Update: Edited to english in 'network.py' and 'modules.py'.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 9, 2023
1 parent f2e9d7c commit 8c0bb55
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 69 deletions.
105 changes: 54 additions & 51 deletions model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

class EMA:
"""
指数移动平均
Exponential Moving Average
"""

def __init__(self, beta):
"""
初始化EMA
Initialize EMA
:param beta: β
"""
super().__init__()
Expand All @@ -26,9 +26,9 @@ def __init__(self, beta):

def update_model_average(self, ema_model, current_model):
"""
更新模型均值
:param ema_model: EMA模型
:param current_model: 当前模型
Update model average
:param ema_model: EMA model
:param current_model: Current model
:return: None
"""
for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
Expand All @@ -37,21 +37,21 @@ def update_model_average(self, ema_model, current_model):

def update_average(self, old_weight, new_weight):
"""
更新均值
:param old_weight: 旧权重
:param new_weight: 新权重
:return: new_weight或old_weight * self.beta + (1 - self.beta) * new_weight
Update average
:param old_weight: Old weight
:param new_weight: New weight
:return: new_weight or old_weight * self.beta + (1 - self.beta) * new_weight
"""
if old_weight is None:
return new_weight
return old_weight * self.beta + (1 - self.beta) * new_weight

def step_ema(self, ema_model, model, step_start_ema=2000):
"""
EMA步长
:param ema_model: EMA模型
:param model: 原模型
:param step_start_ema: 开始 EMA步长
EMA step
:param ema_model: EMA model
:param model: Original model
:param step_start_ema: Start EMA step
:return: None
"""
if self.step < step_start_ema:
Expand All @@ -63,29 +63,30 @@ def step_ema(self, ema_model, model, step_start_ema=2000):

def reset_parameters(self, ema_model, model):
"""
重置参数
:param ema_model: EMA模型
:param model: 原模型
Reset parameters
:param ema_model: EMA model
:param model: Original model
:return: None
"""
ema_model.load_state_dict(model.state_dict())


class SelfAttention(nn.Module):
"""
自注意力模块
SelfAttention block
"""

def __init__(self, channels, size):
"""
初始化自注意力块
:param channels: 通道
:param size: 尺寸
Initialize the self-attention block
:param channels: Channels
:param size: Size
"""
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
# pytorch1.8中不支持batch_first,若要支持升级为1.9及以上,或使用一下代码进行转置
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True)
self.ln = nn.LayerNorm(normalized_shape=[channels])
self.ff_self = nn.Sequential(
Expand All @@ -97,14 +98,16 @@ def __init__(self, channels, size):

def forward(self, x):
"""
前向传播
:param x: 输入
SelfAttention forward
:param x: Input
:return: attention_value
"""
# 首先进行形状变换,再用swapaxes对新张量的第1和2维度进行交换
# First perform the shape transformation, and then use 'swapaxes' to exchange the first
# second dimensions of the new tensor
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# pytorch1.8中不支持batch_first,若要支持升级为1.9及以上,或使用一下代码进行转置
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
Expand All @@ -113,16 +116,16 @@ def forward(self, x):

class DoubleConv(nn.Module):
"""
双卷积
Double convolution
"""

def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
"""
初始化双卷积
:param in_channels: 输入通道
:param out_channels: 输出通道
:param mid_channels: 中间通道
:param residual: 是否残差
Initialize the double convolution block
:param in_channels: Input channels
:param out_channels: Output channels
:param mid_channels: Middle channels
:param residual: Whether residual
"""
super().__init__()
self.residual = residual
Expand All @@ -138,9 +141,9 @@ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False)

def forward(self, x):
"""
前向传播
:param x: 输入
:return: 残差结果或非残差结果
DoubleConv forward
:param x: Input
:return: Residual or non-residual results
"""
if self.residual:
return F.gelu(x + self.double_conv(x))
Expand All @@ -150,15 +153,15 @@ def forward(self, x):

class DownBlock(nn.Module):
"""
下采样块
Downsample block
"""

def __init__(self, in_channels, out_channels, emb_channels=256):
"""
初始化下采样块
:param in_channels: 输入通道
:param out_channels: 输出通道
:param emb_channels: 嵌入通道
Initialize the downsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
"""
super().__init__()
self.maxpool_conv = nn.Sequential(
Expand All @@ -174,9 +177,9 @@ def __init__(self, in_channels, out_channels, emb_channels=256):

def forward(self, x, time):
"""
前向传播
:param x: 输入
:param time: 时间
DownBlock forward
:param x: Input
:param time: Time
:return: x + emb
"""
x = self.maxpool_conv(x)
Expand All @@ -186,15 +189,15 @@ def forward(self, x, time):

class UpBlock(nn.Module):
"""
上采样块
Upsample Block
"""

def __init__(self, in_channels, out_channels, emb_channels=256):
"""
初始化上采样块
:param in_channels: 输入通道
:param out_channels: 输出通道
:param emb_channels: 嵌入通道
Initialize the upsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
"""
super().__init__()

Expand All @@ -211,10 +214,10 @@ def __init__(self, in_channels, out_channels, emb_channels=256):

def forward(self, x, skip_x, time):
"""
前向传播
:param x: 输入
:param skip_x: 需要合并的输入
:param time: 时间
UpBlock forward
:param x: Input
:param skip_x: Merged input
:param time: Time
:return: x + emb
"""
x = self.up(x)
Expand Down
36 changes: 18 additions & 18 deletions model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@

class UNet(nn.Module):
"""
U型网络
UNet
"""

def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256, num_classes=None, image_size=64,
device="cpu"):
"""
初始化UNet网络
:param in_channel: 输入通道
:param out_channel: 输出通道
:param channel: 总通道列表
:param time_channel: 时间通道
:param num_classes: 类别数
:param image_size: 自适应图片大小
:param device: 使用设备
Initialize the UNet network
:param in_channel: Input channel
:param out_channel: Output channel
:param channel: The list of channel
:param time_channel: Time channel
:param num_classes: Number of classes
:param image_size: Adaptive image size
:param device: Device type
"""
super().__init__()
if channel is None:
Expand Down Expand Up @@ -59,9 +59,9 @@ def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256,

def pos_encoding(self, time, channels):
"""
位置编码
:param time: 时间
:param channels: 通道
Position encoding
:param time: Time
:param channels: Channels
:return: pos_enc
"""
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
Expand All @@ -72,10 +72,10 @@ def pos_encoding(self, time, channels):

def forward(self, x, time, y=None):
"""
前向传播
:param x: 输入
:param time: 时间
:param y: 标签
Forward
:param x: Input
:param time: Time
:param y: Input label
:return: output
"""
time = time.unsqueeze(-1).type(torch.float)
Expand Down Expand Up @@ -107,9 +107,9 @@ def forward(self, x, time, y=None):


if __name__ == "__main__":
# 无条件
# Unconditional
net = UNet(device="cpu", image_size=128)
# 有条件
# Conditional
# net = UNet(num_classes=10, device="cpu", image_size=128)
print(sum([p.numel() for p in net.parameters()]))
x = torch.randn(1, 3, 128, 128)
Expand Down

0 comments on commit 8c0bb55

Please sign in to comment.