Skip to content

Commit

Permalink
Merge pull request #18 from chairc/dev
Browse files Browse the repository at this point in the history
Reconstruct modules.py, and update package path; Rewrite checkpoint storage and loading functions, add checkpoint files
  • Loading branch information
chairc committed Dec 5, 2023
2 parents 9300334 + b2caa37 commit d94531b
Show file tree
Hide file tree
Showing 15 changed files with 592 additions and 419 deletions.
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod
**Repository Structure**

```yaml
Industrial Defect Diffusion Model
├── datasets
│ └── dataset_demo
│ ├── class_1
│ ├── class_2
│ └── class_3
├── model
│ ├── modules
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── ema.py
│ │ └── module.py
│ ├── networks
│ │ ├── base.py
Expand All @@ -38,6 +44,7 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod
│ ├── generate.py
│ └── train.py
├── utils
│ ├── checkpoint.py
│ ├── initializer.py
│ ├── lr_scheduler.py
│ └── utils.py
Expand Down Expand Up @@ -139,13 +146,26 @@ The training GPU implements environment for this README is as follows: models ar
**Conditional Resume Training Command**

```bash
python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
# This is using --start_epoch, default use current epoch checkpoint
python train.py --resume True --start_epoch 10 --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
```

```bash
# This is not using --start_epoch, default use last checkpoint
python train.py --resume True --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
```
**Unconditional Resume Training Command**

```bash
python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
# This is using --start_epoch, default use current epoch checkpoint
python train.py --resume True --start_epoch 10 --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```

```bash
# This is not using --start_epoch, default use last checkpoint
python train.py --resume True --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```

#### Distributed Training

1. The basic configuration is similar to regular training, but note that enabling distributed training requires setting `--distributed` to `True`. To prevent arbitrary use of distributed training, we have several conditions for enabling distributed training, such as `args.distributed`, `torch.cuda.device_count() > 1`, and `torch.cuda.is_available()`.
Expand Down Expand Up @@ -197,8 +217,7 @@ The training GPU implements environment for this README is as follows: models ar
| --vis | | Visualize dataset information | bool | Enable visualization of dataset information for model selection based on visualization |
| --num_vis | | Number of visualization images generated | int | Number of visualization images generated. If not filled, the default is the number of image classes |
| --resume | | Resume interrupted training | bool | Set to "True" to resume interrupted training. Note: If the epoch number of interruption is outside the condition of --start_model_interval, it will not take effect. For example, if the start saving model time is 100 and the interruption number is 50, we cannot set any loading epoch points because we did not save the model. We save the xxx_last.pt file every training, so we need to use the last saved model for interrupted training |
| --start_epoch | | Epoch number of interruption | int | Epoch number where the training was interrupted |
| --load_model_dir | | Folder name of the loaded model | str | Folder name of the previously loaded model |
| --start_epoch | | Epoch number of interruption | int | Epoch number where the training was interrupted, the model will load current checkpoint |
| --distributed | | Distributed training | bool | Enable distributed training |
| --main_gpu | | Main GPU for distributed | int | Set the main GPU for distributed training |
| --world_size | | Number of distributed nodes | int | Number of distributed nodes, corresponds to the actual number of GPUs or distributed nodes being used |
Expand Down
25 changes: 21 additions & 4 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
**本仓库整体结构**

```yaml
Industrial Defect Diffusion Model
├── datasets
│ └── dataset_demo
│ ├── class_1
│ ├── class_2
│ └── class_3
├── model
│ ├── modules
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── ema.py
│ │ └── module.py
│ ├── networks
│ │ ├── base.py
Expand All @@ -38,6 +44,7 @@
│ ├── generate.py
│ └── train.py
├── utils
│ ├── checkpoint.py
│ ├── initializer.py
│ ├── lr_scheduler.py
│ └── utils.py
Expand Down Expand Up @@ -138,12 +145,23 @@
**有条件恢复训练命令**

```bash
python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
# 此处为输入--start_epoch参数,使用当前编号权重
python train.py --resume True --start_epoch 10 --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
```

```bash
# 此处为不输入--start_epoch参数,默认使用last权重
python train.py --resume True --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path
```
**无条件恢复训练命令**

```bash
python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
python train.py --resume True --start_epoch 10 --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```

```bash
# 此处为不输入--start_epoch参数,默认使用last权重
python train.py --resume True --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path
```

#### 分布式训练
Expand Down Expand Up @@ -200,8 +218,7 @@
| --vis | | 可视化数据集信息 | bool | 打开可视化数据集信息,根据可视化生成样本信息筛选模型 |
| --num_vis | | 生成的可视化图像数量 | int | 生成的可视化图像数量。如果不填写,则默认生成图片个数为数据集类别的个数 |
| --resume | | 中断恢复训练 | bool | 恢复训练将设置为“True”。注意:设置异常中断的epoch编号若在--start_model_interval参数条件外,则不生效。例如开始保存模型时间为100,中断编号为50,由于我们没有保存模型,所以无法设置任意加载epoch点。每次训练我们都会保存xxx_last.pt文件,所以我们需要使用最后一次保存的模型进行中断训练 |
| --start_epoch | | 中断迭代编号 | int | 设置异常中断的epoch编号 |
| --load_model_dir | | 加载模型所在文件夹 | str | 写入中断的epoch上一个加载模型的所在文件夹 |
| --start_epoch | | 中断迭代编号 | int | 设置异常中断的epoch编号,模型会自动加载当前编号的检查点 |
| --distributed | | 分布式训练 | bool | 开启分布式训练 |
| --main_gpu | | 分布式训练主显卡 | int | 设置分布式中主显卡 |
| --world_size | | 分布式训练的节点等级 | int | 分布式训练的节点等级, world_size的值会与实际使用的GPU数量或分布式节点数量相对应 |
Expand Down
36 changes: 36 additions & 0 deletions model/modules/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:19
@Author : chairc
@Site : https://github.com/chairc
"""
import logging
import coloredlogs
import torch.nn as nn

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def get_activation_function(name="silu", inplace=False):
"""
Get activation function
:param name: Activation function name
:param inplace: can optionally do the operation in-place
:return Activation function
"""
if name == "relu":
act = nn.ReLU(inplace=inplace)
elif name == "relu6":
act = nn.ReLU6(inplace=inplace)
elif name == "silu":
act = nn.SiLU(inplace=inplace)
elif name == "lrelu":
act = nn.LeakyReLU(0.1, inplace=inplace)
elif name == "gelu":
act = nn.GELU()
else:
logger.warning(msg=f"Unsupported activation function type: {name}")
act = nn.SiLU(inplace=inplace)
return act
53 changes: 53 additions & 0 deletions model/modules/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:19
@Author : chairc
@Site : https://github.com/chairc
"""
import torch.nn as nn
from model.modules.activation import get_activation_function


class SelfAttention(nn.Module):
"""
SelfAttention block
"""

def __init__(self, channels, size, act="silu"):
"""
Initialize the self-attention block
:param channels: Channels
:param size: Size
:param act: Activation function
"""
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True)
self.ln = nn.LayerNorm(normalized_shape=[channels])
self.ff_self = nn.Sequential(
nn.LayerNorm(normalized_shape=[channels]),
nn.Linear(in_features=channels, out_features=channels),
get_activation_function(name=act),
nn.Linear(in_features=channels, out_features=channels),
)

def forward(self, x):
"""
SelfAttention forward
:param x: Input
:return: attention_value
"""
# First perform the shape transformation, and then use 'swapaxes' to exchange the first
# second dimensions of the new tensor
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
131 changes: 131 additions & 0 deletions model/modules/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:21
@Author : chairc
@Site : https://github.com/chairc
"""
import torch
import torch.nn as nn

from model.modules.conv import BaseConv, DoubleConv
from model.modules.module import CSPLayer


class DownBlock(nn.Module):
"""
Downsample block
"""

def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the downsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, time):
"""
DownBlock forward
:param x: Input
:param time: Time
:return: x + emb
"""
x = self.maxpool_conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class UpBlock(nn.Module):
"""
Upsample Block
"""

def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the upsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()

self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, skip_x, time):
"""
UpBlock forward
:param x: Input
:param skip_x: Merged input
:param time: Time
:return: x + emb
"""
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class CSPDarkDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"):
super().__init__()
self.conv_csp = nn.Sequential(
BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act),
CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act)
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, time):
x = self.conv_csp(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class CSPDarkUpBlock(nn.Module):

def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act)
self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, skip_x, time):
x = self.conv(x)
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
Loading

0 comments on commit d94531b

Please sign in to comment.