In [1]:
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import os
import multiprocessing as mp
import seaborn as sns

from code.models import basicunet, resnetunet
from code.datasets import TGSTTADataset
from code.configs import *
from code.train import *
from code.losses import FocalRobustLoss
from code.metrics import *
from code.augmentations import *
from code.utils import *
from torch.utils.data import DataLoader
from IPython.display import clear_output
from code.inference import *

%matplotlib inline

In [4]:
dirs = ["../resnet32_256/test1/logits/", "../resnet32_256/test2/logits/", 
        "../resnet32_256/test3/logits/", "../resnet32_256/test4/logits/", 
        "../resnet32_256/test5/logits/"]
logits, ids = [], []

for name in os.listdir(dirs[0]):
    name = name.split(".")[0]
    id_ = name.split("_")[0]
    ids.append(id_)
    
ids = list(set(ids))
    
for id_ in tqdm_notebook(ids):
    logit_list = []
    for dirname in dirs:
        suffix = ""
        name = id_ + suffix + ".npy"
        path = os.path.join(dirname, name)
        logit = np.load(path)
        logit_list.append(logit)
        
        suffix = "_flipped"
        name = id_ + suffix + ".npy"
        path = os.path.join(dirname, name)
        logit = np.load(path)
        logit = logit[:,::-1]
        logit_list.append(logit)
        
    logit = np.mean(np.array(logit_list), axis=0)
    logit = logit[27:-27,27:-27]
    logit = cv2.resize(logit, dsize=(101, 101))
    logits.append(logit)
logits = np.array(logits)

HBox(children=(IntProgress(value=0, max=18000), HTML(value='')))

In [6]:
np.save("logits", logits)

In [5]:
logits.shape, len(ids)

((18000, 101, 101), 18000)

In [7]:
def sigmoid(x):
    return 1. / (1. + np.exp(-x))

In [134]:
probs = sigmoid(logits)

In [135]:
preds = (probs > 0.5).astype(int)

In [139]:
for p, id_, c in zip(preds, ids, conf):
    p = np.stack((p,p,p), axis=-1)
    target_name = "{}.png".format(id_)
    src_name = os.path.join(PATH_TO_TEST, "images", target_name)
    plt.imsave(os.path.join("./test_pl/masks/", target_name), p)
    os.link(src_name, os.path.join("./test_pl/images/", target_name))