# Imports

In [1]:
%load_ext autoreload
%autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:100% !important;}</style>"))

In [2]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from ipywidgets import interact
import ipywidgets as widgets
import matplotlib.pyplot as plt
import albumentations as albu
import albumentations.pytorch as albu_pt
from sklearn.metrics import jaccard_score
%matplotlib inline

import apex
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import pytorch_tools as pt

from src.dataset import OpenCitiesDataset, OpenCitiesTestDataset, InriaTilesDataset, get_dataloaders
from src.augmentations import get_aug
from src.callbacks import ThrJaccardScore
from pytorch_tools.fit_wrapper.callbacks import SegmCutmix
from pytorch_tools.tta_wrapper import TTA

In [3]:
import yaml
from src.utils import MODEL_FROM_NAME
from src.utils import TargetWrapper
from pytorch_tools.fit_wrapper.callbacks import Callback
from pytorch_tools.utils.misc import to_numpy
from src.utils import criterion_from_list
from src.utils import ToCudaLoader
from src.dataset import get_aug

# Get dataloaders

In [4]:
# BS = 32
# train_dtld_gpu , val_dtld_gpu = get_dataloaders(
#     [
#         "tier1", 
# #         "tier2"
#     ],
#     batch_size=BS
# )

In [5]:
val_aug = get_aug("val", size=384)

In [6]:
val_tier1 = OpenCitiesDataset(
    split="val",
    transform=val_aug,
    imgs_path="data/tier_1-images-512",
    masks_path="data/tier_1-masks-512",
)
val_tier1_bo = OpenCitiesDataset(
    split="val",
    transform=val_aug,
    imgs_path="data/tier_1-images-512",
    masks_path="data/tier_1-masks-512",
    buildings_only=True,
)
val_tier2 = OpenCitiesDataset(
    split="val",
    transform=val_aug,
    imgs_path="data/tier_2-images-512",
    masks_path="data/tier_2-masks-512",
)
test_dataset512 = OpenCitiesTestDataset(transform=get_aug("test", size=512))
test_dataset384 = OpenCitiesTestDataset(transform=get_aug("test", size=384))

# Results exploration

In [7]:
def auto_canny(image, sigma=0.33, fixed_thr=False):
    if image.max() <= 1:
        image = (image * 255).astype(np.uint8)
#     image = cv2.blur(image, (11, 11))
    is_image = False
    if len(image.shape) == 3 and image.shape[2] == 3:
        is_image = True
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    # compute the median of the single channel pixel intensities
    v = np.median(image)
    # apply automatic Canny edge detection using the computed median
    lower = int(max(0, (1.0 - sigma) * v))
    upper = int(min(255, (1.0 + sigma) * v))
    if fixed_thr:
        lower, upper = 160, 200
    edged = cv2.Canny(image, lower, upper)
    # return the edged image
    if is_image:
        return np.stack([edged] * 3, axis=2) / 255
    else:
        return edged / 255

def remove_model_from_name(state_dict):
    new_sd = {}
    for k, v in state_dict:
        new_sd[k[6:]] = v
    return 

def process_adaptive_thr(pred):
    thr_mask = cv2.threshold(
        (pred * 255).astype(np.uint8),0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU
    )[1] / 255
    thr_mask[pred < 0.2] = 0
    return thr_mask

In [8]:
PREV_WEIGHTS = None
PREV_MODEL = None
SINGLE_IMG = None
SINGLE_MASK = None
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)

In [9]:
DATASETS = {"val_tier1" : val_tier1, "val_tier1_bo": val_tier1_bo, "val_tier2" : val_tier2, "test512": test_dataset512, "test384": test_dataset384 }
OFFSET = 1000

In [10]:
@interact(
    weights=sorted(os.listdir("logs/")),
    N=widgets.IntSlider(min=0, max=32, continuous_update=False),
    offset=(0, 2000, 100),
    thr=widgets.FloatSlider(0.5, min=0.2, max=0.8, step=0.1, continuous_update=False),
    
)
def foo(
    dataset=DATASETS.keys(),
    weights=None, N=0, offset=0, thr=0.5, adaptive_thr=False, use_tta=False, 
    equilize=False, detect_edges=False, watershed=False, overlay_pred=False, overlay_true=False,
    use_gmean=False,
):
    global PREV_WEIGHTS
    global PREV_MODEL
    global SINGLE_IMG
    global SINGLE_PRED
    
    if weights is None:
        print("select weights")
        return 
    
    if weights != PREV_WEIGHTS:
        PREV_WEIGHTS = weights
        log_path = "logs/" + weights + "/"
        config = yaml.load(open(log_path + "config.yaml"))
        model = MODEL_FROM_NAME[config["segm_arch"]](config["arch"], **config.get("model_params", {})).cuda()
        state_dict = torch.load(log_path + 'model.chpn')["state_dict"]
        if "model." in list(state_dict.keys())[0]:
            # model was trained using tta need to remove `model` from name
            state_dict = {k[6:]:v for k,v in state_dict.items()}
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        PREV_MODEL = model

    if "test" in dataset:
        _, img, _ = DATASETS[dataset][N + offset]
        mask = torch.zeros_like(img)
    else:
        img, mask = DATASETS[dataset][N + offset]
    with torch.no_grad():
        if use_tta:
            model = pt.tta_wrapper.TTA(PREV_MODEL, segm=True, h_flip=True, rotation=[90], merge="gmean" if use_gmean else "mean", activation="sigmoid")
        else:
            model = PREV_MODEL
        pred = model(img.view(1, *img.shape).cuda())
        if not use_tta:
            pred.sigmoid_()
    
    img = to_numpy(img).swapaxes(0, 2)
    img = np.clip((img * STD + MEAN), 0, 1)
    if equilize:
        img = albu.Equalize(always_apply=True, by_channels=False)(image=(img * 255).astype(np.uint8))["image"] / 255
    
    mask = to_numpy(mask).swapaxes(0, 2)
    mask[:, :, :2] = (mask[:, :, :2] + 1) * 0.5
    pred = to_numpy(pred.squeeze())
    if detect_edges:
        img = auto_canny(img, fixed_thr=True, sigma=0.5)
        pred = auto_canny(pred, fixed_thr=True)
    thr_mask = (pred > thr).astype(np.uint8)
    if adaptive_thr:
        thr_mask = process_adaptive_thr(pred)
    
    score = jaccard_score(mask[:,:,2], thr_mask.T, average="micro")
    if mask[:,:, 2].sum() == 0 and thr_mask.sum() == 0:
        score = 1
#     if watershed: 
#         img, pred = watershed(img, pred)
    if overlay_pred:
        img[thr_mask.T == 1] = [1, 0, 0]
    if overlay_true:
        img[mask[:, :, 2] == 1] = [0, 0, 1]
        if overlay_pred:
            img[(mask[:, :, 2] == 1) & (thr_mask.T == 1)] = [1, 0, 1]
    pred = np.stack([pred.T] * 3, 2)
    thr_mask = np.stack([thr_mask.T] * 3, 2)
    stacked = np.hstack([img, mask, pred, thr_mask])
    plt.figure(figsize=(32,8))
    plt.imshow(stacked, cmap="gray")
    plt.title(f"Jaccard={score:.3f}", fontdict={"fontsize": 25})
    plt.axis("off")
    SINGLE_IMG = pred
    SINGLE_PRED = thr_mask

interactive(children=(Dropdown(description='dataset', options=('val_tier1', 'val_tier1_bo', 'val_tier2', 'test…

# Test Results Exploration

In [11]:
!ls logs/

10.bifpn_2l_2dtst_reduced_focal_20200303_202324
10.bifpn_2l_3dtst_reduced_focal_20200303_205035
11.bifpn_2l_2dtst_lookahead_cut_20200304_180726
11.bifpn_2l_2dtst_seres101_20200304_233016
11.bifpn_2l_2dtst_seres101_resume_20200305_111319
12.bifpn_3l_1dtst_effnetb3_20200305_134355
12.bifpn_3l_1dtst_seres50_tta_20200305_172219
13.bifpn_3l_3dtst_effnetb3_resume_20200305_230828
13.bifpn_4l_3dtst_effnetb5_hard_augs_20200306_133212
13.bifpn_4l_3dtst_seres101_better_hard_augs_20200307_113601
13.bifpn_4l_3dtst_seres101_hard_augs_20200307_113309
13.deeplab_3dtst_seres101_hard_augs_20200311_202325
13.test_bifpn_4l_3dtst_effnetb5_20200305_231118
14.tune_2dtst_bifpn_4l_effnetb5_20200306_184111
14.tune_2dtst_bifpn_4l_effnetb5_hard_augs_20200306_232835
14.tune_2dtst_bifpn_4l_seres101_hard_augs_20200310_210420
14.tune_2dtst_bifpn_4l_seres101_hard_augs_OS16_20200313_123409
14.tune_2dtst_deeplab_seres101_hard_augs_20200312_113206
15.tune_1dtst_bifpn_4l_effnetb5_hard_augs_buildings_only

In [12]:
@torch.no_grad()
def load_from_path(path):
    log_path = "logs/" + path + "/"
    config = yaml.load(open(log_path + "config.yaml"))
    model = MODEL_FROM_NAME[config["segm_arch"]](config["arch"], **config.get("model_params", {})).cuda()
    state_dict = torch.load(log_path + 'model.chpn')["state_dict"]
    model.load_state_dict(state_dict, strict=True)
    model = pt.tta_wrapper.TTA(
        model.eval(), segm=True, h_flip=True, rotation=[90], merge="mean", activation="sigmoid"
    )
#     jit_model = torch.jit.trace(model, torch.ones((2, 3, 256, 256)).cuda())
    apex_model = apex.amp.initialize(model, verbosity=False)
    return apex_model

best_models = [
#     "10.bifpn_2l_2dtst_reduced_focal_20200303_202324",
#     "11.bifpn_2l_2dtst_lookahead_cut_20200304_180726",
    "14.tune_2dtst_bifpn_4l_effnetb5_20200306_184111",
    "14.tune_2dtst_bifpn_4l_effnetb5_hard_augs_20200306_232835",
    "14.tune_2dtst_bifpn_4l_seres101_hard_augs_20200310_210420",
    "14.tune_2dtst_bifpn_4l_seres101_hard_augs_OS16_20200313_123409",
    "15.tune_1dtst_bifpn_4l_effnetb5_hard_augs_buildings_only_20200311_001521",
    "15.tune_1dtst_bifpn_4l_seres101_hard_augs_buildings_only_20200311_101635",
    "15.tune_1dtst_bifpn_4l_seres101_hard_augs_OS16_20200314_122204",
]
# jit_model = load_from_path(best_models[0])
loaded_best_models = list(map(load_from_path, best_models))

In [13]:
def process_single_img(img, merge_by_mean=True, adaptive_thr=True, thr=0.5):
    
    @torch.no_grad()
    def make_pred(model):
        return to_numpy(model(img.view(1, *img.shape).cuda()).squeeze())
    
    preds = np.array([make_pred(model) for model in loaded_best_models])
    pred = np.mean(preds, axis=0)
    if merge_by_mean:
        if adaptive_thr:
            thr_mask = process_adaptive_thr(pred)
        else:
            thr_mask = (pred > thr).astype(np.uint8)
    else:
        if adaptive_thr:
            thr_masks = [process_adaptive_thr(p) for p in preds]
        else:
            thr_masks = (preds > thr).astype(np.uint8)
        thr_mask = np.median(thr_masks, axis=0).astype(np.uint8)
    return pred, thr_mask

In [14]:
@interact(
    N_model=(0, len(loaded_best_models), 1),
    N=widgets.IntSlider(min=0 + OFFSET, max=32 + OFFSET, continuous_update=False),
    thr=widgets.FloatSlider(0.5, min=0.2, max=0.8, step=0.1, continuous_update=False),
)
def foo(
    dataset=DATASETS.keys(),
    N_model=0,
    N=0, 
    offset=(0, 4000, 100),
    thr=0.5,
    adaptive_thr=True,
    overlay_pred=False, 
    overlay_true=False,
    merge_by_mean=True,
):
    global SINGLE_PRED
    N += offset
    if "test" in dataset:
        _, img, _ = DATASETS[dataset][N]
        mask = torch.zeros_like(img)
    else:
        img, mask = DATASETS[dataset][N]
    
    
    @torch.no_grad()
    def make_pred(model):
        return to_numpy(model(img.view(1, *img.shape).cuda()).squeeze())
    if N_model == 0:
        pred, thr_mask = process_single_img(img, merge_by_mean, adaptive_thr, thr)
#         preds = np.array([make_pred(model) for model in loaded_best_models])
#         pred = np.mean(preds, axis=0)
#         if merge_by_mean:
#             if adaptive_thr:
#                 thr_mask = process_adaptive_thr(pred)
#             else:
#                 thr_mask = (pred > thr).astype(np.uint8)
#         else:
#             if adaptive_thr:
#                 thr_masks = [process_adaptive_thr(p) for p in preds]
#             else:
#                 thr_masks = (preds > thr).astype(np.uint8)
#             thr_mask = np.median(thr_masks, axis=0).astype(np.uint8)
    else:
        pred = make_pred(loaded_best_models[N_model - 1])
        thr_mask = (pred > thr).astype(np.uint8)
    
    
    img = to_numpy(img).swapaxes(0, 2)
    img = np.clip((img * STD + MEAN), 0, 1)
    mask = to_numpy(mask).swapaxes(0, 2)
    mask[:, :, :2] = (mask[:, :, :2] + 1) * 0.5
    score = jaccard_score(mask[:,:,2], thr_mask.T, average="micro")
    if mask[:,:, 2].sum() == 0 and thr_mask.sum() == 0:
        score = 1
    img_over = img.copy()
    if overlay_pred:
        img_over[thr_mask.T == 1] = [1, 0, 0]
    if overlay_true:
        img_over[mask[:, :, 2] == 1] = [0, 0, 1]
        if overlay_pred:
            img_over[(mask[:, :, 2] == 1) & (thr_mask.T == 1)] = [1, 0, 1]
    pred = np.stack([pred.T] * 3, 2)
    thr_mask = np.stack([thr_mask.T] * 3, 2)
    if "test" in dataset:
        stacked = np.hstack([img, img_over, pred])
        plt.figure(figsize=(24,8))
    else:
        stacked = np.hstack([img_over, mask, pred, thr_mask])
        plt.figure(figsize=(32,8))
    plt.imshow(stacked, cmap="gray")
#     plt.title(f"Jaccard={score:.3f}. Max value: {pred.max():.2f}", fontdict={"fontsize": 25})
    plt.axis("off")
    SINGLE_PRED = pred

interactive(children=(Dropdown(description='dataset', options=('val_tier1', 'val_tier1_bo', 'val_tier2', 'test…

# Run inference

In [15]:
!rm data/preds/*

In [16]:
empty_idx = set(np.load("empty_test_idx.npy"))

In [17]:
PREDS_PATH = "data/preds"
n_img = 0
# max_preds = []
for img, aug_imgs, idx in tqdm(test_dataset512):
#     n_img += 1
#     if n_img < 6:
#         n_img += 1
#         continue
    if idx in empty_idx:
        resized_thr_mask = np.zeros((1024, 1024), dtype=np.uint8)
    else:
        pred, thr_mask = process_single_img(aug_imgs, merge_by_mean=False, adaptive_thr=True, thr=0.5)
    #     max_preds.append((pred.max(), np.percentile(pred, 95), idx))
        resized_thr_mask = cv2.resize(thr_mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
    cv2.imwrite(PREDS_PATH + "/" + (idx + ".tif"), resized_thr_mask)
#     if n_img > 10:
#         break
    
# print(resized_thr_mask.max())
# plt.imshow(resized_thr_mask)

100%|██████████| 11481/11481 [1:17:01<00:00,  2.48it/s]


In [33]:
max_preds_values, max_preds_names = max_preds[:, :2].astype(float), max_preds[:, 2]

In [74]:
def show_by_name(name):
    test_path = "/home/zakirov/datasets/opencities/test"
    img = cv2.imread(f"{test_path}/{name}/{name}.tif")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(6,6))
    plt.imshow(img)