In [1]:
import torch
from torch import nn
from monai.networks.nets import UNet 
from monai.networks.layers import Norm # For specifying normalization layers in UNet
import os


In [12]:
def generate_model(opt):
    """
    Generates a 3D segmentation model (e.g., UNet from MONAI) based on options,
    handles GPU placement, and loads pre-trained weights if specified.

    Args:
        opt: An object or dictionary containing configuration options like:
            - model (str): Type of model (e.g., 'unet_resnet_style' or just 'unet').
            - model_depth (int): Can be used to select UNet configurations if desired,
                                 or ignored if using a fixed UNet architecture.
            - input_W, input_H, input_D (int): Input dimensions (used for MONAI UNet).
            - n_seg_classes (int): Number of segmentation classes.
            - no_cuda (bool): If True, use CPU; otherwise, use GPU.
            - gpu_id (list of int): List of GPU IDs to use.
            - phase (str): Current phase ('train', 'test', etc.).
            - pretrain_path (str, optional): Path to pre-trained model weights.
            - new_layer_names (list of str, optional): Names of new layers for fine-tuning.
            - freeze_base (bool, optional): If True, freeze base parameters during fine-tuning.
    Returns:
        torch.nn.Module: The generated (and potentially pre-trained) model.
        dict or torch.nn.Parameter: Parameters for the optimizer.
    """
    print("--- Entering generate_model ---")
    print(f"Options received: {opt}")

    # For 3D segmentation, U-Net is a common choice.
    # The original code structure was trying to build a ResNet.
    # MONAI's UNet can have ResNet-like blocks, or you can use other 3D architectures.
    # We'll use a MONAI UNet here as an example.
    # The 'opt.model' and 'opt.model_depth' might need to be re-interpreted
    # for selecting different UNet configurations if you have them.

    assert opt.model in ['resnet', 'unet'], \
        f"Unsupported model type: {opt.model}. Expected 'resnet' (interpreted as UNet-style for 3D) or 'unet'."

    model = None

    # Example: Using MONAI UNet.
    # You might want to map opt.model_depth to different UNet channel configurations
    # or numbers of residual units if you want to keep that parameter meaningful.
    # For simplicity, this example uses a fixed UNet configuration.
    # The parameters like 'sample_input_W/H/D' are not directly used by MONAI UNet constructor
    # but are good for defining `roi_size` or `patch_size` in data loading/training.
    # `num_seg_classes` is equivalent to `out_channels`.
    
    # Assuming 'resnet' in opt.model implies a UNet-style architecture for 3D segmentation
    if opt.model == 'resnet' or opt.model == 'unet':
        print(f"Creating MONAI UNet for 3D segmentation with {opt.n_seg_classes} output classes.")
        # Example UNet configuration - you'll likely want to customize this
        # based on opt.model_depth or other specific needs.
        unet_channels = (16, 32, 64, 128, 256) # Default, can be adjusted
        unet_strides = (2, 2, 2, 2)
        num_res_units = 2 # ResNet-like blocks

        if hasattr(opt, 'model_depth'): # Optionally use model_depth to vary UNet
            if opt.model_depth <= 18:
                unet_channels = (16, 32, 64, 128)
                unet_strides = (2, 2, 2)
            elif opt.model_depth <= 34:
                unet_channels = (16, 32, 64, 128, 256)
                unet_strides = (2, 2, 2, 2)
            # Add more conditions for deeper/larger UNets if needed

        try:
            model = UNet(
                spatial_dims=3,
                in_channels=1,  # Assuming 1 input channel for CT scans (e.g., opt.input_channels)
                out_channels=opt.n_seg_classes,
                channels=unet_channels,
                strides=unet_strides,
                num_res_units=num_res_units,
                norm=Norm.BATCH # Or Norm.INSTANCE, etc.
            )
            print(f"MONAI UNet model instantiated: {type(model)}")
        except Exception as e:
            print(f"Error instantiating MONAI UNet: {e}")
            raise e # Re-raise the exception to stop execution if model creation fails

    if model is None:
        raise ValueError("Model could not be instantiated. Check model type and parameters in 'opt'.")

    parameters_for_optimizer = model.parameters()  # Default

    # --- Load Pretrained Weights (before moving to GPU and DataParallel for simplicity) ---
    if hasattr(opt, 'phase') and opt.phase != 'test' and \
       hasattr(opt, 'pretrain_path') and opt.pretrain_path:
        print(f'Loading pretrained model from: {opt.pretrain_path}')
        try:
            pretrain = torch.load(opt.pretrain_path, map_location='cpu')
            current_model_dict = model.state_dict()
            
            pretrained_state_dict_source = pretrain['state_dict'] if isinstance(pretrain, dict) and 'state_dict' in pretrain else pretrain
            
            filtered_pretrained_dict = {}
            loaded_keys_count = 0
            skipped_keys = []

            for k, v_pretrain in pretrained_state_dict_source.items():
                original_key_pretrain = k
                if k.startswith('module.'): # If pretrain was saved from a DataParallel model
                    k = k[7:]
                
                if k in current_model_dict:
                    v_model = current_model_dict[k]
                    if v_model.shape == v_pretrain.shape:
                        filtered_pretrained_dict[k] = v_pretrain
                        loaded_keys_count += 1
                    else:
                        skipped_keys.append(
                            f"'{original_key_pretrain}' (shape mismatch: model {v_model.shape}, pretrain {v_pretrain.shape})"
                        )
                else:
                    skipped_keys.append(f"'{original_key_pretrain}' (not in current model)")

            if loaded_keys_count > 0:
                current_model_dict.update(filtered_pretrained_dict)
                model.load_state_dict(current_model_dict)
                print(f"Successfully loaded {loaded_keys_count} matching layers from pretrained model.")
            else:
                print("Warning: No layers were loaded from the pretrained model. Check keys and model architecture.")
            if skipped_keys:
                print(f"Skipped {len(skipped_keys)} pretrained layers due to name or shape mismatch: {', '.join(skipped_keys[:5])}{'...' if len(skipped_keys) > 5 else ''}")

        except FileNotFoundError:
            print(f"Error: Pretrained model file not found at {opt.pretrain_path}")
        except Exception as e:
            print(f"Error loading pretrained model: {e}")


        # --- Prepare parameter groups for fine-tuning ---
        new_parameters_list = []
        base_parameters_list = list(model.parameters()) # Default to all parameters being base

        if hasattr(opt, 'new_layer_names') and opt.new_layer_names:
            for param_name, param_obj in model.named_parameters():
                for new_layer_identifier in opt.new_layer_names:
                    if new_layer_identifier in param_name:
                        new_parameters_list.append(param_obj)
                        # Ensure requires_grad is True for new parameters
                        param_obj.requires_grad = True 
                        break 

            if new_parameters_list:
                new_param_ids = set(map(id, new_parameters_list))
                base_parameters_list = [p for p in model.parameters() if id(p) not in new_param_ids]
                print(f"Identified {len(new_parameters_list)} new parameters and {len(base_parameters_list)} base parameters for fine-tuning.")
            else:
                print("Warning: 'new_layer_names' provided, but no matching parameters found. All parameters treated as base.")
        else:
            print("No 'new_layer_names' specified for fine-tuning. All parameters treated as base.")
        
        parameters_for_optimizer = {
            'base_parameters': base_parameters_list,
            'new_parameters': new_parameters_list  # Will be empty if none found/specified
        }

        if hasattr(opt, 'freeze_base') and opt.freeze_base and base_parameters_list:
            print("Freezing base parameters.")
            for param in base_parameters_list:
                param.requires_grad = False
        elif base_parameters_list: # Ensure base parameters are trainable if not frozen
             for param in base_parameters_list:
                param.requires_grad = True

    # --- GPU / CPU Handling ---
    if not opt.no_cuda and torch.cuda.is_available():
        if hasattr(opt, 'gpu_id') and opt.gpu_id and len(opt.gpu_id) > 0:
            if len(opt.gpu_id) > 1:
                print(f"Using nn.DataParallel for GPUs: {opt.gpu_id}")
                model = nn.DataParallel(model, device_ids=opt.gpu_id)
                # The model (and its parameters) will be on opt.gpu_id[0]
                # DataParallel handles distributing data during forward pass
                model.cuda(opt.gpu_id[0]) 
            else: 
                device_id = opt.gpu_id[0]
                device = torch.device(f"cuda:{device_id}")
                model = model.to(device)
                print(f"Model moved to GPU: cuda:{device_id}")
        else: 
            device = torch.device("cuda") # Defaults to cuda:0
            model = model.to(device)
            print("Model moved to default GPU (cuda:0 or primary CUDA device).")
    else:
        if opt.no_cuda:
            print("CUDA disabled by user (opt.no_cuda=True). Using CPU.")
        else:
            print("CUDA not available on this system. Using CPU.")
        device = torch.device("cpu")
        model = model.to(device)
    
    print("--- Exiting generate_model ---")
    return model, parameters_for_optimizer,device

In [13]:

class Options:
    def __init__(self):
        self.model = 'unet'
        self.model_depth = 34 # Example depth
        self.input_W = 128 # Example, not directly used by MONAI UNet constructor
        self.input_H = 128 # Example
        self.input_D = 128 # Example
        self.n_seg_classes = 3 # Example: background, tumor, other structure
        self.no_cuda = False # Set to True to force CPU
        self.gpu_id = [0]    # Example: use GPU 0. For multi-GPU: [0, 1]
        self.phase = 'train' # Or 'test'
        self.pretrain_path = None # Path to your .pth or .pt file, or None
        # self.pretrain_path = "path/to/your/pretrained_model.pth" 
        self.new_layer_names = [] # Example: ['model.output_conv.conv.weight'] for fine-tuning specific layers
        # self.new_layer_names = ['final_conv'] # if your UNet has a layer with 'final_conv' in its name
        self.freeze_base = False

# Create an instance of the options
opts = Options()

# --- Special handling for GPU availability in a notebook environment ---
# If 'no_cuda' is False, but no GPUs are actually available, override to True
if not opts.no_cuda and not torch.cuda.is_available():
    print("CUDA specified but not available. Switching to CPU (no_cuda=True).")
    opts.no_cuda = True
    opts.gpu_id = [] # Clear gpu_id if forcing CPU

# Generate the model
try:
    model, params_for_optimizer = generate_model(opts)
    print(f"\nModel generated successfully: {type(model)}")
    if isinstance(params_for_optimizer, dict):
        print(f"Optimizer params: {len(params_for_optimizer.get('base_parameters', []))} base, {len(params_for_optimizer.get('new_parameters', []))} new.")
    else:
        print(f"Optimizer params: {len(list(params_for_optimizer))} total.")
    
    # You can print the model structure (can be very long for UNets)
    # print("\nModel Structure:")
    # print(model)

except Exception as e:
    print(f"An error occurred during model generation or setup: {e}")


CUDA specified but not available. Switching to CPU (no_cuda=True).
--- Entering generate_model ---
Options received: <__main__.Options object at 0x731353114230>
Creating MONAI UNet for 3D segmentation with 3 output classes.
MONAI UNet model instantiated: <class 'monai.networks.nets.unet.UNet'>
CUDA disabled by user (opt.no_cuda=True). Using CPU.
--- Exiting generate_model ---
An error occurred during model generation or setup: too many values to unpack (expected 2)


In [14]:
# This is an illustrative example of how you might set up options and call the function.
# In a real scenario, 'opt' would be populated by an argument parser or a config file.

class Options:
    def __init__(self):
        # Model options
        self.model = 'unet' 
        self.model_depth = 34 
        self.input_channels = 1 # Number of input channels for the model
        self.n_seg_classes = 3 # Example: background, class1, class2 (for DiceCELoss to_onehot_y=True)
        
        # Hardware options
        self.no_cuda = False 
        self.gpu_id = [0]    
        
        # Phase and Pretraining
        self.phase = 'train' 
        self.pretrain_path = None 
        self.new_layer_names = [] 
        self.freeze_base = False

      
        self.data_root = "./monai/scan/label-25.nii.gz" #callin the  with NIfTI files .nfi
        # self.img_list_file = "./data/train_list.txt" 
        self.roi_size = (96, 96, 96) # Example patch size for training
        self.batch_size = 2 # Example batch size
        self.num_workers = 0 # For DataLoader, set to 0 for initial testing, >0 for parallel loading

        # Intensity scaling/normalization ie houndsfield unit
        self.a_min = -200.0 # HU min for scaling (e.g., for soft tissue window)
        self.a_max = 200.0  # HU max for scaling
        self.b_min = 0.0    # Output min after scaling
        self.b_max = 1.0    # Output max after scaling
        # self.norm_mean = 0.5 # For NormalizeIntensityd, if used after ScaleIntensityRanged
        # self.norm_std = 0.5  # For NormalizeIntensityd

# Create an instance of the options
opts = Options()

# --- Special handling for GPU availability in a notebook environment --- gpu requirement code
if not opts.no_cuda and not torch.cuda.is_available():
    print("CUDA specified but not available. Switching to CPU (no_cuda=True).")
    opts.no_cuda = True
    opts.gpu_id = [] 

# Generate the model
try:
    # Now also expecting 'device' to be returned
    model, params_for_optimizer, device = generate_model(opts) 
    print(f"\nModel generated successfully: {type(model)}")
    print(f"Model is on device: {next(model.parameters()).device}") # Check actual device of model parameters
    
    if isinstance(params_for_optimizer, dict):
        base_param_count = sum(p.numel() for p in params_for_optimizer.get('base_parameters', []) if p.requires_grad)
        new_param_count = sum(p.numel() for p in params_for_optimizer.get('new_parameters', []) if p.requires_grad)
        print(f"Optimizer params: {base_param_count} trainable base params, {new_param_count} trainable new params.")
    else:
        total_trainable_params = sum(p.numel() for p in params_for_optimizer if p.requires_grad)
        print(f"Optimizer params: {total_trainable_params} total trainable parameters.")
    
except Exception as e:
    print(f"An error occurred during model generation or setup: {e}")
    # To see the full traceback if an error occurs within generate_model
    import traceback
    traceback.print_exc()
    model = None # Ensure model is None if generation failed
    params_for_optimizer = None
    device = torch.device("cpu")



CUDA specified but not available. Switching to CPU (no_cuda=True).
--- Entering generate_model ---
Options received: <__main__.Options object at 0x7313530b6870>
Creating MONAI UNet for 3D segmentation with 3 output classes.
MONAI UNet model instantiated: <class 'monai.networks.nets.unet.UNet'>
CUDA disabled by user (opt.no_cuda=True). Using CPU.
--- Exiting generate_model ---

Model generated successfully: <class 'monai.networks.nets.unet.UNet'>
Model is on device: cpu
Optimizer params: 4809920 total trainable parameters.
