In [None]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from monai.networks.nets import DynUNet

'''
create a dynamically configurable model in MONAI and how to load pre-trained weights into the model. 
It also shows how to calculate the parameters for the model based on the task properties.
'''


def get_kernels_strides():
    """
    This function is only used for decathlon datasets with the provided patch sizes.
    When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
    be divisible by the product of all strides in the corresponding dimension.
    In addition, the minimal spatial size should have at least one dimension that has twice the size of
    the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.

    get the kernel sizes and strides for the DynUNet model. 
    This function calculates these parameters based on the patch size and spacing for the given task.

    """
    sizes = [180, 240, 240]
    spacings = [1.0, 1.0, 1.0],

    input_size = sizes
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides


def get_network(properties, pretrain_path, checkpoint=None):
    '''
    This function creates a DynUNet model with the given properties and pre-trained weights.
    :param properties: a dictionary containing the properties of the task, such as the number of classes, the modality, and the labels.
    :param pretrain_path: the path to the pre-trained weights.
    :param checkpoint: the name of the checkpoint file.

    :return: a DynUNet model with the given properties and pre-trained weights.
    '''
    n_class = len(properties["labels"])
    in_channels = len(properties["modality"])
    kernels, strides = get_kernels_strides()

    net = DynUNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=n_class,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=False,
        deep_supr_num=1,
    )

    if checkpoint is not None:
        '''
        If a checkpoint is provided, it loads the pre-trained weights from the checkpoint file into the model. 
        If the checkpoint file does not exist, it continues without loading any pre-trained weights.
        '''
        pretrain_path = os.path.join(pretrain_path, checkpoint)
        if os.path.exists(pretrain_path):
            net.load_state_dict(torch.load(pretrain_path))
            print("pretrained checkpoint: {} loaded".format(pretrain_path))
        else:
            print("no pretrained checkpoint")
    return net