In [1]:
from mlp import MLP
from wideresnet import WideResNet
import dnn_utils

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim

from tqdm import tqdm

import tensorboard_logger

  (fname, cnt))
  (fname, cnt))


In [2]:
def get_experiement_name(model_name, schedule_name, adaptive, amsgrad, momentum, lr):
    if adaptive and momentum:
        raise('Cannot run momentum on adaptive')
    if amsgrad and not adaptive:
        raise('Cannot run amsgrad on non adaptive')
    train_loader, val_loader = dnn_utils.get_cifar10_loaders()
    if adaptive and amsgrad:
        method_name = 'AMSGrad'
    elif adaptive and not amsgrad:
        method_name = 'ADAM'
    elif not adaptive and not momentum:
        method_name = 'SGD'
    elif not adaptive and momentum:
        method_name = 'Nest_SGD'
    else:
        raise('Invalid optimization method')
    lr_name = 'lr={}'.format(lr)
    experiment_name = '_'.join([model_name, method_name, 
                                schedule_name, lr_name])
    return experiment_name

def run_experiment(model, lr, decay_delta, decay_k, epochs=400,
                   adaptive=False, amsgrad=False, momentum=0.9, 
                   model_name='default_model', schedule_name='default_schedule',
                   logdir='clean_runs'):
    has_momentum = momentum == 0
    experiment_name = get_experiement_name(model_name, schedule_name, adaptive,
                                           amsgrad, has_momentum, lr)
    tlog = tensorboard_logger.Logger(logdir + '/' + experiment_name)
    train_loader, val_loader = dnn_utils.get_cifar10_loaders()
    decay_lr = dnn_utils.get_lr_decay_function(decay_delta, decay_k, tlog.log_value)
    cudnn.benchmark = True
    if adaptive:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=amsgrad)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    for epoch in tqdm(range(epochs)):
        if not adaptive:
            decay_lr(optimizer, epoch)
        dnn_utils.train(train_loader, model, criterion, optimizer, epoch, 
            total_epochs=epochs, 
            performance_stats={'train_err': dnn_utils.top1error},
            verbose=False, tensorboard_log_function=tlog.log_value,
            tensorboard_stats=['train_loss', 'train_err'])
        dnn_utils.validate(val_loader, model, criterion, epoch,
            total_epochs=epochs, 
            performance_stats={'val_err': dnn_utils.top1error},
            verbose=False, tensorboard_log_function=tlog.log_value,
            tensorboard_stats=['val_loss', 'val_err'])

In [3]:
epochs = 400
adam_lrs = [0.00005, 0.0001, 0.0005]
sgd_lrs = [0.005, 0.01, 0.05]
momentum = 0.9
mlp_model = MLP([32*32*3, 512, 10]).cuda()
resnet_model = WideResNet(depth=28, num_classes=10).cuda()
criterion = nn.CrossEntropyLoss().cuda()
for adaptive in [True, False]:
    for amsgrad in [True, False]:
        for momentum in [0.0, 0.9]:
            run_experiment(mlp_model, adam_lrs[1], decay_delta=0.999, decay_k=1)
            break

Files already downloaded and verified
Files already downloaded and verified


 10%|▉         | 39/400 [05:47<53:36,  8.91s/it]Process Process-79:
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/datasets/cifar.py", line 122, in __getitem__
    img = self.transform(img)
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 42, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/ubuntu/anaconda3/en

KeyboardInterrupt: 