In [1]:
from trainlib.trainer import SegmentationTrainer
from trainlib.utils import load_config
from trainlib.report import ReportGenerator
from trainlib.data import segmentation_dataloaders
import pandas as pd
from trainlib.viewer import BasicViewer
from monai.transforms import (
    KeepLargestConnectedComponentd,
    Compose,
    AsDiscreted,
    SqueezeDimd,
    SaveImaged,
    Lambdad,
)
from pathlib import Path

In [None]:
config = load_config("../configs/bianry-best.yaml")
config.data.data_dir = config.data.data_dir.parent / "test"
ckpt = "models/model.pt"

In [None]:
test_dl = segmentation_dataloaders(config, train=False, valid=False, test=True)
test_df = pd.read_csv(config.data.test_csv)

In [None]:
trainer = SegmentationTrainer(
    progress_bar=True,
    early_stopping=True,
    metrics=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
    save_latest_metrics=True,
    config=config,
)

In [None]:
def squeeze_affine(x):
    x.meta["affine"] = x.meta["affine"].squeeze()
    return x

In [None]:
postprocessing = Compose(
    [
        SqueezeDimd(keys=["pred", "image"]),
        AsDiscreted(keys="pred", argmax=True),
        KeepLargestConnectedComponentd(
            keys="pred", applied_labels=1, is_onehot=False, num_components=1
        ),
        Lambdad(keys=["pred", "image"], func=squeeze_affine),
    ]
)

In [None]:
def save_prediction(data_dict):
    output_dir = str(Path(data_dict["pred"].meta["filename_or_obj"][0]).parent)
    output_dir = output_dir.replace("/resampled/", "/test/")

    writer = SaveImaged(
        output_dir=output_dir,
        keys="pred",
        output_postfix="pred",
        separate_folder=False,
        resample=False,
    )
    writer(data_dict)

    writer = SaveImaged(
        output_dir=output_dir,
        keys="image",
        output_postfix="image",
        separate_folder=False,
        resample=False,
    )
    writer(data_dict)

In [None]:
from tqdm import tqdm
import SimpleITK as sitk

In [None]:
for fn in tqdm(test_df.image):
    pred = trainer.predict(
        file=str(config.data.data_dir / fn), checkpoint=ckpt, roi_size=config.input_size
    )
    processed = postprocessing(pred)
    save_prediction(processed)

In [None]:
BasicViewer(
    processed["image"].transpose(-1, -3), processed["pred"].transpose(-1, -3), figsize=(6, 6)
).show()