-
Notifications
You must be signed in to change notification settings - Fork 0
/
usage.py
50 lines (44 loc) · 1.76 KB
/
usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pathlib import Path
from PIL import Image
import os
import numpy as np
import torch
import cv2
import pytorch_lightning as pl
from modules.human_parsing_with_classifcation import HumanParsingWtihClassifcation
from dataset.human_parsing_dataset_syntetic import HumanParsingDatasetSyntetic
from configs.config import read_conf_file
def save_images(img_tensors, img_names, save_dir):
for img_tensor, img_name in zip(img_tensors, img_names):
tensor = (img_tensor.clone() + 1) * 0.5 * 255
tensor = tensor.cpu().clamp(0, 255)
array = tensor.numpy().astype("uint8")
if array.shape[0] == 1:
array = array.squeeze(0)
elif array.shape[0] == 3:
array = array.swapaxes(0, 1).swapaxes(1, 2)
Image.fromarray(array).save(os.path.join(save_dir, f"{img_name}.png"))
if __name__ == "__main__":
hparams = read_conf_file(yaml_path="configs/configs/densenet121_AdamW_GCC.yaml")
weight_path = "best/last.ckpt"
model = (
HumanParsingWtihClassifcation(hparams=hparams)
.load_from_checkpoint(weight_path, hparams=hparams)
.eval()
)
dataset = HumanParsingDatasetSyntetic(
dataset_paths=hparams["dataset_paths"], train_mode=False
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
)
for batch in dataloader:
with torch.no_grad():
output = model.forward(batch=batch)
output = output[0].cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
output = output * 10
smgm_path = Path("test").joinpath(f"{batch['image_name'][0]}.png").as_posix()
cv2.imwrite(smgm_path, output)