In [None]:
import os
import typing as T
from collections import OrderedDict
from pathlib import Path

import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import torch
from torchmetrics import (
    Accuracy,
    F1Score,
    JaccardIndex,
    Precision,
    Recall,
    MetricCollection,
)
from empatches import EMPatches

from mine_seg_sat.config import get_model_config
from mine_seg_sat.dataset import MineSATDataset
from mine_seg_sat.train_utils.utils import get_model, get_loss
from mine_seg_sat.train_utils.scale import get_band_norm_values_from_root


data_path = Path("/mnt/media/data/mine_data/2021/canada/data_dir/mine_dataset")
model_path = Path("very_real_path_to_model_weights")
os.environ["MODEL_CONFIG"] = (model_path / "config.yaml").as_posix()
config = get_model_config()
band_meta = get_band_norm_values_from_root(data_path, min_max=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    print("Warning: Using CPU")

In [None]:
transforms = A.Compose([ToTensorV2()], is_check_shapes=False)
dataset = MineSATDataset(split="test", data_path=data_path, transformations=transforms, max_values=band_meta[1])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1)
image, label = dataset[0]
print(f"Image and label shapes: {image.shape}, {label.shape}, number of samples: {len(dataset)}")

In [None]:
patchify = EMPatches()
img_patches, indices = patchify.extract_patches(image.permute(1, 2, 0).detach().cpu(), patchsize=512, overlap=0.8)
img_patches[0].shape, img_patches[1].shape, len(img_patches)

In [None]:
loss = get_loss(config)
model = get_model(config, device=device)
weights = torch.load((model_path / "best_model.pth").as_posix(), map_location=device)
new_state_dict = {k.replace('module.', ''): v for k, v in weights["model_state_dict"].items()}
model.load_state_dict(new_state_dict)

In [None]:
import typing
from itertools import zip_longest
from collections import Counter
from numpy import prod
from functools import partial

from fvcore.nn import FlopCountAnalysis
from fvcore.nn.jit_handles import generic_activation_jit

def get_shape(val: object) -> typing.List[int]:
    """
    Get the shapes from a jit value object
    """
    if val.isCompleteTensor():
        r = val.type().sizes()
        if not r:
            r = [1]
        return r
    elif val.type().kind() in ("IntType", "FloatType"):
        return [1]
    else:
        raise ValueError()

def basic_binary_op_flop_jit(inputs, outputs, name):
    input_shapes = [get_shape(v) for v in inputs]
    # for broadcasting
    input_shapes = [s[::-1] for s in input_shapes]
    max_shape = np.array(list(zip_longest(*input_shapes, fillvalue=1))).max(1)
    flop = prod(max_shape)
    flop_counter = Counter({name: flop})
    return flop_counter


def pretty_flops(num_flops: int):
    """
    Pretty print the number of FLOPs.
    """
    units = [("GFLOPs", 1e9), ("MFLOPs", 1e6), ("KFLOPs", 1e3), ("FLOPs", 1)]
    for unit_name, unit_value in units:
        if num_flops >= unit_value:
            return f"{num_flops / unit_value:.2f} {unit_name}"

    return "0 FLOPs"


input_dim = 128
batch_size = 4
num_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
inputs = torch.randn(2, 12, 512, 512).to(device)

counter = FlopCountAnalysis(model, inputs=inputs)
counter.set_op_handle("aten::softmax", generic_activation_jit("aten::softmax"))
counter.set_op_handle("aten::sigmoid", generic_activation_jit("aten::gelu"))
counter.set_op_handle("aten::gelu", generic_activation_jit("aten::gelu"))
counter.set_op_handle("aten::mish", generic_activation_jit("aten::mish"))
counter.set_op_handle("aten::div_", partial(basic_binary_op_flop_jit, name='aten::div_'))
counter.set_op_handle("aten::mul", partial(basic_binary_op_flop_jit, name='aten::mul'))
counter.set_op_handle("aten::add", partial(basic_binary_op_flop_jit, name='aten::add'))
counter.set_op_handle("aten::add_", partial(basic_binary_op_flop_jit, name='aten::add_'))

print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print(f"Total number of estimated flops: {pretty_flops(counter.total())}")

In [None]:
def get_metric_collection(device: str) -> MetricCollection:
    return MetricCollection(
        {
            "accuracy": Accuracy(task="multiclass", num_classes=2, average=None).to(device),
            "f1": F1Score(task="multiclass", num_classes=2, average=None).to(device),
            "iou": JaccardIndex(task="multiclass", num_classes=2, average=None).to(device),
            "precision": Precision(task="multiclass", num_classes=2, average=None).to(device),
            "recall": Recall(task="multiclass", num_classes=2, average=None).to(device),
        }
    )

In [None]:
image, mask = dataset[0]
mask = mask.to(device).long()
print(f"Mask shape: {mask.shape}")
mask = mask.view(1, 1, *mask.shape)
print(f"Mask shape: {mask.shape}")
print(f"Image shape: {image.shape}, transformed image shape: {image.permute(1, 2, 0).shape}")
print(f"Min and max values: {image.min().item()}, {image.max().item()}")
print(f"Mask values: {mask.min().item()}, {mask.max().item()}")

In [None]:
def evaluate_model(
    dataset: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    device: torch.device,
    patch_size: int = 512,
    num_prints: int = 20,
) -> T.Tuple[float, float]:
    """
    Evaluate the model on the dataset.
    """
    model.eval()
    loss = 0.0
    prints = 0
    metrics = get_metric_collection(device)
    with torch.no_grad():
        for index in (range(len(dataset))):
            image, mask = dataset[index]
            mask = mask.to(device).long()
            # mask = torch.tensor(mask, device=device).long()
            mask = mask.view(1, 1, *mask.shape)
            img_patches, img_indices = patchify.extract_patches(image.permute(1, 2, 0).to(device).float(), patchsize=patch_size, overlap=0.8)
            outputs = []
            for image in img_patches:
                image = image.permute(2, 0, 1).unsqueeze(0)
                # print(f"Image shape: {image.unsqueeze(0).shape}, Image min: {image.unsqueeze(0).min()}, Image max: {image.unsqueeze(0).max()}")
                output = model(image)
                if (isinstance(output, dict) or isinstance(output, OrderedDict)) and "out" in output:
                    output = output["out"]
                print(f"output max: {output.max()}, output min: {output.min()}, output shape: {output.shape}")
                prob_mask = output.sigmoid()
                print(f"Mask min: {prob_mask.min()}, Mask max: {prob_mask.max()}")
                outputs.append(prob_mask.squeeze(0).permute((1, 2, 0)).detach().cpu())

            # reassemble the patches into a complete image
            merged_preds = patchify.merge_patches(outputs, img_indices, "avg")
            merged_preds = torch.tensor(merged_preds, device=device).float().permute((2, 0, 1)).unsqueeze(0)

            metrics.update(merged_preds, mask)
            loss += criterion(merged_preds, mask)
            if prints < num_prints:
                prints += 1
                merged_preds = merged_preds.view(merged_preds.shape[2], merged_preds.shape[3])
                dataset.display_model_output(index, merged_preds.detach().cpu().numpy())

        print(f"Loss: {loss:.4f}")
        for name, metric in metrics.items():
            print(f"{name}: {metric.compute()}")

    return metrics.items()

results = evaluate_model(dataset, model, loss, device)