In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys, string, random
from datetime import datetime
from omegaconf import OmegaConf
import wandb
import torch
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from monai.networks.nets import UNet

sys.path.append('../')
from data_utils import MNMv2DataModule
from model.unet import LightningSegmentationModel

In [3]:
# load configs
mnmv2_config   = OmegaConf.load('../configs/datasets/mnmv2.yaml')
unet_config    = OmegaConf.load('../configs/model/monai_unet.yaml')
trainer_config = OmegaConf.load('../configs/trainer/unet_trainer.yaml')

In [4]:
# init datamodule
datamodule = MNMv2DataModule(
    data_dir=mnmv2_config.data_dir,
    vendor_assignment=mnmv2_config.vendor_assignment,
    batch_size=mnmv2_config.batch_size,
    binary_target=mnmv2_config.binary_target,
    non_empty_target=mnmv2_config.non_empty_target,
)

In [5]:
# init model
unet = UNet(
    spatial_dims=unet_config.spatial_dims,
    in_channels=unet_config.in_channels,
    out_channels=unet_config.out_channels,
    channels=[unet_config.n_filters_init * 2 ** i for i in range(unet_config.depth)],
    strides=[2] * (unet_config.depth - 1),
    num_res_units=4
)

model = LightningSegmentationModel(
    model=unet,
    binary_target=True if unet_config.out_channels == 1 else False,
    lr=unet_config.lr,
    patience=unet_config.patience,
    cfgs={
        'dataset': OmegaConf.to_container(mnmv2_config),
        'unet': OmegaConf.to_container(unet_config),
        'trainer': OmegaConf.to_container(trainer_config)
    }
)

In [6]:
# infered variable
patience = unet_config.patience * 2

now = datetime.now()
filename = 'mnmv2-' + now.strftime("%H-%M_%d-%m-%Y")

# init trainer
if trainer_config.logging:
    wandb.finish()
    logger = WandbLogger(
        project="lightning", 
        log_model=True, 
        name=filename
    )
else:
    logger = None

# trainer
trainer = L.Trainer(
    limit_train_batches=trainer_config.limit_train_batches,
    max_epochs=trainer_config.max_epochs,
    logger=logger,
    callbacks=[
        EarlyStopping(
            monitor=trainer_config.early_stopping.monitor, 
            mode=trainer_config.early_stopping.mode, 
            patience=patience
        ),
        ModelCheckpoint(
            dirpath=trainer_config.model_checkpoint.dirpath,
            filename=filename,
            save_top_k=trainer_config.model_checkpoint.save_top_k, 
            monitor=trainer_config.model_checkpoint.monitor,
        )
    ],
    precision='16-mixed',
    gradient_clip_val=0.5,
    devices=[7]
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, datamodule=datamodule)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjlennartz[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | UNet       | 794 K  | train
1 | loss  | DiceCELoss | 0      | train
---------------------------------------------
794 K     Trainable params
0         Non-trainable params
794 K     Total params
3.178     Total estimated model params size (MB)
163       Modules in train mode
0         Modules in eval mode


Epoch 99: 100%|██████████| 50/50 [00:08<00:00,  6.08it/s, v_num=8psl]      

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 50/50 [00:08<00:00,  6.05it/s, v_num=8psl]


In [12]:
checkpoint_path = '../../checkpoints/mnmv2-11-52_29-10-2024.ckpt'

load_as_lightning_module = True
load_as_pytorch_module = False

if load_as_lightning_module:
    unet_config    = OmegaConf.load('../configs/model/monai_unet.yaml')
    unet = UNet(
        spatial_dims=unet_config.spatial_dims,
        in_channels=unet_config.in_channels,
        out_channels=unet_config.out_channels,
        channels=[unet_config.n_filters_init * 2 ** i for i in range(unet_config.depth)],
        strides=[2] * (unet_config.depth - 1),
        num_res_units=4
    )
    model = LightningSegmentationModel.load_from_checkpoint(
        checkpoint_path,
        model=unet,
        binary_target=True if unet_config.out_channels == 1 else False,
        lr=unet_config.lr,
        patience=unet_config.patience,
        # cfg=OmegaConf.to_container(unet_config)
    )

elif load_as_pytorch_module:
    checkpoint = torch.load(checkpoint_path)
    model_state_dict = checkpoint['state_dict']
    model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in model_state_dict.items() if k.startswith('model.')}
    model_config = checkpoint['hyper_parameters']['cfg']

    unet = UNet(
        spatial_dims=model_config['spatial_dims'],
        in_channels=model_config['in_channels'],
        out_channels=model_config['out_channels'],
        channels=[model_config['n_filters_init'] * 2 ** i for i in range(model_config['depth'])],
        strides=[2] * (model_config['depth'] - 1),
        num_res_units=4
    )

    unet.load_state_dict(model_state_dict)

In [13]:
model.cfgs

{'dataset': {'data_dir': '../../../../../data/MNM/',
  'vendor_assignment': {'train': 'siemens', 'test': 'ge'},
  'batch_size': 32,
  'binary_target': False,
  'non_empty_target': False},
 'unet': {'n_filters_init': 16,
  'depth': 4,
  'spatial_dims': 2,
  'in_channels': 1,
  'out_channels': 4,
  'num_res_units': 4,
  'lr': 0.001,
  'patience': 5},
 'trainer': {'train_transforms': 'global_transforms',
  'limit_train_batches': 50,
  'max_epochs': 100,
  'early_stopping': {'monitor': 'val_loss', 'mode': 'min'},
  'model_checkpoint': {'save_top_k': 2,
   'dirpath': '../../pre-trained/monai-unets',
   'monitor': 'val_loss'},
  'logging': True}}