In [None]:
import json
import os
from argparse import ArgumentParser, Namespace
from functools import reduce
from glob import glob

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.utils import draw_segmentation_masks

from sidewalk_widths_extractor import Trainer, seed_all
from sidewalk_widths_extractor.dataset import SatelliteDataset
from sidewalk_widths_extractor.modules.seg import SegModule
from sidewalk_widths_extractor.utilities import get_device
from sidewalk_widths_extractor.utilities.io import mkdir

In [None]:
IMAGE_FOLDER_PATH = "data//images"
MASK_FOLDER_PATH = "data//masks"
TARGET_LOG_FOLDER = "logs//test"

EXTRACT_PREDICTIONS = True

batch_size = 8
shuffle = False
drop_last_batch = True
pin_memory = True
num_workers = 0
persistent_workers = False

device = get_device()
print("using", device)


In [None]:
target_models = [
    {
        "model": "logs//hpt//18-09-2022 11-49-51 hpt 0//checkpoints//best_network.pth.tar",
        "settings": "logs//hpt//18-09-2022 11-49-51 hpt 0//settings.json",
    },
    {
        "model": "logs//hpt//18-09-2022 11-51-13 hpt 1//checkpoints//best_network.pth.tar",
        "settings": "logs//hpt//18-09-2022 11-51-13 hpt 1//settings.json",
    },
    {
        "model": "logs//hpt//18-09-2022 11-52-26 hpt 2//checkpoints//best_network.pth.tar",
        "settings": "logs//hpt//18-09-2022 11-52-26 hpt 2//settings.json",
    },
    {
        "model": "logs//hpt//18-09-2022 11-53-42 hpt 3//checkpoints//best_network.pth.tar",
        "settings": "logs//hpt//18-09-2022 11-53-42 hpt 3//settings.json",
    },
    {
        "model": "logs//hpt//18-09-2022 11-55-04 hpt 4//checkpoints//best_network.pth.tar",
        "settings": "logs//hpt//18-09-2022 11-55-04 hpt 4//settings.json",
    },
]

In [None]:
results = []

for i, m in enumerate(target_models):
    model_path = m["model"]
    settings_path = m["settings"]

    assert os.path.exists(model_path)
    assert os.path.exists(settings_path)

    settings = None
    with open(settings_path) as file:
        settings = json.load(file)

    assert isinstance(settings["module"]["network"], dict)
    assert isinstance(settings["module"]["optimizer"], dict)
    assert isinstance(settings["module"]["criterion"], dict)

    run_id = settings["run"]["run_id"]
    no_train_samples = settings["run"]["no_train_samples"]
    no_parameters = settings["module"]["network"]["no_parameters"]

    seed_all()

    dataset = SatelliteDataset(IMAGE_FOLDER_PATH, MASK_FOLDER_PATH)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last_batch,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
    )
    module = SegModule(
        settings["module"]["network"]["id"],
        settings["module"]["network"]["params"],
        settings["module"]["optimizer"]["id"],
        settings["module"]["optimizer"]["params"],
        settings["module"]["criterion"]["id"],
        settings["module"]["criterion"]["params"],
        device=device,
        save_network_checkpoint=False,
        save_optimizer_checkpoint=False,
    )

    module.load({"network": model_path})
    epoch = module.curr_epoch_idx

    trainer = Trainer(TARGET_LOG_FOLDER, progress_bar=False, transfer_results_to_cpu=True)

    result = trainer.test(dataloader, module)

    tp = sum(result["tp"])
    fp = sum(result["fp"])
    fn = sum(result["fn"])
    tn = sum(result["tn"])

    accuracy = (tp + tn) / (tp + tn + fp + fn)

    if not tp + fp == 0:
        precision = tp / (tp + fp)
    else:
        precision = torch.tensor(0.0)

    if not tp + fn == 0:
        recall = tp / (tp + fn)
    else:
        recall = torch.tensor(0.0)

    f1 = (
        2 * (precision * recall) / (precision + recall)
        if precision + recall != 0.0
        else torch.tensor(0.0)
    )

    iou = tp / (tp + fn + fp)

    dice = 2 * tp / (2 * tp + fn + fp)

    result = {
        "id": [run_id],
        "epoch": [epoch],
        "no_train_samples": [no_train_samples],
        "no_parameters": [no_parameters],
        "tp": [tp.item()],
        "fp": [fp.item()],
        "fn": [fn.item()],
        "tn": [tn.item()],
        "accuracy": [accuracy.item()],
        "precision": [precision.item()],
        "recall": [recall.item()],
        "f1": [f1.item()],
        "iou": [iou.item()],
        "dice": [dice.item()],
    }

    run_target_path = os.path.join(TARGET_LOG_FOLDER, run_id)
    mkdir(run_target_path)
    with open(os.path.join(run_target_path, "settings.json"), "w") as file:
        json.dump(settings, file, indent=2, sort_keys=False)

    results.append(pd.DataFrame(result, index=[i]))

    if EXTRACT_PREDICTIONS:
        target_pred_folder = os.path.join(run_target_path, "predictions")
        mkdir(target_pred_folder)

        image_paths = [x for x in sorted(glob(os.path.join(IMAGE_FOLDER_PATH, "*")))]
        mask_paths = [x for x in sorted(glob(os.path.join(MASK_FOLDER_PATH, "*")))]

        tensor_transform = T.ToTensor()
        image_transform = T.ToPILImage()

        for img_path, mask_path in zip(image_paths, mask_paths):
            target_img_path = os.path.join(
                target_pred_folder, os.path.splitext(os.path.basename(img_path))[0]
            )

            mkdir(target_img_path)

            full_image = cv2.imread(img_path)
            full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)

            full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            # full_mask = np.where(full_mask == 255, 1, 0)

            h, w, _ = full_image.shape

            if h == 512 and w == 512:
                images = [
                    full_image[: h // 2, : w // 2],
                    full_image[: h // 2, w // 2 :],
                    full_image[h // 2 :, : w // 2],
                    full_image[h // 2 :, w // 2 :],
                ]
                masks = [
                    full_mask[: h // 2, : w // 2],
                    full_mask[: h // 2, w // 2 :],
                    full_mask[h // 2 :, : w // 2],
                    full_mask[h // 2 :, w // 2 :],
                ]
            else:
                images = [full_image]
                masks = [full_mask]

            for i, (x, y) in enumerate(zip(images, masks)):
                image = tensor_transform(x)
                truth = torch.from_numpy(y)
                pred = module.infer(image).detach().cpu()

                image = (image * 255).type(torch.uint8)
                pred = pred.type(torch.bool)
                # truth = truth.type(torch.bool)

                seg = draw_segmentation_masks(image, pred, alpha=0.3, colors="blue")

                image_transform(seg).save(os.path.join(target_img_path, f"seg_{i}.png"))
                image_transform((pred * 255).type(torch.uint8)).save(
                    os.path.join(target_img_path, f"pred_mask{i}.png")
                )
                image_transform(truth).save(os.path.join(target_img_path, f"truth_mask_{i}.png"))

out = pd.concat(results)
out.to_csv(os.path.join(TARGET_LOG_FOLDER, "results.csv"), index=False)