# PanDerm - Skin Lesion Segmentation Evaluation

In [None]:
import os
import torch
import argparse
import cv2
import numpy as np
from torchvision import transforms
from skimage.segmentation import mark_boundaries
from PIL import Image
from models.cae_seg import CAEv2_seg
from utils.train_utils import largestConnectComponent

In [None]:
save_path = './'
os.makedirs(save_path, exist_ok=True)

# load dataset
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = cv2.imread('/data2/wangzh/datasets/ISIC2018/Test_Data/ISIC_0012236.jpg')[..., ::-1]
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC)
image = Image.fromarray(np.uint8(image))
image = image_transform(image).unsqueeze(0)

In [None]:
# load model
model = CAEv2_seg()

# load model weights
model_path = '/data/wangzh/experiments/skinfm/finals/cae_seg_isic18/lr_1e-4_decay_0.05_full/0/model_best_0.ckpt'
pretrained_dict = torch.load(model_path, map_location="cpu")
pretrained_dict = pretrained_dict["state_dict"]
model_dict = model.state_dict()
print('Model dict: ', model_dict.keys())
available_pretrained_dict = {}

for k, v in pretrained_dict.items():
    print('Pretrained dict: ', k)
    if k in model_dict.keys():
        if pretrained_dict[k].shape == model_dict[k].shape:
            available_pretrained_dict[k] = v
    if k[6:] in model_dict.keys():
        if pretrained_dict[k].shape == model_dict[k[6:]].shape:
            available_pretrained_dict[k[6:]] = v

for k, _ in available_pretrained_dict.items():
    print("loading {}".format(k))
model_dict.update(available_pretrained_dict)
model.load_state_dict(model_dict)

In [None]:
# inference
model.eval()
with torch.no_grad():
    output = model(image)

In [None]:
# save result
image_save = image.squeeze().cpu().detach().numpy()
image_save = image_save * 0.5 + 0.5
image_save = np.transpose(image_save, (1, 2, 0))

output = torch.argmax(output.squeeze(), dim=0).cpu().detach().numpy()
output = largestConnectComponent(output)

output = mark_boundaries(image_save, output, color=(0, 1, 1), mode='thick')
output = (output * 255).astype(np.uint8)

cv2.imwrite(os.path.join(save_path, 'result.png'), output[..., ::-1])