# Imports

In [1]:
# %load_ext autoreload
# %autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:100% !important;}</style>"))

In [2]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from ipywidgets import interact
import ipywidgets as widgets
import matplotlib.pyplot as plt
import albumentations as albu
import albumentations.pytorch as albu_pt
%matplotlib inline

import apex
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import pytorch_tools as pt

from src.dataset import OpenCitiesDataset, OpenCitiesTestDataset, InriaTilesDataset
from src.augmentations import get_aug
from src.callbacks import ThrJaccardScore
from pytorch_tools.fit_wrapper.callbacks import SegmCutmix

In [3]:
import yaml
from src.utils import MODEL_FROM_NAME
from src.utils import TargetWrapper
from pytorch_tools.fit_wrapper.callbacks import Callback
from pytorch_tools.utils.misc import to_numpy
from src.utils import criterion_from_list
from src.utils import ToCudaLoader

# Get dataloaders

In [4]:
SZ = 384
BS = 32
BUILDINGS_ONLY = False
RETURN_DISTANCE = False
# RETURN_DISTANCE = True

aug = get_aug("medium", SZ)

val_aug = get_aug("val", SZ)

test_aug = get_aug("test", SZ)

val_dtst = OpenCitiesDataset(split="val", transform=val_aug, buildings_only=BUILDINGS_ONLY, return_distance=RETURN_DISTANCE)
val_dtld = DataLoader(val_dtst, batch_size=BS, shuffle=False, num_workers=4, drop_last=True)
val_dtld_i = iter(val_dtld)

train_dtst = OpenCitiesDataset(split="train", transform=aug, buildings_only=BUILDINGS_ONLY, return_distance=RETURN_DISTANCE)
train_dtld = DataLoader(train_dtst, batch_size=BS, shuffle=True, num_workers=8, drop_last=True)
train_dtld_i = iter(train_dtld)

test_dtst = OpenCitiesTestDataset(transform=test_aug)
test_dtld = DataLoader(test_dtst, batch_size=BS, shuffle=False, num_workers=8, drop_last=False)
test_dtld_i = iter(test_dtld)

In [5]:
val_dtst_inria = InriaTilesDataset(split="val", transform=val_aug)
val_dtst_inria = val_dtst_inria + val_dtst
val_dtld_inria = DataLoader(val_dtst_inria, batch_size=BS, shuffle=True, num_workers=8, drop_last=True)

train_dtst_inria = InriaTilesDataset(split="train", transform=aug)
train_dtld_inria = DataLoader(train_dtst_inria, batch_size=BS, shuffle=True, num_workers=8, drop_last=True)

In [6]:
val_dtld_gpu = ToCudaLoader(val_dtld)
train_dtld_gpu = ToCudaLoader(train_dtld)
val_dtld_inria_gpu = ToCudaLoader(val_dtld_inria)
train_dtld_inria_gpu = ToCudaLoader(train_dtld_inria)

# Results exploration

In [9]:
PREV_WEIGHTS = None
PREDS = None
LOADER = val_dtld_gpu
# LOADER = val_dtld_inria_gpu
IMGS, MASKS = next(iter(LOADER))
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)
# IMGS, MASKS = IMGS.cpu(), MASKS.cpu()

In [10]:
@interact(
    weights=sorted(os.listdir("logs/")),
    N=widgets.IntSlider(min=0, max=BS, continuous_update=True),
    thr=widgets.FloatSlider(0.5, min=0.2, max=0.8, step=0.1, continuous_update=False),
)
def foo(weights=None, N=0, thr=0.5):
    global PREV_WEIGHTS
    global PREDS
#     global PREV_IMGS_MASKS_PREDS
    
    if weights is None:
        print("select weights")
        return 
    
    if weights != PREV_WEIGHTS:
        PREV_WEIGHTS = weights
        log_path = "logs/" + weights + "/"
        config = yaml.load(open(log_path + "config.yaml"))
        model = MODEL_FROM_NAME[config["segm_arch"]](config["arch"], encoder_weights=None, **config.get("model_params", {})).cuda()
        model.load_state_dict(torch.load(log_path + 'model.chpn')["state_dict"], strict=False)
        
#         val_loader = PyTorchDALIWrapper(
#             5 if config["use_diff"] else 4, 
#             data_dir='/home/zakirov/datasets/kamni_clear/old_val/', 
#             train=False, 
#             crop_size=512, 
#             batch_size=8,
#             use_diff=config["use_diff"],
#         )
#         imgs, masks = next(iter(val_loader))
        with torch.no_grad():
            PREDS = model(IMGS).cpu().detach().sigmoid()
        del model
#         del val_loader
#         PREV_IMGS_MASKS_PREDS = (imgs, masks, preds)
        
    
#     (IMGS, MASKS, PREDS) = PREV_IMGS_MASKS_PREDS
#     if True:
#         diff = to_numpy(IMGS[N, 3]) 
#         diff = (diff + 1) * 0.5
#         clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
#         diff = clahe.apply((diff * 255).astype(np.uint8)) / 255
#     else:
#         diff = to_numpy(imgs[N, :3])

    img = to_numpy(IMGS[N]).swapaxes(0, 2)
    img = np.clip((img * STD + MEAN), 0, 1)
#     print(img.min(), img.max())
    mask = to_numpy(MASKS[N]).swapaxes(0, 2)
    mask[:, :, :2] = (mask[:, :, :2] + 1) * 0.5
#     print(mask.min(), mask.max())
    pred = np.repeat(to_numpy(PREDS[N]), 3, 0).swapaxes(0, 2)
    thr_mask = (pred > thr)
#     print(pred.min(), pred.max())
    stacked = np.hstack([img, mask, pred, thr_mask])
    plt.figure(figsize=(32,8))
    plt.imshow(stacked, cmap="gray")
    plt.axis("off")

interactive(children=(Dropdown(description='weights', options=('2.deeplab_nov_20200212_171538', '2.deeplab_nov…