In [1]:
import sys
import os
from pathlib import Path

import numpy as np
import pandas as pd
import skimage.io as io
import matplotlib.pyplot as plt
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage.color import gray2rgb
from torchvision.ops import nms
from pytorch_toolbelt.utils import to_numpy, rle_encode
from dotenv import load_dotenv

from src.visualization import tensor_to_image
from src.postprocessing import remove_overlapping_pixels

current_dir = Path(".")
load_dotenv()
current_dir.absolute()

PosixPath('/workspaces/sartorius_instance_segmentation')

In [2]:
# Global config of configuration
test_images_dir = Path(os.environ['dataset_path']) / "test"
weights_dir = current_dir / "weights" / "maxim_baseline.ckpt"
device = "cpu"

# Local tunable parameters of evaluation
score_threshold = 0.0  # All predictions would be counted, even with low score
nms_threshold = 0.1  # Overlapping instances will be dropped, lower - lower overlap is permitted
mask_threshold = 0.5  # Cut masks by the threshold

In [3]:
assert test_images_dir.is_dir(), f"Check test dir path for correctness, was looking at {test_images_dir.absolute()}"
assert weights_dir.is_file(), f"File not found, was looking at {weights_dir.absolute()}"

In [4]:
preprocess_image = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

In [5]:
model = maskrcnn_resnet50_fpn(progress=False, num_classes=2)
model.load_state_dict(torch.load(weights_dir, map_location=torch.device("cpu")))
model.to(device)
model.eval()

def predict_masks(image: np.ndarray, model) -> np.ndarray:
    """Predicts masks for the given single image"""
    device = next(model.parameters()).device
    image = preprocess_image(image=image)['image']
    image.to(device)
    with torch.no_grad():
        output = model.forward([image])[0]

    scores = output['scores'].detach().cpu()
    masks = output['masks'].squeeze().detach().cpu()
    boxes = output['boxes'].detach().cpu()

    masks = (masks >= mask_threshold).int()

    # Now some masks can be empty (all zeros), we need to exclude them
    # TODO(shamil): this indexing is ugly
    indices = torch.as_tensor([torch.sum(mask) > 0 for mask in masks])
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]

    indices = scores >= score_threshold
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]

    indices = nms(boxes, scores, nms_threshold)
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]
    
    answer_masks = remove_overlapping_pixels(masks.numpy())
    assert np.max(np.sum(answer_masks, axis=0)) <= 1, "Masks overlap"
    return answer_masks

In [6]:
answers = {
    "id": [],
    "predicted" : [],
}
for image_path in test_images_dir.glob("**/*.png"):
    image = io.imread(str(image_path))
    masks = predict_masks(image, model)
    answers["id"].extend(image_path.stem for i in range(len(masks)))
    answers["predicted"].extend(" ".join(map(str, rle_encode(mask))) for mask in masks)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [7]:
submission = pd.DataFrame(answers)
submission.sample(8)

Unnamed: 0,id,predicted
99,d48ec7815252,163 14 682 16 1201 18 1720 20 2239 22 2759 22 ...
185,d8bfd1dafdc4,209651 1 210171 1 210690 1 211727 3 212246 4 2...
21,7ae19de7bc2a,57887 11 58406 14 58925 15 59444 16 59964 16 6...
116,d48ec7815252,82231 4 82751 5 83271 6 83791 6 84310 8 84830 ...
33,7ae19de7bc2a,89846 1 90364 5 90884 6 91403 8 91923 8 92443 ...
125,d48ec7815252,87707 1 88226 4 88746 5 89266 6 89786 6 90305 ...
109,d48ec7815252,199049 5 199566 8 200078 2 200085 9 200596 8 2...
166,d8bfd1dafdc4,163975 4 164495 8 165015 9 165534 9 166054 9 1...


In [8]:
submission.to_csv("submission.csv", index=False)