In [2]:
from src.dl.inference.inferer import Inferer
import src.dl.lightning as lightning
from src.data import PannukeDataModule, ConsepDataModule

# Running Inference

- `stride_size` has a big impact on results. The smaller the better accuracy but results increased mem footprint and running times
- `apply_weights=True` results less weight on the tile boundaries preventing boundary artifacts. Used for the auxilliary branch predictions
- By tweaking `*batch_size` and `loader_num_workers` you can get better performance depending on your machine. Hitting the sweetspot requires a few runs 
- If your `in_dir` contains hundreds of images, use `n_images={int}` to refresh memory after every `n_images` predictions

In [3]:
# Initialize the inferer
in_dir = "path/to/imgs"
gt_dir = None # This is optional. Can be None
exp_name = "panoptic" # name of the experiment (directory)
exp_version = "effnetv2" # name of the experiment version (sub directory inside the experiment dir)
lightning_model = lightning.SegModel.from_experiment(name=exp_name, version=exp_version)

inferer = Inferer(
    lightning_model,
    in_data_dir=in_dir,
    gt_mask_dir=gt_dir,
    patch_size=(256, 256),
    stride_size=80,
    fn_pattern="*",
    model_weights="last",
    apply_weights=True,
    post_proc_method="cellpose",
    loader_batch_size=1,
    loader_num_workers=1,
    model_batch_size=16,
    auto_range=True,
    n_images=32
)

In [None]:
area_classes = {
    "background": 0,
    "areastroma": 1,
    "area_cin": 2,
    "areasquam": 3,
    "areagland": 4,
}

cell_classes = {
    "background": 0,
    "neoplastic": 1,
    "inflammatory": 2,
    "connective": 3,
    "dead": 4,
    "glandular_epithel": 5,
    "squamous_epithel": 6
}

inferer.run_inference(
    save_dir="/path/to/my_geojson_dir",
    fformat="geojson",
    offsets=True,
    classes_sem=area_classes,
    classes_type=cell_classes
)

In [None]:
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from src.utils import FileHandler
from pathlib import Path

from src.utils import draw_thing_contours, draw_stuff_contours, label_sem_map,


ix = 5
keys = list(inferer.soft_insts.keys())
key = keys[ix]
img = FileHandler.read_img([f for f in sorted(Path(inferer.in_data_dir).glob("*")) if key in f.name][0])
areas = draw_stuff_contours(label_sem_map(inferer.sem_maps[key]), img, inferer.sem_maps[key], classes=area_classes, thickness=5, fill_contours=True)
everything = draw_thing_contours(inferer.inst_maps[key], areas, inferer.type_maps[key], classes=cell_classes)

fig, ax = plt.subplots(2, 2, figsize=(40, 40))
ax = ax.flatten()
ax[0].imshow(label2rgb(inferer.sem_maps[key], bg_label=0))
ax[1].imshow(label2rgb(inferer.type_maps[key], bg_label=0))
ax[2].imshow(img)
ax[3].imshow(everything)
# ax[2].imshow(label2rgb(FileHandler.read_mask([f for f in inferer.gt_mask_dir if key in f.name][0], "inst_map"), bg_label=0))
# ax[3].imshow(label2rgb(FileHandler.read_mask([f for f in inferer.gt_mask_dir if key in f.name][0], "type_map"), bg_label=0))

# Merging results

- If you saved your results to geojson format and the tiles are adjascent to each other (e.g. extracted from a WSI), you can merge all the tiles together to form a QuPath readable geojson file


In [None]:
from src.utils import CellMerger, AreaMerger

# Merge cell annotations
c = CellMerger(
    "/path/to/my_geojson_dir/cells/"
)
c.merge("/path/to/my_wsi_cells.json")

# If the netowork outputs area predictions u can merge them too
c = AreaMerger(
    "/path/to/my_geojson_dir/areas/"
)
c.merge("/path/to/my_wsi_areas.json")

# Run Benchmarking

- Only works if `gt_mask_dir` is provided
- The first cell below this one runs binary metrics i.e. segmentation metrics for all cells
- The second cell below this one runs metrics per cell type

In [None]:
pattern_list = None
binary_scores = inferer.benchmark_insts(pattern_list=pattern_list, file_prefix=f"{exp_name}_{exp_version}")
binary_scores

In [None]:
pattern_list = None
type_scores = inferer.benchmark_types(
    classes=PannukeDataModule.get_classes(),
    pattern_list=pattern_list, 
    file_prefix=f"{exp_version}"
)
type_scores = type_scores[type_scores.index.str.contains("avg")]
type_scores