## 模型预测分析、可视化

In [None]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import math
from matplotlib import pyplot as plt
from tqdm import tqdm, tqdm_notebook
import pickle
import os
import logging
import time
import gc
from IPython.core.debugger import set_trace

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.utils import save_checkpoint, load_checkpoint, set_logger
from utils.gpu_utils import set_n_get_device

#from model.deeplab_model.deeplab import DeepLab, predict_proba
from model.model_unet import UNetResNet34, predict_proba
from dataset.dataset_unet import prepare_trainset

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def inverse_sigmoid(x):
    return np.log(x / (1-x))

%matplotlib inline

In [None]:
######### Config the training process #########
MODEL = 'UNetResNet34' #'resnet'
print('====MODEL ACHITECTURE: %s===='%MODEL)

device = set_n_get_device("0", data_device_id="cuda:0")
multi_gpu = None #[0,1] #None

SEED = 2019
debug = False
IMG_SIZE = 512
BATCH_SIZE = 8#32
NUM_WORKERS = 24
torch.cuda.manual_seed_all(SEED)

In [None]:
train_dl, val_dl = prepare_trainset(BATCH_SIZE, NUM_WORKERS, SEED, IMG_SIZE, debug)

In [None]:
# y should be makeup
y_valid = []
for i, (image, masks) in enumerate(val_dl):
    #if i==10:
    #    break
    truth = masks.to(device=device, dtype=torch.float)
    y_valid.append(truth.cpu().numpy())
y_valid = np.concatenate(y_valid, axis=0)
y_valid.shape

In [None]:
# net = DeepLab(num_classes=2,
#               backbone=MODEL,#resnet34, resnet101
#               output_stride=16,#default 16, 8
#               sync_bn=None,
#               freeze_bn=False,
#               debug=False
#              ).cuda(device=device)

net = UNetResNet34(debug=False).cuda(device=device)

#checkpoint_path = '../checkpoint/deeplabv3plus_resnet_1280_v2_seed2345/best.pth.tar'
checkpoint_path = '../checkpoint/UNetResNet34_512_v1_seed2019/best.pth.tar'

net, _ = load_checkpoint(checkpoint_path, net)

if multi_gpu is not None:
    net = nn.DataParallel(net, device_ids=multi_gpu)

In [None]:
%%time
preds_valid = predict_proba(net, val_dl, device, multi_gpu=multi_gpu, mode='valid', tta=True)

In [None]:
y_valid.shape, preds_valid.shape

In [None]:
# ## search for best thresholds
# def calculate_dice(logit, truth, EMPTY_THRESHOLD=400, MASK_THRESHOLD=0.22):
#     IMG_SIZE = logit.shape[-1] #256
#     logit = sigmoid(logit)#.reshape(n, -1)
#     pred = (logit>MASK_THRESHOLD).astype(np.int)
#     pred_clf = (pred.reshape(pred.shape[0], -1).sum(axis=1)<EMPTY_THRESHOLD).astype(np.int)
#     pred[pred_clf.reshape(-1,)==1, ] = 0
#     return dice_overall(pred, truth)

# def dice_overall(pred_mask, truth_mask, eps=1e-8):
#     n = pred_mask.shape[0]
#     pred_mask = pred_mask.reshape(n, -1)
#     truth_mask = truth_mask.reshape(n, -1)
#     intersect = (pred_mask * truth_mask).sum(axis=1).astype(np.float)
#     union = (pred_mask + truth_mask).sum(axis=1).astype(np.float)
#     return ((2.0*intersect + eps) / (union+eps)).mean()

In [None]:
# EMPTY_THRESHOLD_candidate = np.arange(1400, 1520, 20)#for 512
# #np.arange(350, 450, 10) #for 256
# #np.arange(6000, 7000, 100)#for 1024
# #np.arange(2900, 4200, 100)#for 768
# MASK_THRESHOLD_candidate = np.arange(0.18, 0.23, 0.01)#np.arange(0.19, 0.27, 0.01)
# M, N = len(EMPTY_THRESHOLD_candidate), len(MASK_THRESHOLD_candidate)
# best_threshold = None
# best_score = 0

# for i in tqdm_notebook(range(M)):
#     EMPTY_THRESHOLD = EMPTY_THRESHOLD_candidate[i]
#     for j in range(N):
#         MASK_THRESHOLD = MASK_THRESHOLD_candidate[j]
#         dice_score = calculate_dice(preds_valid, y_valid.squeeze(1), EMPTY_THRESHOLD, MASK_THRESHOLD)
#         print('CLF_EMPTY_THRESHOLD: %f, MASK_THRESHOLD: %f, dice_score: %f'%(EMPTY_THRESHOLD, MASK_THRESHOLD, dice_score))
#         if dice_score>best_score:
#             best_threshold = [EMPTY_THRESHOLD, MASK_THRESHOLD]
#             best_score = dice_score

In [None]:
#EMPTY_THRESHOLD, MASK_THRESHOLD = best_threshold
EMPTY_THRESHOLD, MASK_THRESHOLD, best_score = 400, 0.21, -1
EMPTY_THRESHOLD, MASK_THRESHOLD, best_score

In [None]:
def predict_mask(logit, EMPTY_THRESHOLD, MASK_THRESHOLD):
    """Transform each prediction into mask.
    input shape: (256, 256)
    """
    IMG_SIZE = logit.shape[-1] #256
    logit = sigmoid(logit)#.reshape(n, -1)
    pred = (logit>MASK_THRESHOLD).astype(np.int)
    if pred.sum() < EMPTY_THRESHOLD:
        return np.zeros(pred.shape).astype(np.int)
    else:
        return pred
    return pred

In [None]:
## visualize predicted masks
start = 5
rows = 10

cnt = 0
for idx, (img, mask) in enumerate(val_dl):
    if idx<start:
        continue
    for j in range(BATCH_SIZE):#BATCH_SIZE=8
        not_empty = mask[j][0].sum()>0
        if not_empty:
            cnt+=1
            pred_mask = predict_mask(preds_valid[idx*BATCH_SIZE+j], EMPTY_THRESHOLD, MASK_THRESHOLD)#EMPTY_THRESHOLD=0
            #if pred_mask.sum()==0:
            #    continue
            fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
            ax0.imshow(img[j][0].numpy(), plt.cm.bone)
            ax1.imshow(mask[j][0], vmin=0, vmax=1, cmap="Reds")
            ax2.imshow(pred_mask, vmin=0, vmax=1, cmap="Blues")
            if not_empty.item():
                ax1.set_title('Targets(Has Mask)')
            else:
                ax1.set_title('Targets(Empty)')
            ax2.set_title('Predictions')
        if cnt>rows:
            break
    if cnt>rows:
            break

In [None]:
# ## visualize predicted masks
# rows = 20

# cnt = 0
# for idx, (img, mask) in enumerate(val_dl):
#     for j in range(BATCH_SIZE):#BATCH_SIZE=8
#         is_empty = mask[j][0].sum()==0
#         if is_empty:
#             cnt+=1
#             pred_mask = predict_mask(preds_valid[idx*BATCH_SIZE+j], EMPTY_THRESHOLD, MASK_THRESHOLD)
#             #if pred_mask.sum()==0:
#             #    continue
#             fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
#             ax0.imshow(img[j][0].numpy(), plt.cm.bone)
#             ax1.imshow(mask[j][0], vmin=0, vmax=1, cmap="Reds")
#             ax2.imshow(pred_mask, vmin=0, vmax=1, cmap="Blues")
#             if is_empty.item():
#                 ax1.set_title('Targets(Empty Mask)')
#             else:
#                 ax1.set_title('Targets(Has Mask)')
#             ax2.set_title('Predictions')
#         if cnt>rows:
#             break
#     if cnt>rows:
#             break

In [None]:
# s = (sigmoid(preds_valid)>MASK_THRESHOLD).reshape(1064, -1).sum(axis=1)
# (s>1420).mean(), (s>0).mean()

## 测试集

In [None]:
import glob
from dataset.dataset_unet import prepare_testset

In [None]:
test_fnames = [f.split('/')[-1][:-4] for f in glob.glob('../data/processed/test/*')]
len(test_fnames), test_fnames[0]

In [None]:
test_dl = prepare_testset(BATCH_SIZE, NUM_WORKERS, IMG_SIZE)

In [None]:
%%time
preds_test = predict_proba(net, test_dl, device, multi_gpu=multi_gpu, mode='test', tta=True)

In [None]:
preds_test.shape

In [None]:
## visualize predicted masks
start = 0
total = 19

fig=plt.figure(figsize=(15, 20))
cnt = 0
for idx, img in enumerate(test_dl):
    if idx<start:
        continue
    for j in range(BATCH_SIZE):#BATCH_SIZE=8
        cnt+=1
        pred_mask = predict_mask(preds_test[idx*BATCH_SIZE+j], EMPTY_THRESHOLD, MASK_THRESHOLD)
        #if pred_mask.float().mean()==0:
        #    continue
        ax = fig.add_subplot(5, 4, cnt)
        plt.imshow(img[j][0].numpy(), plt.cm.bone)
        plt.imshow(pred_mask, alpha=0.3, cmap="Reds")
        if pred_mask.sum()>0:
            plt.title('Predict Mask')
        else:
            plt.title('Predict Empty')
        if cnt>total:
            break
    if cnt>total:
            break