In [2]:
import os
import cv2
import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [3]:
gpu_idx = [0]
model_dir = './checkpoints/2024-10-24/16_57_08/best_model.pt'

resize=128
# data_dir = 'dataset/hot_august'
data_dir = 'dataset/demo_crop'
test_size = 32

save_dir = 'outputs/dataset/demo'

In [4]:
def set_transforms(resize=128):
    # set up transforms
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((resize, resize)),
        ]
    )
    return transform

In [5]:
class ImageDataset(Dataset):
    def __init__(self, data_dir, root, save_dir, transform=None):
        self.data = []
        self.data_dir = data_dir
        self.root = root
        self.save_dir = save_dir
        self.transform = transform

        for folder in self.data_dir:
            for file in os.listdir(folder):
                if file.endswith(".jpg"):
                    self.data.append(os.path.join(folder, file))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = self.data[idx].replace(self.root, self.save_dir)
        folder_path = file_path.split('/')[:-1]
        folder_path = os.path.join(*folder_path)
        os.makedirs(folder_path, exist_ok=True)

        img_path = self.data[idx]
        image = cv2.imread(self.data[idx])

        if self.transform:
            image = self.transform(image)

        return image, file_path, img_path

In [6]:
# set inference model
if torch.cuda.is_available():
    torch.cuda.set_device(gpu_idx[0])
    device = torch.device(f"cuda:{gpu_idx[0]}")
else:
    device = torch.device("cpu")

checkpoint = torch.load(model_dir, map_location=device)
model = checkpoint['model']
model.load_state_dict(checkpoint['model_state_dict'])

if len(gpu_idx) >= 2:
    model = torch.nn.DataParallel(model, device_ids=gpu_idx).to(device)
    print(f"Using multiple GPUs: {gpu_idx}")
else:
    model = model.to(device)
    print(f"Using single GPU: cuda:{gpu_idx[0]}")

# set dataset
test_data = [f.path for f in os.scandir(data_dir) if f.is_dir()]
transform = set_transforms(resize)
test_dataset = ImageDataset(test_data, data_dir, save_dir, transform)
test_loader = DataLoader(test_dataset, batch_size=test_size, shuffle=False)

Using single GPU: cuda:0


In [136]:
if isinstance(model, torch.nn.DataParallel):
    model.module.eval()  # DataParallel을 사용 중인 경우
else:
    model.eval()

with torch.no_grad():
    for data, file_path, _ in test_loader:
        # inference data
        data, file_path = data.cuda(), file_path
        outputs = model(data)

        # create labeling data
        for idx, output in enumerate(outputs):
            prediction = F.softmax(output, dim=0).cpu().numpy() > 0.5
            prediction = prediction[1, :, :]

            rgb_image = np.zeros((prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8)
            rgb_image[prediction] = [255, 255, 255]

            image = Image.fromarray(rgb_image)
            image.save(file_path[idx])
            print(f"Saved labeled image at: {file_path[idx]}")


Saved labeled image at: outputs/CE/CE~06_45_51/F003-0022.jpg


### HotStamping Visualization

In [7]:
if isinstance(model, torch.nn.DataParallel):
    model.module.eval()  # DataParallel을 사용 중인 경우
else:
    model.eval()

with torch.no_grad():
    for data, file_path, img_path in test_loader:
        # inference data
        data, file_path, img_path = data.cuda(), file_path, img_path
        outputs = model(data)

        # create labeling data
        for idx, output in enumerate(outputs):
            image = cv2.imread(img_path[idx])
            image = cv2.resize(image, (resize, resize))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            prediction = F.softmax(output, dim=0).cpu().numpy() > 0.5
            prediction = prediction[1, :, :]

            pink_mask = np.zeros_like(image)
            pink_mask[prediction] = [255, 0, 0]

            overlay_image = cv2.addWeighted(image, 1.0, pink_mask, 0.4, 0)

            cv2.imwrite(file_path[idx], cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR))
            print(f"Saved labeled image at: {file_path[idx]}")


Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F003-0016.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0019.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F003-0014.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0024.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F001-0034.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0013.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0025.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0022.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F001-0038.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F003-0010.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F001-0033.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F002-0016.jpg
Saved labeled image at: outputs/dataset/demo/GN7 파노라마~07_16_51/F003-0013.jpg