Skip to content

Commit

Permalink
Add: Add learning rate function such as default, cosine and warmup_co…
Browse files Browse the repository at this point in the history
…sine.
  • Loading branch information
chairc committed Jul 31, 2023
1 parent 9189913 commit 35d024e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
import coloredlogs

from torchvision.utils import save_image
from matplotlib import pyplot as plt

from model.ddpm import Diffusion
from model.network import UNet
from utils.utils import get_dataset, delete_files
from utils.initializer import device_initializer
from utils.lr_scheduler import set_cosine_lr

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
Expand Down Expand Up @@ -87,6 +91,25 @@ def test_noising(self):
save_image(tensor=noised_image.add(1).mul(0.5), fp=os.path.join(save_path, "noise.jpg"))
logger.info(msg="Finish noising noising_test.")

def test_lr(self):
image_size = 64
device = device_initializer()
net = UNet(num_classes=10, device=device, image_size=image_size)
optimizer = torch.optim.AdamW(net.parameters(), lr=3e-4)
lr_max = 3e-4
lr_min = 3e-6
max_epoch = 300
lrs = []
for epoch in range(max_epoch):
set_cosine_lr(optimizer=optimizer, current_epoch=epoch, max_epoch=max_epoch, lr_min=lr_min,
lr_max=lr_max, warmup=True)
logger.info(msg=f"{epoch}: {optimizer.param_groups[0]['lr']}")
lrs.append(optimizer.param_groups[0]["lr"])
optimizer.step()

plt.plot(lrs)
plt.show()


if __name__ == "__main__":
pass
16 changes: 16 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from model.modules import EMA
from model.network import UNet
from utils.initializer import device_initializer, seed_initializer
from utils.lr_scheduler import set_cosine_lr
from utils.utils import plot_images, save_images, get_dataset, setup_logging

logger = logging.getLogger(__name__)
Expand All @@ -50,6 +51,8 @@ def train(rank=None, args=None):
optim = args.optim
# 学习率大小
init_lr = args.lr
# 学习率方法
lr_func = args.lr_func
# 类别个数
num_classes = args.num_classes
# classifier-free guidance插值权重,用户更好生成模型效果
Expand Down Expand Up @@ -159,6 +162,16 @@ def train(rank=None, args=None):
# 开始迭代
for epoch in range(start_epoch, args.epochs):
logger.info(msg=f"[{device}]: Start epoch {epoch}:")
# 设置学习率
if lr_func == "cosine":
current_lr = set_cosine_lr(optimizer=optimizer, current_epoch=epoch, max_epoch=args.epochs,
lr_min=init_lr * 0.01, lr_max=init_lr, warmup=False)
elif lr_func == "warmup_cosine":
current_lr = set_cosine_lr(optimizer=optimizer, current_epoch=epoch, max_epoch=args.epochs,
lr_min=init_lr * 0.01, lr_max=init_lr, warmup=True)
else:
current_lr = init_lr
logger.info(msg=f"[{device}]: This epoch learning rate is {current_lr}")
pbar = tqdm(dataloader)
# 初始化images和labels
images, labels = None, None
Expand Down Expand Up @@ -311,6 +324,9 @@ def main(args):
parser.add_argument("--optim", type=str, default="adamw")
# 学习率(酌情设置)
parser.add_argument("--lr", type=int, default=3e-4)
# 学习率方法(酌情设置)
# 不设置时为空,可设置cosine,warmup_cosine
parser.add_argument("--lr_func", type=str, default="")
# 保存路径(必须设置)
parser.add_argument("--result_path", type=str, default="/your/path/Defect-Diffusion-Model/results")
# 是否每次训练储存(建议设置)
Expand Down
35 changes: 35 additions & 0 deletions utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/7/15 23:50
@Author : chairc
@Site : https://github.com/chairc
"""
import math


def set_cosine_lr(optimizer, current_epoch, max_epoch, lr_min=0, lr_max=0.1, warmup=True, num_warmup=5):
"""
设置优化器学习率
:param optimizer: 优化器
:param current_epoch: 当前迭代次数
:param max_epoch: 最大迭代次数
:param lr_min: 最小学习率
:param lr_max: 最大学习率
:param warmup: 预热
:param num_warmup: 预热个数
:return: lr
"""
warmup_epoch = num_warmup if warmup else 0
if current_epoch < warmup_epoch:
lr = lr_max * current_epoch / warmup_epoch
elif current_epoch < max_epoch:
lr = lr_min + (lr_max - lr_min) * (
1 + math.cos(math.pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2
else:
lr = lr_min + (lr_max - lr_min) * (
1 + math.cos(math.pi * (current_epoch - max_epoch) / max_epoch)) / 2
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr

0 comments on commit 35d024e

Please sign in to comment.