In [None]:
from src.dl.inference.inferer import Inferer
import src.dl.lightning as lightning
from src.utils.gson_merger import GSONMerger
from src.data import PannukeDataModule, ConsepDataModule

In [None]:
# Initialize the inferer
in_dir = "my_in_dir"
gt_dir = "my_gt_dir" # This is optional. Can be None
exp_name = "ovca" # name of the experiment (directory)
exp_version = "full" # name of the experiment version (sub directory inside the experiment dir)
lightning_model = lightning.SegModel.from_experiment(name=exp_name, version=exp_version)

inferer = Inferer(
    lightning_model,
    in_data_dir=in_dir,
    gt_mask_dir=gt_dir,
    patch_size=(256, 256),
    stride_size=80,
    fn_pattern="*",
    model_weights="last",
    apply_weights=True,
    post_proc_method="cellpose",
    loader_batch_size=1,
    loader_num_workers=1,
    model_batch_size=16,
    n_images=185,
    auto_range=False
) 

In [None]:
inferer.run_inference(
    # save_dir="my_geojson_dir",
    # fformat="geojson",
    # offsets=True
)

In [None]:
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from src.utils import FileHandler

keys = list(inferer.soft_insts.keys())
key = keys[1]

fig, ax = plt.subplots(2, 2, figsize=(40, 40))
ax = ax.flatten()
ax[0].imshow(label2rgb(inferer.inst_maps[key], bg_label=0))
ax[1].imshow(label2rgb(inferer.type_maps[key], bg_label=0))
ax[2].imshow(label2rgb(FileHandler.read_mask([f for f in inferer.gt_mask_dir if key in f.name][0], "inst_map"), bg_label=0))
ax[3].imshow(label2rgb(FileHandler.read_mask([f for f in inferer.gt_mask_dir if key in f.name][0], "type_map"), bg_label=0))

In [None]:
# merge the output geojson files to one file (QuPath readable)
# Run this cell only for patched WSIs and if you need to convert outputs to geojson
# gsonmerger = GSONMerger(in_dir="gson_dir")
# gsonmerger.merge(fname="big_gson_file.json")

In [None]:
pattern_list = None
binary_scores = inferer.benchmark_insts(pattern_list=pattern_list, file_prefix=f"{exp_name}_{exp_version}")
binary_scores

In [None]:
pattern_list = None
type_scores = inferer.benchmark_types(
    classes=PannukeDataModule.get_classes(),
    pattern_list=pattern_list, 
    file_prefix=f"{exp_version}"
)
type_scores = type_scores[type_scores.index.str.contains("avg")]
type_scores

In [None]:
type_scores[type_scores.index.str.contains("for_the")].mean()