In [1]:
import sys

import torch
from torchvision import datasets, transforms
import torch.nn as nn
from torch.utils.data import random_split, DataLoader

from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

sys.path.append('../src')
sys.path.append('../configs')
sys.path.append('../../../utils')
from search_space import search_space
from unet import Unet
from train_utils import Trainer
from eval_utils import get_pixel_accuracy, evaluate_model
from tune_utils import get_config
from isbi_em_dataset import ISBIEMDataset

In [2]:
# Constants.
split_ratio = 0.8 # train / valid

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])
DATA_DIR = '/home/kramasamy/Code/projects/cnn/data/isbi_em_segmentation'
train_dataset = ISBIEMDataset(DATA_DIR, transform=transform, train=False)

total_size = len(train_dataset)
train_size = int(total_size * split_ratio)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(train_dataset, 
                                            [train_size, valid_size])


In [4]:
def objective_function(params):
    torch.cuda.empty_cache()
    model = Unet()
    config = get_config(params)
    trainer = Trainer(model, train_dataset, config)
    result = trainer.train(progress_bar=False)
    valid_loader = DataLoader(dataset=train_dataset, batch_size=2, 
                              shuffle=False)
    torch.cuda.empty_cache()
    y_pred, y_true = evaluate_model(result['model'], valid_loader, 'cuda')
    return -1 * get_pixel_accuracy(y_pred, y_true)

In [5]:
trials = Trials()

best_params = fmin(
    fn = objective_function,
    space = search_space,
    algo = tpe.suggest,
    max_evals = 25,
    trials = trials
)

100%|██████████| 25/25 [1:08:40<00:00, 164.84s/trial, best loss: -78.55237579345703]


In [6]:
print(best_params)

{'adam_beta1': np.float64(0.9076454389299679), 'adam_beta2': np.float64(0.9983641710260994), 'adam_lr': np.float64(0.0006976320483029133), 'adam_weight_decay': np.float64(0.008140957129697157), 'batch_size': np.int64(0), 'epochs': np.int64(0), 'optimizer': np.int64(1)}
