In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [None]:
import optuna
optuna.logging.disable_default_handler()

In [3]:
from fastai.vision import *

In [4]:
bs = 32
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((.5,),(.5,))
]
)

In [5]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transform,
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [6]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transform
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [7]:
train_dl = torch.utils.data.DataLoader(train_ds,bs,True,num_workers=12)
test_dl = torch.utils.data.DataLoader(test_ds,bs,False,num_workers=12)

In [8]:
input_size = train_ds[0][0].numel()

In [9]:
train_ds[0][0].size()

torch.Size([1, 28, 28])

In [10]:
def conv(ni,nf):
    return nn.Sequential(
        nn.Conv2d(ni,nf,kernel_size=3,stride=2,padding=1),
        nn.BatchNorm2d(nf),
        nn.ReLU(True)
    )

In [11]:
def resblock(nf):
    return nn.Sequential(
        nn.Conv2d(nf,nf,kernel_size=3),
        nn.Conv2d(nf,nf,kernel_size=3)
    )

In [25]:
class ResNet(nn.Module):
    def __init__(self,num_class):
        super().__init__()
        self.cnn1 = conv(1,16) #14
        self.res1 = res_block(16)
        self.cnn2 = conv(16,32)#7     
        self.res2 = res_block(32)
        self.cnn3 = conv(32,16)#4        
        self.res3 = res_block(16)
#         self.cnn4 = conv(32,16)#2       
#         self.res4 = res_block(16)
        self.cnn5 = conv(16,num_class)#1
    def forward(self,x):
        x = self.cnn1(x)
        x = x + self.res1(x)
        x = self.cnn2(x)        
        x = x + self.res2(x)
        x = self.cnn3(x)        
        x = x + self.res3(x)
#         x = self.cnn4(x)        
#         x = x + self.res4(x)
        x = self.cnn5(x)
        return x.view(x.size(0),-1)

In [26]:
model = ResNet(10)

In [27]:
images = next(iter(train_dl))

In [28]:
model(images[0]);

In [29]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=3e-3)

In [30]:
from tqdm import tqdm_notebook as tqdm

In [31]:
def train(model,dl,optimizer):
    model.cuda()
    model.train()
    t = tqdm(dl)
    for i,(images,labels) in enumerate(t):
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()
        preds = model(images)
        loss = criterion(preds,labels)
        t.set_description(f'{loss.item()}')
        t.refresh()
        loss.backward()
        optimizer.step()
def test(model,dl):
    model.cuda()
    model.eval()
    correct = 0
    total = 0
    for img,y in dl:
        img,y = img.cuda(),y.cuda()
        preds = model(img)
        pred = preds.max(1, keepdim=True)[1]
        total += preds.size(0)
        correct += pred.eq(y.view_as(pred)).sum().item()
    return 1-(correct / len(dl.dataset))

In [None]:
for i in range(5):
    train(model,train_dl,optimizer)
    test(model,test_dl)

In [32]:
def get_optimizer(trial, model):
    optimizer_names = ['Adam', 'MomentumSGD']
    optimizer_name = trial.suggest_categorical('optimizer', optimizer_names)
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-10, 1e-3)
    if optimizer_name == optimizer_names[0]: 
        adam_lr = trial.suggest_loguniform('adam_lr', 1e-5, 1e-1)
        optimizer = optim.Adam(model.parameters(), lr=adam_lr, weight_decay=weight_decay)
    else:
        momentum_sgd_lr = trial.suggest_loguniform('momentum_sgd_lr', 1e-5, 1e-1)
        optimizer = optim.SGD(model.parameters(), lr=momentum_sgd_lr,
                              momentum=0.9, weight_decay=weight_decay)
    return optimizer

In [33]:
def objective_wrapper(pbar):
    def objective(trial):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(device)
        model = ResNet(10).cuda()
        optimizer = get_optimizer(trial, model)

        for step in range(EPOCH):
            train(model,train_dl, optimizer)
            error_rate = test(model, test_dl)

            trial.report(error_rate, step)
            if trial.should_prune(step):
                pbar.update()
                raise optuna.structs.TrialPruned()

        pbar.update()

        return error_rate
    
    return objective

In [34]:
TRIAL_SIZE = 50
EPOCH = 1

with tqdm(total=TRIAL_SIZE) as pbar:
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
    study.optimize(objective_wrapper(pbar), n_trials=TRIAL_SIZE)

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

cuda


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

In [48]:
study.best_params

{'optimizer': 'Adam',
 'weight_decay': 0.00019131849419605393,
 'adam_lr': 0.0014739748148838122}

In [49]:
study.best_params['adam_lr']

0.0014739748148838122

In [50]:
study.best_value

0.01200000000000001

In [51]:
for i in range(5):
    train(model,train_dl,optim.Adam(model.parameters(),lr=study.best_params['adam_lr'],weight_decay=study.best_params['weight_decay']))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

In [47]:
1-test(model,test_dl)

0.9901

In [53]:
df = study.trials_dataframe()

In [55]:
df.head()

Unnamed: 0_level_0,number,state,value,datetime_start,datetime_complete,params,params,params,params,system_attrs,intermediate_values
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,adam_lr,momentum_sgd_lr,optimizer,weight_decay,_number,0
0,0,TrialState.COMPLETE,0.1355,2019-07-15 23:51:38.487460,2019-07-15 23:52:09.109515,9e-05,,Adam,6.861286e-10,0,0.1355
1,1,TrialState.COMPLETE,0.0143,2019-07-15 23:52:09.109937,2019-07-15 23:52:39.823919,0.007205,,Adam,4.62272e-08,1,0.0143
2,2,TrialState.COMPLETE,0.0265,2019-07-15 23:52:39.824519,2019-07-15 23:53:04.153158,,0.000971,MomentumSGD,0.0005418758,2,0.0265
3,3,TrialState.COMPLETE,0.1339,2019-07-15 23:53:04.154335,2019-07-15 23:53:28.654131,,0.0001,MomentumSGD,8.117865e-09,3,0.1339
4,4,TrialState.COMPLETE,0.0196,2019-07-15 23:53:28.655204,2019-07-15 23:53:52.791244,,0.036752,MomentumSGD,2.819793e-08,4,0.0196


In [57]:
df.sort_values(['value'])

Unnamed: 0_level_0,number,state,value,datetime_start,datetime_complete,params,params,params,params,system_attrs,intermediate_values
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,adam_lr,momentum_sgd_lr,optimizer,weight_decay,_number,0
24,24,TrialState.COMPLETE,0.012,2019-07-16 00:03:14.357314,2019-07-16 00:03:44.942917,0.001474,,Adam,0.0001913185,24,0.012
1,1,TrialState.COMPLETE,0.0143,2019-07-15 23:52:09.109937,2019-07-15 23:52:39.823919,0.007205,,Adam,4.62272e-08,1,0.0143
16,16,TrialState.COMPLETE,0.0146,2019-07-15 23:59:11.601895,2019-07-15 23:59:41.868783,0.003975,,Adam,4.03447e-06,16,0.0146
9,9,TrialState.COMPLETE,0.0147,2019-07-15 23:55:42.031665,2019-07-15 23:56:06.288145,,0.021316,MomentumSGD,3.319922e-05,9,0.0147
27,27,TrialState.COMPLETE,0.015,2019-07-16 00:04:46.124208,2019-07-16 00:05:16.700701,0.008598,,Adam,2.594062e-08,27,0.015
18,18,TrialState.COMPLETE,0.0154,2019-07-16 00:00:12.170658,2019-07-16 00:00:42.810779,0.004216,,Adam,3.696161e-06,18,0.0154
12,12,TrialState.COMPLETE,0.0161,2019-07-15 23:57:07.994161,2019-07-15 23:57:38.843075,0.001521,,Adam,1.070563e-10,12,0.0161
40,40,TrialState.COMPLETE,0.0161,2019-07-16 00:10:55.952270,2019-07-16 00:11:20.442152,,0.012749,MomentumSGD,1.520565e-08,40,0.0161
30,30,TrialState.COMPLETE,0.0168,2019-07-16 00:06:13.166648,2019-07-16 00:06:44.396502,0.001767,,Adam,0.0001197752,30,0.0168
25,25,TrialState.COMPLETE,0.0171,2019-07-16 00:03:44.951208,2019-07-16 00:04:15.485369,0.001219,,Adam,0.0001989868,25,0.0171
