In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import cv2 as cv
import os
import rasterio
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import PIL.Image
import torch
import pytorch_lightning as pl
from IPython.display import display
from rasterio.windows import Window
from zeus.utils import home
from zeus.plotting.utils import axes
from kidney.experiments.aug import get_dataset_input_size
from kidney.datasets.kaggle import outlier, get_reader, DatasetReader, SampleType
from kidney.experiments import FCNExperiment
from kidney.inference.inference import SlidingWindow, SlidingWindowConfig
from kidney.inference.window import sliding_window_boxes
from kidney.utils.tiff import read_tiff
from kidney.utils.mask import rle_decode, rle_numba_encode
from kidney.utils.plotting import preview_arrays

In [None]:
SUBMIT_DIR = "/mnt/fast/kaggle/submits/kidney"

In [None]:
EXPERIMENT = "aug"
TIMESTAMP = "Wed_06_Jan__18_19_20"
WEIGHTS = "epoch=6_avg_val_loss=0.1043.ckpt"
CHECKPOINT_PATH = os.path.join(SUBMIT_DIR, "fcn_resnet50_e6_avl1043_w1024_o32_sz256_expo.csv")

In [None]:
DEVICE = torch.device("cuda:1")
DEBUG = False

In [None]:
def get_checkpoint_paths(
    experiment: str,
    timestamp: str,
    weights: str
):
    dirname = home(f"experiments/{experiment}/checkpoints/{timestamp}")
    info_filename = os.path.join(dirname, "info.pth")
    weights_filename = os.path.join(dirname, weights)
    return info_filename, weights_filename

In [None]:
def get_inference(
    info_filename: str,
    weights_filename: str,
    factory: pl.LightningModule,
    overlap: int,
    max_batch_size: int,
    check_for_outliers: bool,
    device: torch.device = DEVICE,
    debug: bool = False
):
    meta = torch.load(info_filename)
    meta["params"]["fcn_pretrained"] = False
    meta["params"]["fcn_pretrained_backbone"] = False
    experiment = factory.load_from_checkpoint(
        weights_filename, 
        params=meta["params"], 
        strict=False
    )
    transformers = meta["transformers"]
    inference = SlidingWindow(
        model=experiment.eval().to(device),
        config=SlidingWindowConfig(
            window_size=get_dataset_input_size(meta["params"]["dataset"]),
            overlap=overlap,
            max_batch_size=max_batch_size,
            check_for_outliers=check_for_outliers,
            transform_input=transformers.test_preprocessing,
            transform_output=transformers.test_postprocessing
        ),
        device=device,
        debug=debug
    )
    return inference

In [None]:
inference = get_inference(
    *get_checkpoint_paths(
        experiment=EXPERIMENT,
        timestamp=TIMESTAMP,
        weights=WEIGHTS
    ),
    factory=FCNExperiment,
    overlap=32,
    max_batch_size=50,
    check_for_outliers=True,
    debug=DEBUG
)

In [None]:
reader = get_reader()

In [None]:
sample_type = SampleType.Unlabeled

In [None]:
predictions = inference.predict_from_reader(reader, sample_type, encoder=rle_numba_encode)

In [None]:
predictions_df = pd.DataFrame(predictions)

In [None]:
predictions_df.columns = ["id", "predicted"]

In [None]:
if sample_type == SampleType.Unlabeled:
    predictions_df.to_csv(CHECKPOINT_PATH, index=False)
    print("saved:", CHECKPOINT_PATH)

In [None]:
predictions_df

## Small-Scale Predictions Preview

In [None]:
from zeus.plotting.utils import calculate_layout

In [None]:
sz = 1024
total = predictions_df.shape[0]
n, m = calculate_layout(total, n_cols=3)
grid = axes(subplots=(n, m), figsize=(30, 40))

for ax in grid.flat:
    ax.axis("off")

for i in range(total):
    record = predictions_df.iloc[i]
    meta = reader.fetch_meta(record.id)
    tiff = read_tiff(meta["tiff"])
    y_pred = rle_decode(record.predicted, tiff.shape)
    if meta.get("mask") is not None:
        y_true = rle_decode(meta["mask"], tiff.shape)
    else:
        y_true = None
    tiff, y_pred, y_true = [
        cv.resize(arr, (sz, sz)) 
        if arr is not None 
        else arr 
        for arr in (tiff, y_pred, y_true)]
    ax = grid.flat[i]
    preview_arrays(tiff, gt=y_true, pred=y_pred, ax=ax)
    ax.set_title(record.id, fontsize=20)

## Large-Scale Predictions Preview

In [None]:
def predict(key: str):
    global reader, inference
    meta = reader.fetch_meta(key)
    full_size_image = read_tiff(meta["tiff"])
    mask_true = rle_decode(meta["mask"], full_size_image.shape)
    mask_pred = inference.predict_from_file(meta["tiff"])
    return full_size_image, mask_true, mask_pred

In [None]:
THUMB_SIZE = 4096

In [None]:
def preview(
    image: np.ndarray, 
    mask_true: np.ndarray,
    mask_pred: np.ndarray, 
    thumb_size: int = THUMB_SIZE,
    ax=None, 
    **fig_params
):
    ax = axes(ax=ax, **fig_params)
    thumb_img, thumb_gt, thumb_pred = [
        cv.resize(arr, (thumb_size, thumb_size)) 
        for arr in (image, mask_true, mask_pred)
    ]
    preview_arrays(thumb_img, thumb_gt, thumb_pred, ax=ax)
    return ax

In [None]:
for key in train_keys:
    img, y_true, y_pred = predict(key)
    ax = preview(img, y_true, y_pred, figsize=(30, 30))
    ax.set_title(key, fontsize=20)
    display(ax.figure)
    ax.figure.savefig(f"/mnt/fast/data/{key}.png", format="png")