Skip to content

Commit

Permalink
Update: Edited to english in 'train.py' and 'test_module.py'.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 6, 2023
1 parent d858ef0 commit 788138d
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 147 deletions.
51 changes: 25 additions & 26 deletions test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,66 +29,65 @@

class TestModule(unittest.TestCase):
"""
方法测试
1. 运行unittest测试模块
* 运行unittest测试模块。使用python -m unittest <test_module>命令运行测试,<test_module>是测试文件的名称或相对路径
* 例如:python -m unittest test_module.py
2. 运行单个测试类或测试方法
* 使用 -k 选项指定要运行的测试类或方法的名称,TestModule是要运行的测试类的名称,test_noising是要运行的测试方法的名称
* 例如:python -m unittest -k TestModule.test_noising
If you want to run the unittest test module, please use 'python -m unittest <test_module.py>'.
If you want to run a single test class or test method, please use
'python -m unittest -k <TestModule or TestClass>.<test_function_name>'.
Test Module
1. Run the unittest test module
* If you want to run the unittest test module, please use 'python -m unittest <test_module.py>',
<test module> is the name or relative path of the test file
* e.g: python -m unittest test_module.py
2. Run a single test class or test method
* Use the -k option to specify the name of the test class or method to run,
where TestModule is the name of the test class to run and test_noising is the name of the test method to run.
* e.g: python -m unittest -k TestModule.test_noising
"""

def test_num_cases(self):
"""
获取所有测试类名称
Get all test class names
:return: None
"""
# 获取所有测试名
# Get all test class names
test_cases = [method for method in dir(TestModule) if method.startswith('test_')]
logger.info(test_cases)
# 打印测试方法名称
# Print all test class names
for method_name in test_cases:
logger.info(method_name)

def test_noising(self):
"""
测试噪声
Test noising
:return: None
"""
# 参数设置
# Parameter settings
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=2)
# 输入图像大小
# Input image size
parser.add_argument("--image_size", type=int, default=640)
parser.add_argument("--dataset_path", type=str, default="./noising_test")

args = parser.parse_args()
logger.info(msg=f"Input params: {args}")

# 开始测试
# Start test
logger.info(msg="Start noising noising_test.")
dataset_path = args.dataset_path
save_path = os.path.join(dataset_path, "noise")
# 需要先清除noise文件夹下所有文件
# You need to clear all files under the 'noise' folder first
delete_files(path=save_path)
dataloader = get_dataset(args=args)
# 重新创建文件夹
# Recreate the folder
os.makedirs(name=save_path, exist_ok=True)
# 扩散模型初始化
# Diffusion model initialization
diffusion = Diffusion(device="cpu")
# 获取图像和噪声Tensor
# Get image and noise tensor
image = next(iter(dataloader))[0]
time = torch.Tensor([0, 50, 125, 225, 350, 500, 675, 999]).long()

# 给图片分别增加噪声
# Add noise to the image
noised_image, _ = diffusion.noise_images(x=image, time=time)
# 保存噪声图片
# Save noise images
save_image(tensor=noised_image.add(1).mul(0.5), fp=os.path.join(save_path, "noise.jpg"))
logger.info(msg="Finish noising noising_test.")

Expand All @@ -113,7 +112,7 @@ def test_lr(self):

def test_summary(self):
"""
测试模型结构
Test model structure
:return: None
"""
image_size = 64
Expand Down
Loading

0 comments on commit 788138d

Please sign in to comment.