In [1]:
import sys, string, random
from omegaconf import OmegaConf
import wandb
import torch
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from monai.networks.nets import UNet

sys.path.append('../')
from data_utils import MNMv2DataModule
from model.unet import LightningSegmentationModel

nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.


In [2]:
# load configs
mnmv2_config   = OmegaConf.load('../configs/data/mnmv2.yaml')
unet_config    = OmegaConf.load('../configs/model/monai_unet.yaml')
trainer_config = OmegaConf.load('../configs/trainer/unet_trainer.yaml')

In [3]:
# init datamodule
datamodule = MNMv2DataModule(
    data_dir=mnmv2_config.data_dir,
    vendor_assignment=mnmv2_config.vendor_assignment,
    batch_size=mnmv2_config.batch_size,
    binary_target=mnmv2_config.binary_target,
    non_empty_target=mnmv2_config.non_empty_target,
)

In [6]:
# init model
unet = UNet(
    spatial_dims=unet_config.spatial_dims,
    in_channels=unet_config.in_channels,
    out_channels=unet_config.out_channels,
    channels=[unet_config.n_filters_init * 2 ** i for i in range(unet_config.depth)],
    strides=[2] * (unet_config.depth - 1),
    num_res_units=4
)

model = LightningSegmentationModel(
    model=unet,
    binary_target=True if unet_config.out_channels == 1 else False,
    lr=unet_config.lr,
    patience=unet_config.patience,
    cfg=OmegaConf.to_container(unet_config)
)

In [7]:
# infered variable
patience = unet_config.patience * 2
filename = 'mnmv2-' + ''.join(random.choices(string.ascii_letters + string.digits, k=6))

# init trainer
if trainer_config.logging:
    wandb.finish()
    logger = WandbLogger(
        project="lightning", 
        log_model=True, 
        name="mnmv2_unet"
    )
else:
    logger = None

# trainer
trainer = L.Trainer(
    limit_train_batches=trainer_config.limit_train_batches,
    max_epochs=trainer_config.max_epochs,
    logger=logger,
    callbacks=[
        EarlyStopping(
            monitor=trainer_config.early_stopping.monitor, 
            mode=trainer_config.early_stopping.mode, 
            patience=patience
        ),
        ModelCheckpoint(
            dirpath=trainer_config.model_checkpoint.dirpath,
            filename=filename,
            save_top_k=trainer_config.model_checkpoint.save_top_k, 
            monitor=trainer_config.model_checkpoint.monitor,
        )
    ],
    precision='16-mixed',
    gradient_clip_val=0.5,
    devices=[7]
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model, datamodule=datamodule)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjlennartz[0m. Use [1m`wandb login --relogin`[0m to force relogin


/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /out/trained_UNets exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | UNet       | 199 K  | train
1 | loss  | DiceCELoss | 0      | train
---------------------------------------------
199 K     Trainable params
0         Non-trainable params
199 K     Total params
0.798     Total estimated model params size (MB)
169       Modules in train mode
0         Modules in eval mode


Epoch 38:  48%|████▊     | 24/50 [00:04<00:04,  5.63it/s, v_num=am9e]      


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [24]:
checkpoint_path = '../../pre-trained/trained_UNets/mnmv2-Zyv5Y0.ckpt'

load_as_lightning_module = True
load_as_pytorch_module = False

if load_as_lightning_module:
    unet_config    = OmegaConf.load('../configs/model/monai_unet.yaml')
    unet = UNet(
        spatial_dims=unet_config.spatial_dims,
        in_channels=unet_config.in_channels,
        out_channels=unet_config.out_channels,
        channels=[unet_config.n_filters_init * 2 ** i for i in range(unet_config.depth)],
        strides=[2] * (unet_config.depth - 1),
        num_res_units=4
    )
    model = LightningSegmentationModel.load_from_checkpoint(
        checkpoint_path,
        model=unet,
        binary_target=True if unet_config.out_channels == 1 else False,
        lr=unet_config.lr,
        patience=unet_config.patience,
        cfg=OmegaConf.to_container(unet_config)
    )

elif load_as_pytorch_module:
    checkpoint = torch.load(checkpoint_path)
    model_state_dict = checkpoint['state_dict']
    model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in model_state_dict.items() if k.startswith('model.')}
    model_config = checkpoint['hyper_parameters']['cfg']

    unet = UNet(
        spatial_dims=model_config['spatial_dims'],
        in_channels=model_config['in_channels'],
        out_channels=model_config['out_channels'],
        channels=[model_config['n_filters_init'] * 2 ** i for i in range(model_config['depth'])],
        strides=[2] * (model_config['depth'] - 1),
        num_res_units=4
    )

    unet.load_state_dict(model_state_dict)

In [None]:
# TODO: write some example test