In [1]:
import sys

from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

sys.path.append('../src')
sys.path.append('../configs')
sys.path.append('../../../utils')
from search_space import search_space
from simple_cnn import SimpleCnn
from train_utils import Trainer
from tune_utils import get_config

In [2]:
# Load model.
model = SimpleCnn()

In [3]:
train_dataset = datasets.FashionMNIST(root='../../../data/fashion_mnist',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True)

train_frac = 0.8

train_size = int(train_frac * len(train_dataset))
valid_size = len(train_dataset) - train_size

train_dataset, valid_dataset = random_split(
        train_dataset, [train_size, valid_size])


In [4]:
def objective_function(params):
    config = get_config(params)
    print(config)
    # config = {
    #     'optimizer': get_optimizer_config(params),
    #     'loss_fn' : 'ce_loss',
    #     'batch_size' : 128,
    #     'epochs' : 5,
    #     'device' : 'cpu'
    # }
    trainer = Trainer(model, train_dataset, config)
    result = trainer.train()
    return result['loss_history'][-1]

In [5]:
trials = Trials()

best_params = fmin(
    fn = objective_function,
    space = search_space,
    algo = tpe.suggest,
    max_evals = 1,
    trials = trials
)

{'batch_size': 16, 'epochs': 5, 'loss_fn': 'ce_loss', 'optimizer': {'name': 'sgd', 'params': {'lr': 0.017125936022545085, 'momentum': 0.00928849323986532, 'weight_decay': 0.5068976757981482}}, 'device': 'cpu'}
  0%|          | 0/1 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:13<00:52, 13.03s/it]
 40%|####      | 2/5 [00:28<00:43, 14.53s/it]
 60%|######    | 3/5 [00:42<00:28, 14.11s/it]
 80%|########  | 4/5 [01:08<00:18, 18.80s/it]
100%|##########| 5/5 [01:34<00:00, 21.59s/it]
100%|##########| 5/5 [01:34<00:00, 18.95s/it]


100%|██████████| 1/1 [01:34<00:00, 94.76s/trial, best loss: 2.3026270888646443]


In [6]:
print(best_params)

{'batch_size': np.int64(0), 'epochs': np.int64(0), 'optimizer': np.int64(0), 'sgd_lr': np.float64(0.017125936022545085), 'sgd_momentum': np.float64(0.00928849323986532), 'sgd_weight_decay': np.float64(0.5068976757981482)}
