Skip to content

Commit

Permalink
Update: Optimizing loading weight method, optimizing generate.py and …
Browse files Browse the repository at this point in the history
…Check for issues where distributed weight names are not the same as network weights.
  • Loading branch information
chairc committed Aug 3, 2023
1 parent 6c21b9b commit 24a03c3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
30 changes: 19 additions & 11 deletions tools/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import coloredlogs

sys.path.append(os.path.dirname(sys.path[0]))
from model.ddpm import Diffusion
from model.ddpm import Diffusion as DDPMDiffusion
from model.ddim import Diffusion as DDIMDiffusion
from model.network import UNet
from utils.initializer import device_initializer
from utils.initializer import device_initializer, load_model_weight_initializer
from utils.utils import plot_images, save_images

logger = logging.getLogger(__name__)
Expand All @@ -27,6 +28,8 @@ def generate(args):
logger.info(msg="Start generation.")
# 是否启用条件生成
conditional = args.conditional
# 采样器类别
sample = args.sample
# 生成名称
generate_name = args.generate_name
# 图片大小
Expand All @@ -39,6 +42,14 @@ def generate(args):
result_path = args.result_path
# 设备初始化
device = device_initializer()
# 初始化扩散模型
if sample == "ddpm":
diffusion = DDPMDiffusion(img_size=image_size, device=device)
elif sample == "ddim":
diffusion = DDIMDiffusion(img_size=image_size, device=device)
else:
diffusion = DDPMDiffusion(img_size=image_size, device=device)
logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")
# 模型初始化
if conditional:
# 类别个数
Expand All @@ -48,18 +59,12 @@ def generate(args):
# classifier-free guidance插值权重
cfg_scale = args.cfg_scale
model = UNet(num_classes=num_classes, device=device, image_size=image_size).to(device)
# 加载权重路径
weight = torch.load(f=weight_path)
model.load_state_dict(state_dict=weight)
diffusion = Diffusion(img_size=image_size, device=device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
y = torch.Tensor([class_name] * num_images).long().to(device)
x = diffusion.sample(model=model, n=num_images, labels=y, cfg_scale=cfg_scale)
else:
model = UNet(device=device, image_size=image_size).to(device)
# 加载权重路径
weight = torch.load(f=weight_path)
model.load_state_dict(state_dict=weight)
diffusion = Diffusion(img_size=image_size, device=device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
x = diffusion.sample(model=model, n=num_images)
# 如果不存在路径信息则只展示;存在则保存到指定路径并展示
if result_path == "" or result_path is None:
Expand All @@ -84,7 +89,10 @@ def generate(args):
# 保存路径
parser.add_argument("--result_path", type=str, default="/your/path/Defect-Diffusion-Model/results/vis")
# 开启条件生成,若使用False则不需要设置该参数之后的参数
parser.add_argument("--conditional", type=bool, default=False)
parser.add_argument("--conditional", type=bool, default=True)
# 采样器类别(必须设置)
# 不设置是为ddpm,可设置ddpm,ddim
parser.add_argument("--sample", type=str, default="ddpm")

# ==========================开启条件生成分界线==========================
# 类别个数
Expand Down
11 changes: 3 additions & 8 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from collections import OrderedDict

sys.path.append(os.path.dirname(sys.path[0]))
from model.ddpm import Diffusion as DDPMDiffusion
from model.ddim import Diffusion as DDIMDiffusion
from model.modules import EMA
from model.network import UNet
from utils.initializer import device_initializer, seed_initializer
from utils.initializer import device_initializer, seed_initializer, load_model_weight_initializer
from utils.lr_scheduler import set_cosine_lr
from utils.utils import plot_images, save_images, get_dataset, setup_logging

Expand Down Expand Up @@ -129,11 +128,7 @@ def train(rank=None, args=None):
load_epoch = str(start_epoch - 1).zfill(3)
model_path = os.path.join(result_path, load_model_dir, f"model_{load_epoch}.pt")
optim_path = os.path.join(result_path, load_model_dir, f"optim_model_{load_epoch}.pt")
model_dict = model.state_dict()
model_weights_dict = torch.load(f=model_path, map_location=device)
model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(model_weights_dict)
model.load_state_dict(state_dict=OrderedDict(model_dict))
load_model_weight_initializer(model=model, weight_path=model_path, device=device)
logger.info(msg=f"[{device}]: Successfully load model model_{load_epoch}.pt")
# 加载优化器参数
optim_weights_dict = torch.load(f=optim_path, map_location=device)
Expand Down Expand Up @@ -313,7 +308,7 @@ def main(args):
parser.add_argument("--conditional", type=bool, default=False)
# 采样器类别(必须设置)
# 不设置是为ddpm,可设置ddpm,ddim
parser.add_argument("--sample", type=str, default="ddim")
parser.add_argument("--sample", type=str, default="ddpm")
# 初始化模型的文件名称(必须设置)
parser.add_argument("--run_name", type=str, default="df")
# 训练总迭代次数(必须设置)
Expand Down
31 changes: 30 additions & 1 deletion utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import logging
import coloredlogs

from collections import OrderedDict

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")

Expand Down Expand Up @@ -46,7 +48,7 @@ def seed_initializer(seed_id=0):
"""
初始化种子
:param seed_id: 种子id
:return:
:return: None
"""
torch.manual_seed(seed_id)
torch.cuda.manual_seed_all(seed_id)
Expand All @@ -55,3 +57,30 @@ def seed_initializer(seed_id=0):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(msg=f"The seed is initialized, and the seed ID is {seed_id}.")


def load_model_weight_initializer(model, weight_path, device, is_train=True):
"""
初始化权重加载
:param model: 模型
:param weight_path: 权重路径
:param device: 设备类型
:param is_train: 是否为训练模式
:return: None
"""
model_dict = model.state_dict()
model_weights_dict = torch.load(f=weight_path, map_location=device)
# 检查键是否包含 'module.' 前缀。该方法为分布式中训练后的名称,检查权重并删除
if not is_train:
new_model_weights_dict = {}
for key, value in model_weights_dict.items():
if key.startswith('module.'):
new_key = key[len('module.'):]
new_model_weights_dict[new_key] = value
else:
new_model_weights_dict[key] = value
model_weights_dict = new_model_weights_dict
logger.info(msg="Successfully check the load weight and rename.")
model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(model_weights_dict)
model.load_state_dict(state_dict=OrderedDict(model_dict))

0 comments on commit 24a03c3

Please sign in to comment.