In [1]:
from pathlib import Path
from hydra.utils import instantiate

import numpy as np
import matplotlib.pylab as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

from src.data.transforms import GrayToRGB
from src.data.dataset import ChestXRayAlignmentDataset
from src.models import PLAlignmentModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os.path as osp
import re
from datetime import datetime
from omegaconf import OmegaConf


def get_last_ckpt(ckpt_dir):
    def epoch_counter(ckpt_path: Path):
        s = re.search(r"epoch=\d+", ckpt_path.stem).group(0)
        return int(s.split("=")[1])
    
    ckpt_dir = Path(ckpt_dir)
    ckpts = ckpt_dir.glob("*.ckpt")
    ckpts = sorted(ckpts, key=epoch_counter)

    return ckpts[-1]


def get_best_ckpt(ckpt_dir, mode="min", monitor="val_loss"):
    def monitor_counter(ckpt_path: Path):
        s = re.search(rf"{monitor}=\d+.\d+", ckpt_path.stem).group(0)
        return float(s.split("=")[1])

    ckpt_dir = Path(ckpt_dir)
    ckpts = ckpt_dir.glob("*.ckpt")
    ckpts = sorted(ckpts, key=monitor_counter)

    if mode == "min":
        ckpts = ckpts[::-1]
    
    return ckpts[-1]


def get_last_config(cfg_dir, experiment_name, experiment_step=''):
    def sort_configs(paths):
        return sorted(
            paths,
            key=lambda p: datetime.strptime(
                " ".join(p.split(osp.sep)), "%Y-%m-%d %H-%M-%S"),
            reverse=True
        ) 
    
    cfg_dir = Path(cfg_dir)
    cfg_paths = list(map(lambda p: str(Path(*p.parts[-2:])), cfg_dir.glob("*/*")))
    cfg_paths = list(map(lambda p: cfg_dir / p, sort_configs(cfg_paths)))

    for cfg_p in cfg_paths:
        cfg = OmegaConf.load(cfg_p / ".hydra" / "config.yaml")

        cfg_exp_name = cfg.experiment_name
        cfg_exp_step = cfg.get("experiment_step", '')

        cond = (cfg_exp_name == experiment_name) \
            and (cfg_exp_step == experiment_step)
        if cond:
            return cfg

    msg = f"Can't find config for experiment name: {experiment_name} "
    if experiment_step:
        msg += f"and step: {experiment_step}"
    
    raise ValueError(msg)

In [3]:
main_dir = Path("/home/orogov/smbmount/from_DGX/cxr14-2")

In [4]:
cfg = get_last_config("outputs", "V2_nih_resnet18_vgg16_320x320_perceptual_bs16", "")

experiment_step = cfg.get("experiment_step", "")

In [78]:
ckpt_dir = Path(osp.join(cfg.experiment_path, "checkpoints", cfg.experiment_name, experiment_step))
ckpt_path = get_last_ckpt(ckpt_dir)

print(f"Checkpoint: {ckpt_path}")

model_cfg = cfg.model

model = PLAlignmentModel.load_from_checkpoint(ckpt_path, model_config=model_cfg)

Checkpoint: /home/orogov/smbmount/a_galichin/experiments/alignment/checkpoints/V2_nih_resnet18_vgg16_320x320_perceptual_bs16/epoch=17.ckpt


In [6]:
transforms = [A.Resize(1024, 1024, always_apply=True), GrayToRGB(always_apply=True), ToTensorV2()]
anchor_transforms = [A.Resize(320, 320), GrayToRGB(always_apply=True), ToTensorV2()]

dataset = ChestXRayAlignmentDataset(
    main_dir,
    transforms,
    anchor_transforms,
    anchor="canonical_chest",
    split="train_val_list"
)