Skip to content

Commit

Permalink
Add: Add activation function setting.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 28, 2023
1 parent f48c468 commit 8b34de2
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 30 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod
| --dataset_path | | Dataset path | str | Path to the conditional dataset, such as CIFAR-10, with each class in a separate folder, or the path to the unconditional dataset with all images in one folder |
| --fp16 | | Half precision training | bool | Enable half precision training. It effectively reduces GPU memory usage but may affect training accuracy and results |
| --optim | | Optimizer | str | Optimizer selection. Currently supports Adam and AdamW |
| --act | | Activation function | str | Activation function selection. Currently supports gelu, silu, relu, relu6 and lrelu |
| --lr | | Learning rate | int | Initial learning rate. Currently only supports linear learning rate |
| --lr_func | | Learning rate schedule | str | Setting learning rate schedule, currently supporting linear, cosine, and warmup_cosine. |
| --result_path | | Save path | str | Path to save the training results |
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
| --dataset_path | | 数据集路径 | str | 有条件数据集,例如cifar10,每个类别一个文件夹,路径为主文件夹;无条件数据集,所有图放在一个文件夹,路径为图片文件夹 |
| --fp16 | | 半精度训练 | bool | 开启半精度训练,有效减少显存使用,但无法保证训练精度和训练结果 |
| --optim | | 优化器 | str | 优化器选择,目前支持adam和adamw |
| --act | | 激活函数 | str | 激活函数选择,目前支持gelu、silu、relu、relu6和lrelu |
| --lr | | 学习率 | int | 初始化学习率,目前仅支持线性学习率 |
| --lr_func | | 学习率方法 | str | 设置学习率方法,当前支持linear、cosine和warmup_cosine |
| --result_path | | 保存路径 | str | 保存路径 |
Expand Down
69 changes: 58 additions & 11 deletions model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,39 @@
@Author : chairc
@Site : https://github.com/chairc
"""
import logging
import coloredlogs

import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def get_activation_function(name="silu", inplace=False):
"""
Get activation function
:param name: Activation function name
:param inplace: can optionally do the operation in-place
:return Activation function
"""
if name == "relu":
act = nn.ReLU(inplace=inplace)
elif name == "relu6":
act = nn.ReLU6(inplace=inplace)
elif name == "silu":
act = nn.SiLU(inplace=inplace)
elif name == "lrelu":
act = nn.LeakyReLU(0.1, inplace=inplace)
elif name == "gelu":
act = nn.GELU()
else:
logger.warning(msg=f"Unsupported activation function type: {name}")
act = nn.SiLU(inplace=inplace)
return act


class EMA:
"""
Expand Down Expand Up @@ -76,11 +105,12 @@ class SelfAttention(nn.Module):
SelfAttention block
"""

def __init__(self, channels, size):
def __init__(self, channels, size, act="silu"):
"""
Initialize the self-attention block
:param channels: Channels
:param size: Size
:param act: Activation function
"""
super(SelfAttention, self).__init__()
self.channels = channels
Expand All @@ -92,7 +122,7 @@ def __init__(self, channels, size):
self.ff_self = nn.Sequential(
nn.LayerNorm(normalized_shape=[channels]),
nn.Linear(in_features=channels, out_features=channels),
nn.GELU(),
get_activation_function(name=act),
nn.Linear(in_features=channels, out_features=channels),
)

Expand All @@ -119,22 +149,24 @@ class DoubleConv(nn.Module):
Double convolution
"""

def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"):
"""
Initialize the double convolution block
:param in_channels: Input channels
:param out_channels: Output channels
:param mid_channels: Middle channels
:param residual: Whether residual
:param act: Activation function
"""
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.act = act
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(num_groups=1, num_channels=mid_channels),
nn.GELU(),
get_activation_function(name=self.act),
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(num_groups=1, num_channels=out_channels),
)
Expand All @@ -146,7 +178,20 @@ def forward(self, x):
:return: Residual or non-residual results
"""
if self.residual:
return F.gelu(x + self.double_conv(x))
out = x + self.double_conv(x)
if self.act == "relu":
return F.relu(out)
elif self.act == "relu6":
return F.relu6(out)
elif self.act == "silu":
return F.silu(out)
elif self.act == "lrelu":
return F.leaky_relu(out)
elif self.act == "gelu":
return F.gelu(out)
else:
logger.warning(msg=f"Unsupported activation function type: {self.act}")
return F.silu(out)
else:
return self.double_conv(x)

Expand All @@ -156,18 +201,19 @@ class DownBlock(nn.Module):
Downsample block
"""

def __init__(self, in_channels, out_channels, emb_channels=256):
def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the downsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True),
DoubleConv(in_channels=in_channels, out_channels=out_channels),
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act),
)

self.emb_layer = nn.Sequential(
Expand All @@ -192,19 +238,20 @@ class UpBlock(nn.Module):
Upsample Block
"""

def __init__(self, in_channels, out_channels, emb_channels=256):
def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the upsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()

self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True),
DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2),
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act),
)

self.emb_layer = nn.Sequential(
Expand Down
35 changes: 18 additions & 17 deletions model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class UNet(nn.Module):
"""

def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256, num_classes=None, image_size=64,
device="cpu"):
device="cpu", act="silu"):
"""
Initialize the UNet network
:param in_channel: Input channel
Expand All @@ -27,31 +27,32 @@ def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256,
:param num_classes: Number of classes
:param image_size: Adaptive image size
:param device: Device type
:param act: Activation function
"""
super().__init__()
if channel is None:
channel = [64, 128, 256, 512]
self.device = device
self.time_channel = time_channel
self.image_size = image_size
self.inc = DoubleConv(in_channels=in_channel, out_channels=channel[0])
self.down1 = DownBlock(in_channels=channel[0], out_channels=channel[1])
self.sa1 = SelfAttention(channels=channel[1], size=int(self.image_size / 2))
self.down2 = DownBlock(in_channels=channel[1], out_channels=channel[2])
self.sa2 = SelfAttention(channels=channel[2], size=int(self.image_size / 4))
self.down3 = DownBlock(in_channels=channel[2], out_channels=channel[2])
self.sa3 = SelfAttention(channels=channel[2], size=int(self.image_size / 8))
self.inc = DoubleConv(in_channels=in_channel, out_channels=channel[0], act=act)
self.down1 = DownBlock(in_channels=channel[0], out_channels=channel[1], act=act)
self.sa1 = SelfAttention(channels=channel[1], size=int(self.image_size / 2), act=act)
self.down2 = DownBlock(in_channels=channel[1], out_channels=channel[2], act=act)
self.sa2 = SelfAttention(channels=channel[2], size=int(self.image_size / 4), act=act)
self.down3 = DownBlock(in_channels=channel[2], out_channels=channel[2], act=act)
self.sa3 = SelfAttention(channels=channel[2], size=int(self.image_size / 8), act=act)

self.bot1 = DoubleConv(in_channels=channel[2], out_channels=channel[3])
self.bot2 = DoubleConv(in_channels=channel[3], out_channels=channel[3])
self.bot3 = DoubleConv(in_channels=channel[3], out_channels=channel[2])
self.bot1 = DoubleConv(in_channels=channel[2], out_channels=channel[3], act=act)
self.bot2 = DoubleConv(in_channels=channel[3], out_channels=channel[3], act=act)
self.bot3 = DoubleConv(in_channels=channel[3], out_channels=channel[2], act=act)

self.up1 = UpBlock(in_channels=channel[3], out_channels=channel[1])
self.sa4 = SelfAttention(channels=channel[1], size=int(self.image_size / 4))
self.up2 = UpBlock(in_channels=channel[2], out_channels=channel[0])
self.sa5 = SelfAttention(channels=channel[0], size=int(self.image_size / 2))
self.up3 = UpBlock(in_channels=channel[1], out_channels=channel[0])
self.sa6 = SelfAttention(channels=channel[0], size=int(self.image_size))
self.up1 = UpBlock(in_channels=channel[3], out_channels=channel[1], act=act)
self.sa4 = SelfAttention(channels=channel[1], size=int(self.image_size / 4), act=act)
self.up2 = UpBlock(in_channels=channel[2], out_channels=channel[0], act=act)
self.sa5 = SelfAttention(channels=channel[0], size=int(self.image_size / 2), act=act)
self.up3 = UpBlock(in_channels=channel[1], out_channels=channel[0], act=act)
self.sa6 = SelfAttention(channels=channel[0], size=int(self.image_size), act=act)
self.outc = nn.Conv2d(in_channels=channel[0], out_channels=out_channel, kernel_size=1)

if num_classes is not None:
Expand Down
9 changes: 7 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def train(rank=None, args=None):
image_size = args.image_size
# Select optimizer
optim = args.optim
# Select activation function
act = args.act
# Learning rate
init_lr = args.lr
# Learning rate function
Expand Down Expand Up @@ -113,9 +115,9 @@ def train(rank=None, args=None):
resume = args.resume
# Model
if not conditional:
model = UNet(device=device, image_size=image_size).to(device)
model = UNet(device=device, image_size=image_size, act=act).to(device)
else:
model = UNet(num_classes=num_classes, device=device, image_size=image_size).to(device)
model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device)
# Distributed training
if distributed:
model = nn.parallel.DistributedDataParallel(module=model, device_ids=[device], find_unused_parameters=True)
Expand Down Expand Up @@ -350,6 +352,9 @@ def main(args):
# Set optimizer (needed)
# Option: adam/adamw
parser.add_argument("--optim", type=str, default="adamw")
# Set activation function (needed)
# Option: gelu/silu/relu/relu6/lrelu
parser.add_argument("--act", type=str, default="gelu")
# Learning rate (needed)
parser.add_argument("--lr", type=int, default=3e-4)
# Learning rate function (needed)
Expand Down

0 comments on commit 8b34de2

Please sign in to comment.