In [1]:
import sys

from copy import deepcopy

import torch
from torchvision import datasets, transforms
import torch.nn as nn
from torch.utils.data import random_split, DataLoader

import mlflow
from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

sys.path.append('../src')
sys.path.append('../configs')
sys.path.append('../../../utils')
from search_space_pretrained_unet import search_space
from unet import Unet
from train_utils import Trainer, he_init, OptimizerFactory
from eval_utils import get_pixel_accuracy, evaluate_model
from tune_utils import get_config
from isbi_em_dataset import ISBIEMDataset
from utils import flatten_params

2024-12-21 14:45:30.814207: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734772530.831164   34859 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734772530.836170   34859 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-21 14:45:30.854384: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  check_for_updates()


In [2]:
# Constants.
split_ratio = 0.8 # train / valid

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.25)
])
DATA_DIR = '/home/kramasamy/Code/projects/cnn/data/isbi_em_segmentation'
train_dataset = ISBIEMDataset(DATA_DIR, transform=transform, train=False)

total_size = len(train_dataset)
train_size = int(total_size * split_ratio)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(train_dataset, 
                                            [train_size, valid_size])


In [4]:
import pickle as pkl
import yaml 

pretrain_unet_result = pkl.load(open('../checkpoints/pretrain_unet_results.pkl', 'rb'))
pretrained_unet = pretrain_unet_result['model'].unet

with open('../configs/train_pretrained_config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
def objective_function(params, log_to_mlflow=True):
    torch.cuda.empty_cache()
    config = deepcopy(get_config(params))
    torch.cuda.empty_cache()
    model = Unet()
    model.apply(he_init)
    model.load_state_dict(pretrained_unet.state_dict())

    # Custom optimizer.
    encoder_params = list(model.block1.parameters()) + \
                    list(model.block2.parameters()) + \
                    list(model.block3.parameters()) + \
                    list(model.block4.parameters()) + \
                    list(model.block5.parameters())

    decoder_params = list(model.block6.parameters()) + \
                    list(model.block7.parameters()) + \
                    list(model.block8.parameters()) + \
                    list(model.block9.parameters()) + \
                    list(model.final_conv.parameters())

    optimizer_factory = OptimizerFactory(config['optimizer'])
    lr = config['optimizer']['params']['lr']
    del config['optimizer']['params']['lr']
    optimizer = optimizer_factory.get_optimizer([
        {'params': encoder_params, 'lr': lr / config['optimizer']['lr_damp_pretrained']},  
        {'params': decoder_params, 'lr': lr}   # Larger LR for decoder
    ])

    trainer = Trainer(model, train_dataset, config)
    result = trainer.train(optimizer=optimizer, progress_bar=False)
    valid_loader = DataLoader(dataset=train_dataset, batch_size=2, 
                              shuffle=False)
    torch.cuda.empty_cache()
    y_pred, y_true = evaluate_model(result['model'], valid_loader, 'cuda')
    accuracy = get_pixel_accuracy(y_pred, y_true)

    if log_to_mlflow:
        with mlflow.start_run(nested=True):
            mlflow.log_params(flatten_params(params))
            mlflow.log_metric("accuracy", accuracy)

    return -1 * accuracy

In [6]:
experiment_name = "tune_pretrained_unet"
mlflow.set_tracking_uri("/home/kramasamy/Code/projects/cnn/models/unet/logs/mlflow")
mlflow.set_experiment(experiment_name=experiment_name)
with mlflow.start_run():
    trials = Trials()
    best_params = fmin(
        fn = objective_function,
        space = search_space,
        algo = tpe.suggest,
        max_evals = 100,
        trials = trials
    )

100%|██████████| 100/100 [4:32:18<00:00, 163.39s/trial, best loss: -84.77850341796875] 
