In [1]:
import sys

from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll import scope

from simple_cnn import SimpleCnn
sys.path.append('../../utils')
from trainer import Trainer

In [2]:
# Load model.
model = SimpleCnn()

In [3]:
train_dataset = datasets.FashionMNIST(root='../../data/fashion_mnist',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True)

train_frac = 0.8

train_size = int(train_frac * len(train_dataset))
valid_size = len(train_dataset) - train_size

train_dataset, valid_dataset = random_split(
        train_dataset, [train_size, valid_size])


In [4]:
def objective_function(params):
    config = {
        'optimizer': {
            'name' : params['optimizer'],
            'params' : {
                'lr' : params['lr'],
            }
        },
        'loss_fn' : 'ce_loss',
        'batch_size' : 128,
        'epochs' : 5,
        'device' : 'cpu'
    }
    trainer = Trainer(model, train_dataset, config)
    result = trainer.train()
    return result['loss_history'][-1]

In [5]:
search_space = {
    'lr': hp.uniform('lr', 0.001, 0.1),  # Learning rate
    'optimizer': hp.choice('optimizer', ['sgd', 'adam']),  # Optimizer type
}

In [6]:
trials = Trials()

best_params = fmin(
    fn = objective_function,
    space = search_space,
    algo = tpe.suggest,
    max_evals = 1,
    trials = trials
)

  0%|          | 0/1 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/375 [00:00<?, ?it/s]
  1%|          | 2/375 [00:00<00:21, 17.15it/s]
  2%|1         | 7/375 [00:00<00:11, 31.49it/s]
  3%|3         | 12/375 [00:00<00:09, 37.34it/s]
  5%|4         | 17/375 [00:00<00:09, 39.33it/s]
  6%|5         | 22/375 [00:00<00:08, 40.22it/s]
  7%|7         | 27/375 [00:00<00:08, 40.54it/s]
  9%|8         | 32/375 [00:00<00:08, 40.10it/s]
 10%|9         | 37/375 [00:00<00:08, 40.15it/s]
 11%|#1        | 42/375 [00:01<00:08, 39.96it/s]
 13%|#2        | 47/375 [00:01<00:08, 40.17it/s]
 14%|#3        | 52/375 [00:01<00:08, 40.11it/s]
 15%|#5        | 57/375 [00:01<00:08, 37.19it/s]
 16%|#6        | 61/375 [00:01<00:08, 36.62it/s]
 17%|#7        | 65/375 [00:01<00:08, 36.25it/s]
 18%|#8        | 69/375 [00:01<00:08, 36.80it/s]
 19%|#9        | 73/375 [00:01<00:08, 37.07it/s]
 21%|##        | 77/375 [00:02<00:08, 36.99it/s]
 22%|##1       | 81/375 [00:02<00:07, 36.76it/s]
 23%|##2       | 85/375 [00:02<00:07, 36.57it/s]
 24%|##3       | 89/375 [00:02<

Epoch 1/5,                   Loss: 0.6382884060541789
  0%|          | 0/1 [00:09<?, ?trial/s, best loss=?]

  0%|          | 0/375 [00:00<?, ?it/s]
  1%|1         | 4/375 [00:00<00:09, 38.02it/s]
  2%|2         | 8/375 [00:00<00:09, 38.88it/s]
  3%|3         | 12/375 [00:00<00:09, 38.89it/s]
  4%|4         | 16/375 [00:00<00:09, 37.24it/s]
  6%|5         | 21/375 [00:00<00:09, 38.59it/s]
  7%|6         | 25/375 [00:00<00:08, 39.03it/s]
  8%|7         | 29/375 [00:00<00:08, 38.99it/s]
  9%|8         | 33/375 [00:00<00:08, 39.28it/s]
 10%|9         | 37/375 [00:00<00:08, 38.99it/s]
 11%|#         | 41/375 [00:01<00:08, 38.31it/s]
 12%|#2        | 45/375 [00:01<00:08, 38.68it/s]
 13%|#3        | 50/375 [00:01<00:08, 39.50it/s]
 15%|#4        | 55/375 [00:01<00:07, 40.32it/s]
 16%|#6        | 60/375 [00:01<00:07, 40.17it/s]
 17%|#7        | 65/375 [00:01<00:07, 40.16it/s]
 19%|#8        | 70/375 [00:01<00:07, 39.56it/s]
 20%|#9        | 74/375 [00:01<00:07, 38.93it/s]
 21%|##        | 78/375 [00:02<00:07, 38.01it/s]
 22%|##1       | 82/375 [00:02<00:07, 36.84it/s]
 23%|##2       | 86/375 [00:02<

Epoch 2/5,                   Loss: 0.44184005284309386
  0%|          | 0/1 [00:19<?, ?trial/s, best loss=?]

  0%|          | 0/375 [00:00<?, ?it/s]
  1%|1         | 4/375 [00:00<00:10, 34.58it/s]
  2%|2         | 8/375 [00:00<00:10, 34.94it/s]
  3%|3         | 12/375 [00:00<00:10, 35.60it/s]
  4%|4         | 16/375 [00:00<00:10, 35.23it/s]
  5%|5         | 20/375 [00:00<00:09, 35.61it/s]
  6%|6         | 24/375 [00:00<00:09, 36.23it/s]
  7%|7         | 28/375 [00:00<00:09, 37.03it/s]
  9%|8         | 32/375 [00:00<00:09, 36.05it/s]
 10%|9         | 36/375 [00:01<00:09, 35.87it/s]
 11%|#         | 40/375 [00:01<00:09, 35.53it/s]
 12%|#1        | 44/375 [00:01<00:09, 35.42it/s]
 13%|#2        | 48/375 [00:01<00:09, 35.73it/s]
 14%|#3        | 52/375 [00:01<00:09, 35.69it/s]
 15%|#4        | 56/375 [00:01<00:08, 36.24it/s]
 16%|#6        | 60/375 [00:01<00:08, 35.98it/s]
 17%|#7        | 64/375 [00:01<00:08, 35.55it/s]
 18%|#8        | 68/375 [00:01<00:08, 35.36it/s]
 19%|#9        | 72/375 [00:02<00:08, 35.93it/s]
 20%|##        | 76/375 [00:02<00:08, 36.45it/s]
 21%|##1       | 80/375 [00:02<

Epoch 3/5,                   Loss: 0.41620501899719237
  0%|          | 0/1 [00:30<?, ?trial/s, best loss=?]

  0%|          | 0/375 [00:00<?, ?it/s]
  1%|1         | 4/375 [00:00<00:10, 35.38it/s]
  2%|2         | 8/375 [00:00<00:10, 36.16it/s]
  3%|3         | 12/375 [00:00<00:09, 37.03it/s]
  4%|4         | 16/375 [00:00<00:09, 36.73it/s]
  5%|5         | 20/375 [00:00<00:09, 37.40it/s]
  6%|6         | 24/375 [00:00<00:09, 37.74it/s]
  7%|7         | 28/375 [00:00<00:09, 36.81it/s]
  9%|8         | 32/375 [00:00<00:09, 36.48it/s]
 10%|9         | 36/375 [00:00<00:09, 36.14it/s]
 11%|#         | 40/375 [00:01<00:09, 35.77it/s]
 12%|#1        | 44/375 [00:01<00:09, 35.97it/s]
 13%|#2        | 48/375 [00:01<00:09, 35.69it/s]
 14%|#3        | 52/375 [00:01<00:09, 35.11it/s]
 15%|#4        | 56/375 [00:01<00:09, 34.59it/s]
 16%|#6        | 60/375 [00:01<00:09, 34.86it/s]
 17%|#7        | 64/375 [00:01<00:08, 34.57it/s]
 18%|#8        | 68/375 [00:01<00:08, 34.28it/s]
 19%|#9        | 72/375 [00:02<00:08, 34.36it/s]
 20%|##        | 76/375 [00:02<00:08, 34.49it/s]
 21%|##1       | 80/375 [00:02<

Epoch 4/5,                   Loss: 0.39439649625619255
  0%|          | 0/1 [00:40<?, ?trial/s, best loss=?]

  0%|          | 0/375 [00:00<?, ?it/s]
  1%|1         | 4/375 [00:00<00:10, 35.38it/s]
  2%|2         | 8/375 [00:00<00:10, 35.56it/s]
  3%|3         | 12/375 [00:00<00:10, 35.95it/s]
  4%|4         | 16/375 [00:00<00:09, 36.23it/s]
  5%|5         | 20/375 [00:00<00:09, 36.93it/s]
  6%|6         | 24/375 [00:00<00:09, 37.33it/s]
  7%|7         | 28/375 [00:00<00:09, 37.00it/s]
  9%|8         | 32/375 [00:00<00:09, 36.23it/s]
 10%|9         | 36/375 [00:00<00:09, 36.73it/s]
 11%|#         | 40/375 [00:01<00:09, 36.26it/s]
 12%|#1        | 44/375 [00:01<00:09, 35.83it/s]
 13%|#2        | 48/375 [00:01<00:09, 35.43it/s]
 14%|#3        | 52/375 [00:01<00:08, 35.92it/s]
 15%|#4        | 56/375 [00:01<00:08, 35.74it/s]
 16%|#6        | 60/375 [00:01<00:08, 36.76it/s]
 17%|#7        | 64/375 [00:01<00:08, 37.03it/s]
 18%|#8        | 68/375 [00:01<00:08, 37.44it/s]
 19%|#9        | 72/375 [00:01<00:08, 36.87it/s]
 20%|##        | 76/375 [00:02<00:08, 37.37it/s]
 21%|##1       | 80/375 [00:02<

Epoch 5/5,                   Loss: 0.39210060584545137
100%|██████████| 1/1 [00:50<00:00, 50.41s/trial, best loss: 147.03772719204426]


In [7]:
print(best_params)

{'lr': np.float64(0.022792097124131792), 'optimizer': np.int64(1)}
