In [None]:
import sys

from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

from simple_cnn import SimpleCnn
sys.path.append('../../utils')
from train_utils import Trainer

In [2]:
# Load model.
model = SimpleCnn()

In [3]:
train_dataset = datasets.FashionMNIST(root='../../data/fashion_mnist',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True)

train_frac = 0.8

train_size = int(train_frac * len(train_dataset))
valid_size = len(train_dataset) - train_size

train_dataset, valid_dataset = random_split(
        train_dataset, [train_size, valid_size])


In [4]:
def find_suffix(s, t):
    if not t.startswith(s):
        return t
    return t[len(s)+1:]


def get_config(params):
    raw_config = params
    name = params['optimizer']['name']
    if name == 'adam':
        beta1 = raw_config['optimizer']['adam_beta1']
        beta2 = raw_config['optimizer']['adam_beta2']
        del raw_config['optimizer']['beta1']
        del raw_config['optimizer']['beta2']
        raw_config['optimize']['betas'] = (beta1, beta2)
    raw_config['device'] = 'cpu'

    config = {}

    for key, value in raw_config.items():
        new_key = find_suffix(name, key)
        config[new_key] = value

    return config


def objective_function(params):
    config = get_config(params)
    print(config)
    # config = {
    #     'optimizer': get_optimizer_config(params),
    #     'loss_fn' : 'ce_loss',
    #     'batch_size' : 128,
    #     'epochs' : 5,
    #     'device' : 'cpu'
    # }
    trainer = Trainer(model, train_dataset, config)
    result = trainer.train()
    return result['loss_history'][-1]

In [5]:
search_space = {
    'optimizer': hp.choice('optimizer', [
        {
            'name' : 'sgd',
            'params' : {
                'momentum' : hp.uniform('sgd_momentum', 0.001, 0.01),
                'lr': hp.uniform('sgd_lr', 0.001, 0.1),
                'weight_decay': hp.uniform('sgd_weight_decay', 0.5, 0.99)
            }
        },
        {
            'name' : 'adam',
            'params' : {
                'lr': hp.uniform('adam_lr', 0.001, 0.1),
                'weight_decay': hp.uniform('adam_weight_decay', 0.5, 0.99),
                'beta1' : hp.uniform('adam_beta1', 0.9, 0.99),
                'beta2' : hp.uniform('adam_beta2', 0.99, 0.999)
            }
        },
        {
            'name' : 'rmsprop',
            'params' : {
                'lr': hp.uniform('rmsprop_lr', 0.001, 0.1),
                'weight_decay': hp.uniform('rmsprop_weight_decay', 0.5, 0.99),
                'momentum' : hp.uniform('rmsprop_momentum', 0.001, 0.01),
                'alpha' : hp.uniform('rmsprop_alpha', 0.9, 0.99)
            }
        }
    ]),  
    'loss_fn' : 'ce_loss',
    'batch_size' : hp.choice('batch_size', [16, 32, 64, 128]),
    'epochs' : hp.choice('epochs', [5])
}

In [6]:
trials = Trials()

best_params = fmin(
    fn = objective_function,
    space = search_space,
    algo = tpe.suggest,
    max_evals = 1,
    trials = trials
)

{'batch_size': 64, 'epochs': 5, 'loss_fn': 'ce_loss', 'optimizer': {'name': 'rmsprop', 'params': {'alpha': 0.9644328599384232, 'lr': 0.04573058188168829, 'momentum': 0.0040741029581688164, 'weight_decay': 0.7689033224225683}}, 'device': 'cpu'}
  0%|          | 0/1 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:09<00:38,  9.60s/it]
 40%|####      | 2/5 [00:19<00:28,  9.63s/it]
 60%|######    | 3/5 [00:27<00:18,  9.19s/it]
 80%|########  | 4/5 [00:37<00:09,  9.18s/it]
100%|##########| 5/5 [00:47<00:00,  9.69s/it]
100%|##########| 5/5 [00:47<00:00,  9.53s/it]


100%|██████████| 1/1 [00:47<00:00, 47.68s/trial, best loss: 2.3106859623591105]


In [7]:
print(best_params)

{'batch_size': np.int64(2), 'epochs': np.int64(0), 'optimizer': np.int64(2), 'rmsprop_alpha': np.float64(0.9644328599384232), 'rmsprop_lr': np.float64(0.04573058188168829), 'rmsprop_momentum': np.float64(0.0040741029581688164), 'rmsprop_weight_decay': np.float64(0.7689033224225683)}
