In [1]:
import sys

from torchvision import datasets, transforms
import torch.nn as nn
from torch.utils.data import random_split, DataLoader

from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

sys.path.append('../src')
sys.path.append('../configs')
sys.path.append('../../../utils')
from search_space import search_space
from simple_cnn import SimpleCnn
from train_utils import Trainer
from eval_utils import get_top_k_accuracy, evaluate_model
from tune_utils import get_config

In [2]:
# Constants.
split_ratio = 0.8 # train / valid

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.25)
])
train_dataset = datasets.FashionMNIST(root='../../../data/fashion_mnist',
                                train=True,
                                transform=transform,
                                download=True)

total_size = len(train_dataset)
train_size = int(total_size * split_ratio)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(train_dataset, 
                                            [train_size, valid_size])


In [4]:
def objective_function(params):
    model = SimpleCnn()
    config = get_config(params)
    trainer = Trainer(model, train_dataset, config)
    result = trainer.train(progress_bar=False)
    valid_loader = DataLoader(dataset=train_dataset, batch_size=128, 
                              shuffle=False)
    y_pred, y_true = evaluate_model(result['model'], valid_loader, 'cuda')
    return -1 * get_top_k_accuracy(y_pred, y_true, 1)

In [5]:
trials = Trials()

best_params = fmin(
    fn = objective_function,
    space = search_space,
    algo = tpe.suggest,
    max_evals = 25,
    trials = trials
)

100%|██████████| 25/25 [22:14<00:00, 53.38s/trial, best loss: -89.72291350364685]


In [6]:
print(best_params)

{'batch_size': np.int64(2), 'epochs': np.int64(0), 'optimizer': np.int64(0), 'sgd_lr': np.float64(0.09899191131179588), 'sgd_momentum_choice': np.int64(0), 'sgd_weight_decay': np.float64(0.0008258631628090254)}
