In [None]:
%cd /content/ttAugment/
%pip install -e . 
%pip install --no-cache-dir building-footprint-segmentation

In [None]:
import cv2
import torch
import numpy as np

from building_footprint_segmentation.seg.binary.models import DLinkNet34
from building_footprint_segmentation.helpers.normalizer import min_max_image_net
from building_footprint_segmentation.utils.py_network import (
    to_input_image_tensor,
    add_extra_dimension,
    convert_tensor_to_numpy,
    load_parallel_model,
    adjust_model
)
from torch.utils import model_zoo

from tt_augment.augment import generate_seg_augmenters

%matplotlib inline 
from matplotlib import pyplot as plt

from pathlib import Path

In [None]:
def get_model():
    model = DLinkNet34()
    state_dict = model_zoo.load_url(r"https://github.com/fuzailpalnak/building-footprint-segmentation/releases/download/v0.2.3/DlinkNet.zip", progress=True, map_location="cpu")
    state_dict = adjust_model(state_dict["model"])
    model.load_state_dict(state_dict)
    return model

In [None]:
TRANSFORMATION_TO_APPLY = [
  {"name": "Mirror", "crop_to_dimension": (128, 128)},
  {"name": "CropScale", "crop_to_dimension": (128, 128)},
]

In [None]:
model = get_model()

img_pths = Path("/content/test_imgs").glob("*")
# n_imgs = len(list(img_pths))

for img_pth in img_pths:
    # load og img
    print(f"running img {img_pth.name}")
    original_image = cv2.imread(str(img_pth))
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    # plt.imshow(original_image, interpolation='nearest')
    # plt.show()
    print("image loaded")

    # prep for modelling
    image = min_max_image_net(img=original_image)
    image = np.expand_dims(image, 0)

    h, w, _ = original_image.shape
    number_of_output_channels = 1

    # init test-time transforms
    tta_alt = generate_seg_augmenters(
                image=image,
                window_size=(256, 256),
                output_dimension=(1, h, w, number_of_output_channels),
                transformation_to_apply=TRANSFORMATION_TO_APPLY,
            )

    print("applying transforms to each fragment and running model...")
    # apply transforms to each fragment of img
    for iterator, transformation in enumerate(tta_alt):
        for augmented_fragment in transformation.transform_fragment():
            tensor_image = to_input_image_tensor(augmented_fragment).swapaxes(0, 1)

            with torch.no_grad():
                # Perform prediction
                prediction = model(tensor_image)
                prediction = prediction.sigmoid()
                
                prediction_binary = convert_tensor_to_numpy(prediction.swapaxes(-1, 1))
                # need to correct for original swap - old example ntb failed to do
                # so and was clearly incorrect
                transformation.restore_fragment(prediction_binary.swapaxes(1, 2)) 
    # collect results (mean by default)
    tta_alt.merge()
    output = tta_alt.tta_output()

    # plot og img
    original_image = cv2.imread(str(img_pth))
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    plt.imshow(original_image, interpolation='nearest')
    # plt.show()

    # get binary mask from preds
    msk = output[0] > np.median(output[0]) + 2*np.std(output[0])
    # msk[msk==0] = np.nan
    masked_array = np.ma.array(msk, mask=~msk)

    # overlay on og img and show
    plt.imshow(masked_array, alpha=0.5, cmap='spring', interpolation='nearest')
    plt.show()