In [2]:
import os, random
import numpy as np
import zarr, copick
from tqdm import tqdm
from monai.data import DataLoader, CacheDataset, decollate_batch
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd, 
    ScaleIntensityRanged, 
    CropForegroundd, 
    Orientationd, 
    Spacingd, 
    EnsureTyped, 
    Activations, 
    AsDiscrete, 
    Resized, 
    RandFlipd, 
    RandRotate90d, 
    RandZoomd,
    RandGridPatchd,
    NormalizeIntensityd,
    RandCropByLabelClassesd,
    Resized, 
    RandZoomd,
    Activations, 
    CropForegroundd, 
    ScaleIntensityRanged, 
    RandCropByPosNegLabeld,    
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
from monai.metrics import DiceMetric, ConfusionMatrixMetric
# from skimage
import mlflow
import mlflow.pytorch
from sklearn.model_selection import train_test_split

from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, RandCropByLabelClassesd, RandRotate90d, RandFlipd
from monai.data import CacheDataset, DataLoader, Dataset
import mlflow, optuna, torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
copick_config_path = "../../../simulations/ml_challenge/ml_config.json"
root = copick.from_file(copick_config_path)

In [4]:
def get_tomogram_array(copick_run, voxel_spacing=10, tomo_type='wbp'):
    voxel_spacing_obj = copick_run.get_voxel_spacing(voxel_spacing)
    tomogram = voxel_spacing_obj.get_tomogram(tomo_type)
    image = zarr.open(tomogram.zarr(), mode='r')['0']
    return image[:]

def get_segmentation_array(copick_run, segmentation_name, voxel_spacing=10, is_multilabel=True):
    #seg_memb = copick_run.get_segmentations(name="membrane")
    seg_memb = copick_run.get_segmentations()
    seg = copick_run.get_segmentations(is_multilabel=is_multilabel, name=segmentation_name, voxel_size=voxel_spacing)
    if len(seg) == 0:
        raise ValueError(f"No segmentations found for session '{session_id}' and segmentation type '{segmentation_type}'.")
        
    segmentation = zarr.open(seg[0].zarr().path, mode="r")['0'][:]
    _, array = list(zarr.open(seg_memb[0].zarr()).arrays())[0]
    seg_membrane = np.array(array[:])
    #seg_membrane = zarr.open(seg_memb[0].zarr().path, mode="r")['0'][:]
    segmentation[seg_membrane==1] = 1
    return segmentation

In [5]:
voxel_spacing = 10
tomo_type = "wbp"
painting_segmentation_name = "segmentation" #"remotetargets"

runIDs = [run.name for run in root.runs]
data_dicts = []
for runID in tqdm(runIDs):
    run = root.get_run(str(runID))
    tomogram = get_tomogram_array(run)
    segmentation = get_segmentation_array(run, painting_segmentation_name)
    data_dicts.append({"image": tomogram, "label": segmentation})

100%|██████████| 24/24 [00:11<00:00,  2.01it/s]


In [6]:
my_num_samples = 16
train_batch_size = 1
val_batch_size = 1

test_files = data_dicts[-2:]
train_files, val_files = train_test_split(data_dicts[:-2], test_size=0.3)
print(f"Number of training samples: {len(train_files)}")
print(f"Number of validation samples: {len(val_files)}")

# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])

# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        num_classes=8,
        num_samples=my_num_samples
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])

# Create the cached dataset with non-random transforms
train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)

# Wrap the cached dataset to apply random transforms during iteration
train_ds = Dataset(data=train_ds, transform=random_transforms)

# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)

# Validation transforms
val_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        num_classes=8,
        num_samples=my_num_samples,  # Use 1 to get a single, consistent crop per image
    ),
])

# Create validation dataset
val_ds = CacheDataset(data=val_files, transform=non_random_transforms, cache_rate=1.0)

# Wrap the cached dataset to apply random transforms during iteration
val_ds = Dataset(data=val_ds, transform=random_transforms)

# Create validation DataLoader
val_loader = DataLoader(
    val_ds,
    batch_size=val_batch_size,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
    shuffle=False,  # Ensure the data order remains consistent
)

Number of training samples: 15
Number of validation samples: 7


Loading dataset: 100%|██████████| 15/15 [00:00<00:00, 24.33it/s]
Loading dataset: 100%|██████████| 7/7 [00:00<00:00, 25.15it/s]


In [None]:
#loss_function = DiceLoss(include_background=True, to_onehot_y=True, softmax=True)  # softmax=True for multiclass
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True)  

dice_metric = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)  # must use onehot for multiclass
recall_metric = ConfusionMatrixMetric(include_background=False, metric_name="recall", reduction="None")

In [8]:
def stack_patches(data):
    shape = data.shape
    new_shape = (shape[0] * shape[1],) + shape[2:]
    return data.view(new_shape)

post_pred = AsDiscrete(argmax=True, to_onehot=len(root.pickable_objects)+1)
post_label = AsDiscrete(to_onehot=len(root.pickable_objects)+1)

def train(train_loader, 
          val_loader, 
          model, 
          device,
          loss_function, 
          metrics_function, 
          optimizer, 
          max_epochs=100,
          val_interval = 10,
          verbose=True):
    
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs = batch_data["image"].to(device)  # Shape: [B, C, H, W, D]
            labels = batch_data["label"].to(device)  # Shape: [B, C, H, W, D]            
            optimizer.zero_grad()
            outputs = model(inputs)    # Output shape: [B, num_classes, H, W, D]
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(f"batch {step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        mlflow.log_metric("train_loss", epoch_loss, step=epoch+1)

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)
                    val_outputs = model(val_inputs)
                    
                    # Decollate batches into lists
                    val_outputs_list = decollate_batch(val_outputs)
                    val_labels_list = decollate_batch(val_labels)
                    # Apply post-processing
                    metric_val_outputs = [post_pred(i) for i in val_outputs_list]
                    metric_val_labels = [post_label(i) for i in val_labels_list]
                    # Compute metrics
                    metrics_function(y_pred=metric_val_outputs, y=metric_val_labels)

                metrics = metrics_function.aggregate(reduction="mean_batch")
                metric_per_class = ["{:.4g}".format(x) for x in metrics]
                metric = torch.mean(metrics).numpy(force=True)
                mlflow.log_metric("validation metric", metric, step=epoch+1)
                for i,m in enumerate(metrics):
                    mlflow.log_metric(f"validation metric class {i+1}", m, step=epoch+1)
                metrics_function.reset()

                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join('./', "best_metric_model.pth"))
                    
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean recall per class: {', '.join(metric_per_class)}"
                    f"\nbest mean recall: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )

In [12]:
def set_seed(seed):
    # Set the seed for Python's random module
    random.seed(seed)

    # Set the seed for NumPy
    np.random.seed(seed)

    # Set the seed for PyTorch (both CPU and GPU)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # If using multi-GPU

    # Ensure reproducibility of operations by disabling certain optimizations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [20]:
def objective(trial,  
              n_classes,
              train_loader, 
              val_loader, 
              loss_function,
              metrics_function, 
              epochs,
              random_seed = 42,
              gpu_count = 1):

    set_seed(random_seed)

    # Assign each trial to a specific GPU based on the trial number
    if gpu_count > 1:
        gpu_id = trial.number % gpu_count  # Cycle through available GPUs
        device = torch.device(f"cuda:{gpu_id}")
        torch.cuda.set_device(device)  # Set the current GPU for this trial
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    # Set a unique run name for each trial
    trial_num = f"trial_{trial.number}"

    # Start a new MLflow run for each trial
    with mlflow.start_run(run_name = trial_num, nested=True):  # Nested=True allows it to be part of the overall experiment run
        
        # Sample number of channels
        num_layers = trial.suggest_int("num_layers", 2, 5)
        
        # Generate increasing channels based on 16, 32, 64
        base_channel = trial.suggest_categorical("base_channel", [8, 16, 32, 64])
        channels = [base_channel * (2 ** i) for i in range(num_layers)]

        # Strides pattern
        # Sample number of downsampling layers (those with stride of 2)
        num_downsampling_layers = trial.suggest_int("num_downsampling_layers", 1, num_layers - 1)

        # Define strides: first num_downsampling_layers with stride 2, and the rest with stride 1
        strides_pattern = [2] * num_downsampling_layers + [1] * (num_layers - num_downsampling_layers)        

        # Number of residual units
        num_res_units = trial.suggest_int("num_res_units", 1, 3)

        print('Current Parameters: ')
        print(f'Num Layers: {num_layers}')
        print(f'Num Res Units: {num_res_units}')
        print(f'Channels: {channels}')
        print(f'Num Downsampling Layers: {num_downsampling_layers}')
        print(f'Strides: {strides_pattern}\n')

        # Log parameters to MLflow at each trial
        mlflow.log_params({
            "num_layers": num_layers,
            "base_channel": base_channel,
            "channels": channels,
            "num_downsampling_layers": num_downsampling_layers,
            "strides_pattern": strides_pattern,
            "num_res_units": num_res_units,
            "random_seed": random_seed
        })

        # Now use channels, strides, and num_res_units in your model definition
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=n_classes,
            channels=channels,
            strides=strides_pattern,
            num_res_units=num_res_units,
        ).to(device)

        # # Sample learning rate using Optuna
        # learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-3)

        # Define your loss and optimizer, and return the objective (e.g., validation loss)
        lr = 1e-3
        optimizer = torch.optim.Adam(model.parameters(), lr)

        # score = train(train_loader, val_loader, 
        #               model, device, loss_function, 
        #               metrics_function, 
        #               optimizer, max_epochs=epochs)
        
        score = 1

        # Log the score (e.g., validation loss or F1 score) for each trial
        mlflow.log_metric("score", score)

    return score

#### Multi-GPU Exploration

In [21]:
epochs = 100
num_trials = 10
nclasses = len(root.pickable_objects)+1
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True)  
my_metrics = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="None")

# Get available GPUs (for example, on an 8 GPU node)
gpu_count = torch.cuda.device_count()

print(f'Running Architecture Search Over {gpu_count} GPUS\n')
mlflow.set_experiment('unet-model-search')
with mlflow.start_run():
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda trial: objective(trial, nclasses, train_loader, val_loader, loss_function, my_metrics, epochs), 
                   n_trials=num_trials,
                   n_jobs=gpu_count) # Run trials on multiple GPUs

    print(f"Best trial: {study.best_trial.value}")
    print(f"Best params: {study.best_params}")

    # # Log the best trial in MLflow
    # mlflow.log_metric("best_trial_value", study.best_trial.value)
    # mlflow.log_params(study.best_trial.params)

[I 2024-10-17 01:30:08,722] A new study created in memory with name: no-name-937eb630-ceea-41fc-8612-8406d9274cf2
[I 2024-10-17 01:30:08,832] Trial 0 finished with value: 1.0 and parameters: {'num_layers': 2, 'base_channel': 32, 'num_downsampling_layers': 1, 'num_res_units': 1}. Best is trial 0 with value: 1.0.


Running Architecture Search Over 1 GPUS

Current Parameters: 
Num Layers: 2
Num Res Units: 1
Channels: [32, 64]
Num Downsampling Layers: 1
Strides: [2, 1]



[I 2024-10-17 01:30:08,971] Trial 1 finished with value: 1.0 and parameters: {'num_layers': 3, 'base_channel': 64, 'num_downsampling_layers': 2, 'num_res_units': 3}. Best is trial 0 with value: 1.0.
[I 2024-10-17 01:30:09,071] Trial 2 finished with value: 1.0 and parameters: {'num_layers': 2, 'base_channel': 64, 'num_downsampling_layers': 1, 'num_res_units': 1}. Best is trial 0 with value: 1.0.


Current Parameters: 
Num Layers: 3
Num Res Units: 3
Channels: [64, 128, 256]
Num Downsampling Layers: 2
Strides: [2, 2, 1]

Current Parameters: 
Num Layers: 2
Num Res Units: 1
Channels: [64, 128]
Num Downsampling Layers: 1
Strides: [2, 1]



[I 2024-10-17 01:30:09,174] Trial 3 finished with value: 1.0 and parameters: {'num_layers': 2, 'base_channel': 64, 'num_downsampling_layers': 1, 'num_res_units': 1}. Best is trial 0 with value: 1.0.
[I 2024-10-17 01:30:09,277] Trial 4 finished with value: 1.0 and parameters: {'num_layers': 3, 'base_channel': 8, 'num_downsampling_layers': 2, 'num_res_units': 1}. Best is trial 0 with value: 1.0.


Current Parameters: 
Num Layers: 2
Num Res Units: 1
Channels: [64, 128]
Num Downsampling Layers: 1
Strides: [2, 1]

Current Parameters: 
Num Layers: 3
Num Res Units: 1
Channels: [8, 16, 32]
Num Downsampling Layers: 2
Strides: [2, 2, 1]



[I 2024-10-17 01:30:09,468] Trial 5 finished with value: 1.0 and parameters: {'num_layers': 4, 'base_channel': 64, 'num_downsampling_layers': 2, 'num_res_units': 2}. Best is trial 0 with value: 1.0.


Current Parameters: 
Num Layers: 4
Num Res Units: 2
Channels: [64, 128, 256, 512]
Num Downsampling Layers: 2
Strides: [2, 2, 1, 1]

Current Parameters: 
Num Layers: 3
Num Res Units: 2
Channels: [8, 16, 32]
Num Downsampling Layers: 1
Strides: [2, 1, 1]



[I 2024-10-17 01:30:09,580] Trial 6 finished with value: 1.0 and parameters: {'num_layers': 3, 'base_channel': 8, 'num_downsampling_layers': 1, 'num_res_units': 2}. Best is trial 0 with value: 1.0.
[I 2024-10-17 01:30:09,763] Trial 7 finished with value: 1.0 and parameters: {'num_layers': 5, 'base_channel': 16, 'num_downsampling_layers': 4, 'num_res_units': 2}. Best is trial 0 with value: 1.0.


Current Parameters: 
Num Layers: 5
Num Res Units: 2
Channels: [16, 32, 64, 128, 256]
Num Downsampling Layers: 4
Strides: [2, 2, 2, 2, 1]



[I 2024-10-17 01:30:09,906] Trial 8 finished with value: 1.0 and parameters: {'num_layers': 2, 'base_channel': 8, 'num_downsampling_layers': 1, 'num_res_units': 3}. Best is trial 0 with value: 1.0.
[I 2024-10-17 01:30:10,009] Trial 9 finished with value: 1.0 and parameters: {'num_layers': 2, 'base_channel': 8, 'num_downsampling_layers': 1, 'num_res_units': 1}. Best is trial 0 with value: 1.0.


Current Parameters: 
Num Layers: 2
Num Res Units: 3
Channels: [8, 16]
Num Downsampling Layers: 1
Strides: [2, 1]

Current Parameters: 
Num Layers: 2
Num Res Units: 1
Channels: [8, 16]
Num Downsampling Layers: 1
Strides: [2, 1]

Best trial: 1.0
Best params: {'num_layers': 2, 'base_channel': 32, 'num_downsampling_layers': 1, 'num_res_units': 1}


In [None]:
epochs = 100
num_trials = 10
nclasses = len(root.pickable_objects)+1
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True)  
my_metrics = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="None")

# Get available GPUs (for example, on an 8 GPU node)
my_gpu_count = torch.cuda.device_count()

mlflow.set_experiment('unet-model-search')
with mlflow.start_run():
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda trial: objective(trial, nclasses, train_loader, val_loader, loss_function, my_metrics, epochs, gpu_count = my_gpu_count), 
                   n_trials=num_trials,
                   n_jobs=my_gpu_count) # Run trials on multiple GPUs

    print(f"Best trial: {study.best_trial.value}")
    print(f"Best params: {study.best_params}")

    # # Log the best trial in MLflow
    # mlflow.log_metric("best_trial_value", study.best_trial.value)
    # mlflow.log_params(study.best_trial.params)