In [1]:
import argparse
import os
import numpy as np
import torch
import cv2

from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *

from dataloaders.utils import decode_segmap
MI = np.load('meanimage.npy')

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
numClass={'pascal':6,
'coco':21,
'cityscapes':19}
classes = ['drusen', 'hemorrhage', 'exudate', 'scar', 'others']
kernel = np.ones((5,5),np.uint8)
kernel2 = np.ones((7,7),np.uint8)
# parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")

# parser.add_argument('--backbone', type=str, default='resnet',choices=['resnet', 'xception', 'drn', 'mobilenet'],help='backbone name (default: resnet)')
# parser.add_argument('--out-stride', type=int, default=16,help='network output stride (default: 8)')
# parser.add_argument('--dataset', type=str, default='cityscapes', choices=['pascal', 'coco', 'cityscapes'],help='dataset name (default: pascal)')
# parser.add_argument('--sync-bn', type=bool, default=None,help='whether to use sync bn (default: auto)')
# parser.add_argument('--freeze-bn', type=bool, default=False,help='whether to freeze bn parameters (default: False)')
# parser.add_argument('--weightPath', type=str, default=None,help='put the path to resuming file if needed')
# parser.add_argument('--imgPath', type=str, default=None,help='put the path to resuming file if needed')
# parser.add_argument('--outPath', type=str, default=None,help='put the path to resuming file if needed')
# parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
# args = parser.parse_args()
cuda = torch.cuda.is_available()
nclass = numClass['pascal']
model = DeepLab(num_classes=nclass, backbone='resnet', output_stride=16, sync_bn=None, freeze_bn=False)
weight_dict=torch.load('run/pascal/exp8/model_best.pth.tar')
if cuda:
    model = torch.nn.DataParallel(model, device_ids=[0])
    patch_replication_callback(model)
    model = model.cuda()
    
model.module.load_state_dict(weight_dict['state_dict'])
model.eval()
for cid, c in enumerate(classes):
    if not os.path.isdir('val_result/'+c):
        os.mkdir('val_result/' + c)

with open('pascal/ImageSets/Segmentation/val.txt', 'r') as fp:
    data = fp.readlines()
for fn in data:
    fn = fn.strip('\n')
    
    imgPath = 'pascal/JPEGImages/' + fn + '.jpg'
    outPath = 'val_result/' + fn +'.png'

    image = cv2.imread(imgPath)
    oriDim = image.shape
    image = cv2.resize(image, dsize=(513,513))
    
    image = (image.astype(np.float32) - MI) / 255.
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]

    image = image[:, :, ::-1]
    for i in range(3):
        image[:, :, i] = image[:, :, i] - means[i]
        image[:, :, i] = image[:, :, i] / stds[i]

    

    image = torch.from_numpy(image.transpose((2, 0, 1)).astype(np.float32)).float().unsqueeze(0)

    if cuda:
        image = image.cuda()
        with torch.no_grad():
            output = model(image)
            output = output.data.cpu().numpy()
            prediction = np.argmax(output, axis=1)[0]
            for cid, c in enumerate(classes):
                mask = np.zeros((prediction.shape[0], prediction.shape[1]), np.uint8) +255
                mask[prediction == cid+1] = 0
                mask = cv2.morphologyEx(255-mask, cv2.MORPH_OPEN, kernel)
                mask = 255-cv2.morphologyEx(mask, cv2.MORPH_DILATE, kernel2)

                mask = cv2.resize(mask,dsize=(oriDim[1],oriDim[0]), interpolation=cv2.INTER_NEAREST)
                cv2.imwrite('val_result/' + c + '/' + fn+'.png', mask)
                
            segmap = decode_segmap(prediction, dataset='pascal')
            segmap = (segmap*255).astype(np.uint8)
            segmap = cv2.resize(segmap,dsize=(oriDim[1],oriDim[0]))
            segmap = segmap[:, :, ::-1]
            cv2.imwrite(outPath,segmap)
    print('Done inference '+fn)
exit(1)

Done inference A0036
Done inference N0293
Done inference A0013
Done inference A0056
Done inference N0174
Done inference N0018
Done inference A0060
Done inference A0041
Done inference N0290
Done inference N0279
