In [None]:
import pathlib
import sys
from torchvision import transforms as T
from torch.utils.data import DataLoader
import json

src = pathlib.Path().absolute().parent

sys.path.append(str(src))

from src.models.segmentation.unet import UNet
from src.data.segmentation import BoxSegmentationDataset, LGGSegmentationDataset 
from src.enums import DataSplit
from src.data.datasets import LGG_NORMALIZE_TRANSFORM
from src.utils.transforms import DualInputCompose, DualInputResize, DualInputTransform, ImgOnlyTransform
from src.utils.visualize import plot_semantic_predictions

LOG_DIR = src / 'logs'
MODEL_REGISTRY = src / 'model_registry'
DATASETS = src / 'datasets'

base_transforms = DualInputCompose([DualInputResize((320, 320)), DualInputTransform(T.ToTensor())])


In [None]:
def load_and_test(type: str, path: pathlib.Path, normalized: bool = False, transforms: DualInputCompose = base_transforms):
    model = UNet.load(path)
    if type == "box":
        dataset = BoxSegmentationDataset(root_dir=DATASETS, split=DataSplit.TEST, transform=transforms)
    elif type == "lgg":
        if normalized:
            transforms = DualInputCompose([*transforms.transforms, ImgOnlyTransform(LGG_NORMALIZE_TRANSFORM)])
            
        dataset = LGGSegmentationDataset(root_dir=DATASETS, split=DataSplit.TEST, transform=transforms)
    else:
        raise ValueError("Invalid type")
    
    loader = DataLoader(dataset, batch_size=6, shuffle=False)
    model.eval()
    imgs, masks = next(iter(loader))
    preds = model(imgs)
    preds = preds.sigmoid().round().int()
    plot_semantic_predictions(imgs, masks, preds, include_overlay=True, include_split=False)

In [None]:
best_lgg_score = 0
lgg_metrics = None
best_lgg_norm_score = 0
lgg_metrics_norm = None
best_box_score = 0
box_metrics = None

best_lgg = None
best_lgg_norm = None
best_box = None

for exp_dir in MODEL_REGISTRY.iterdir():
    # check if run_dir is actually a directory
    if not exp_dir.is_dir():
        continue
    for run_dir in exp_dir.iterdir():
        exp_path = str(run_dir).split('model_registry/')[-1]
        metric_path = LOG_DIR / exp_path / 'metrics.json'
        if not metric_path.exists():
            continue
        with open(metric_path, 'r') as f:
            metrics = json.load(f)
        
        check_points = list(run_dir.glob("*.pth"))
        steps = [int(check_point.stem.split("_")[-1]) for check_point in check_points]
        for i, step in enumerate(steps):
            step = str(step)
            if 'lgg' in exp_path and metrics.get(str(step)) is not None:
                if 'normalized' in exp_path:
                    if metrics[step]['val_BinaryIoU'] > best_lgg_norm_score:
                        best_lgg_norm_score = metrics[step]['val_BinaryIoU']
                        best_lgg_norm = check_points[i]
                        lgg_metrics_norm = metrics[step]
                elif metrics[step]['val_BinaryIoU']> best_lgg_score:
                    best_lgg_score = metrics[step]['val_BinaryIoU']
                    best_lgg = check_points[i]
                    lgg_metrics = metrics[step]
            elif 'box' in exp_path and metrics.get(str(step)) is not None:
                if metrics[step]['val_BinaryIoU']> best_box_score:
                    best_box_score = metrics[step]['val_BinaryIoU']
                    print(metrics[step]['val_BinaryAUROC'])
                    best_box = check_points[i]
                    box_metrics = metrics[step]


print(f"Best LGG Path: {best_lgg} \nMetrics: {lgg_metrics}")
print(f"Best LGG Normalized Path: {best_lgg_norm} \nMetrics: {lgg_metrics_norm}")
print(f"Best Box Path: {best_box}\nMetrics: {box_metrics}")

In [None]:
load_and_test("lgg", best_lgg)
load_and_test("lgg", best_lgg_norm, normalized=True)
load_and_test("box", best_box)