In [15]:
import os, cv2, warnings
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from tqdm import tqdm
import torch
import torchvision
from sklearn.metrics import f1_score
import multiprocessing as mp

from utils.Dataset import CustomAugmentation, CustomDataset
from utils.models import ResNet50, UnetResnet50
from utils.utils import dense_crf_wrapper

from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
warnings.filterwarnings("ignore")

label_to_str = ['1', '1+', '1++', '3', '2']

Dataloader

In [2]:
# classification model
print('model loading start')

model_classification = ResNet50()
model_classification.load_state_dict(torch.load('./saved/best2.pt'))
model_classification = model_classification.model.to(device)

# segmentation classification
model_segmentation = UnetResnet50()
model_segmentation.load_state_dict(torch.load('./saved/seg1.pt').state_dict())
model_segmentation = model_segmentation.model.to(device)

model_seg_classification = ResNet50()
model_seg_classification.load_state_dict(torch.load('./saved/seg_class1.pt'))
model_seg_classification = model_seg_classification.model.to(device)

softmax = torch.nn.Softmax()
normalize = torchvision.transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))


model_classification.eval()
model_segmentation.eval()
model_seg_classification.eval()

print('model loading end')

model loading start
model loading end


In [3]:
def grad_cam_resnet50(model, tensor_img, label, device):
    # grad cam
    target_layers = [model.layer4[-1]]
    input_tensor = tensor_img
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=device)
    target_category = np.array(label.detach().cpu())

    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
    return grayscale_cam

In [12]:
def draw_label_pred(img, label, pred):
    font = ImageFont.truetype('arial.ttf', 40)
    img = Image.fromarray(img)
    draw = ImageDraw.Draw(img)
    text = f'label: {label_to_str[label]}\npred:{label_to_str[pred]}'
    draw.text((10,10), text, font=font)
    return np.array(img)

In [6]:
batch_size = 32
val_transform = CustomAugmentation('val')
val_dataset = CustomDataset(data_path='../data/QCdataset', mode='val', transform=val_transform)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=4,
                                        drop_last=False,
                                        pin_memory=(torch.cuda.is_available()),
                                        collate_fn=None)

Model load

In [13]:
batch_size = 4
val_transform = CustomAugmentation('val')
val_dataset = CustomDataset(data_path='../data/QCdataset', mode='test', transform=val_transform)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=4,
                                        drop_last=False,
                                        pin_memory=(torch.cuda.is_available()),
                                        collate_fn=None)

In [16]:
print('inference start')

pbar = tqdm(enumerate(val_loader), total=len(val_loader))
count = [0, 0, 0, 0]
total = 0
num_img = len(val_loader)
fig_idx = 0

# img
img_size = 224
num_img = len(val_loader)
# num_img = 64
vis_img = np.zeros((batch_size*num_img*img_size, 9*img_size, 3),np.uint8)

# fig, ax = plt.subplots(num_img, 4)
for batch, (image, label, mask) in pbar:
    image, label, mask = image.to(device), label.to(device), mask.to(device)

    # classification
    output_classification = model_classification(image)
    pred_classification = output_classification.argmax(1)
    count[0] += (pred_classification == label).sum().item()

    # segmentation
    output_segmentation = model_segmentation(image)

    # threshold
    output_segmentation_th = torch.where(output_segmentation>0., 1., 0.)

    # crf
    image_unnorm = 255*torch.div(torch.add(image, -image.min()),torch.add(image.max(), -image.min()))
    probs_seg = torch.nn.functional.softmax(output_segmentation, dim=1).detach().cpu().numpy()

    pool = mp.Pool(mp.cpu_count())
    images_rgb = image_unnorm.detach().cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
    probs_crf = np.array(pool.map(dense_crf_wrapper, zip(images_rgb, probs_seg)))
    pool.close()
    probs_crf = torch.tensor(probs_crf).to(device)

    # masked image
    masked_seg_images = torch.mul(image_unnorm, torch.stack([output_segmentation_th[:,1]]*3, dim=1)).detach().cpu()
    masked_crf_images = torch.mul(image_unnorm, torch.stack([probs_crf[:,1,:,:]]*3, dim=1)).detach().cpu()
    masked_gt_images = torch.mul(image_unnorm, torch.stack([mask.squeeze()]*3, dim=1)).detach().cpu()


    # masked image transform
    result = []
    for masked_seg_image in masked_seg_images:
        result.append(normalize(masked_seg_image))
    input_seg_classification = torch.stack(result,dim=0).to(device)
    result = []
    for masked_crf_image in masked_crf_images:
        result.append(normalize(masked_crf_image))
    input_crf_classification = torch.stack(result,dim=0).to(device)
    result = []
    for masked_gt_image in masked_gt_images:
        result.append(normalize(masked_gt_image))
    input_gt_classification = torch.stack(result,dim=0).to(device)

    # seg masked image classification
    output_seg_classification = model_seg_classification(input_seg_classification)
    pred_seg_classification = output_seg_classification.argmax(1)
    count[1] += (pred_seg_classification == label).sum().item()
    # crf masked image classification
    output_crf_classification = model_seg_classification(input_crf_classification)
    pred_crf_classification = output_crf_classification.argmax(1)
    count[2] += (pred_crf_classification == label).sum().item()
    # gt masked image classification
    output_gt_classification = model_seg_classification(input_gt_classification)
    pred_gt_classification = output_gt_classification.argmax(1)
    count[3] += (pred_gt_classification == label).sum().item()
    total += batch_size

    # print info
    pbar.update()
    pbar.set_description(
                    f'{batch+1}/{len(val_loader)} - '
                    f'class: {100*count[0]/total}, '
                    f'seg_class: {100*count[1]/total}, '
                    f'crf_class: {100*count[2]/total}, '
                )

    # grad_cam
    grad_cam_class = grad_cam_resnet50(model_classification, image, label, device)
    grad_cam_seg = grad_cam_resnet50(model_seg_classification, input_seg_classification, label, device)
    grad_cam_crf = grad_cam_resnet50(model_seg_classification, input_crf_classification, label, device)

    # img save
    raws = image_unnorm.detach().cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
    probs = 255*np.stack([probs_seg[:,1]]*3, axis=1).transpose(0, 2, 3, 1)
    ths = 255*torch.stack([output_segmentation_th[:,1]]*3, dim=1).detach().cpu().numpy().transpose(0, 2, 3, 1)
    crfs = 255*torch.stack([probs_crf[:,1,:,:]]*3, dim=1).detach().cpu().numpy().transpose(0, 2, 3, 1)
    masked_segs = masked_seg_images.detach().cpu().numpy().transpose(0, 2, 3, 1)
    masked_crfs = masked_crf_images.detach().cpu().numpy().transpose(0, 2, 3, 1)
    for idx in range(raws.shape[0]):
        # grad cam vis
        visualization_class = show_cam_on_image(raws[idx]/255, grad_cam_class[idx, :], use_rgb=True)
        visualization_seg = show_cam_on_image(raws[idx]/255, grad_cam_seg[idx, :], use_rgb=True)
        visualization_crf = show_cam_on_image(raws[idx]/255, grad_cam_crf[idx, :], use_rgb=True)

        # draw label img
        visualization_class = draw_label_pred(visualization_class, label[idx].detach().cpu().numpy(), pred_classification[idx].detach().cpu().numpy())
        visualization_seg = draw_label_pred(visualization_seg, label[idx].detach().cpu().numpy(), pred_seg_classification[idx].detach().cpu().numpy())
        visualization_crf = draw_label_pred(visualization_crf, label[idx].detach().cpu().numpy(), pred_crf_classification[idx].detach().cpu().numpy())

        row = np.concatenate((raws[idx], probs[idx], ths[idx], crfs[idx], masked_segs[idx], masked_crfs[idx], visualization_class, visualization_seg, visualization_crf), axis=1)
        if fig_idx<64:
            vis_img[fig_idx*img_size:(fig_idx+1)*img_size,:,:] = row
        fig_idx += 1

pbar.close()
print('inference end')
im = Image.fromarray((vis_img).astype(np.uint8))
im.save("test_visualization.jpg")
print('save vis')


inference start


15/15 - class: 21.666666666666668, seg_class: 23.333333333333332, crf_class: 25.0, : 100%|██████████| 15/15 [00:15<00:00,  1.03s/it]             


inference end
save vis
