In [8]:
import sys
from functools import partial
from typing import Callable, List
from torch import nn, optim
from torchvision import transforms
import torch
import pandas as pd
sys.path.append('../')
from modules import dataloaders, schedulers, model, train

print(torch.cuda.is_available(), torch.backends.cudnn.enabled)
cuda_flag = torch.cuda.is_available()
device = torch.device('cuda') if cuda_flag else torch.device('cpu')
RESULT_DIR = '../results/headline'

True True


In [10]:
dl, _ = dataloaders.get_cifar10_data_loaders(data_dir='../data/cifar10', batch_size=100)
rgb_ave = dataloaders.channel_avg(dl)
tfms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=rgb_ave, std=[1,1,1])])
train_dl, test_dl = dataloaders.get_cifar10_data_loaders(data_dir='../data/cifar10',
                                             batch_size=100,
                                             num_workers=8,
                                             pin_memory=cuda_flag,
                                             train_transform=tfms,
                                             test_transform=tfms)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [12]:
def run_experiment(train_dl:torch.utils.data.DataLoader,
                   test_dl:torch.utils.data.DataLoader,
                   net:nn.Module,
                   optim_groups:List[torch.nn.parameter.Parameter],
                   scheduler_init:Callable[[optim.Optimizer], optim.lr_scheduler._LRScheduler],
                   device:torch.device,
                   iterations:int,
                   result_path:str,
                   criterion:nn.Module=nn.CrossEntropyLoss(),
                   validate_it=1000):
    'TODO: docstring'
    optimizer = optim.SGD(optim_groups, lr=0.001, momentum=0.9, weight_decay=0.004)
    scheduler = scheduler_init(optimizer)
    recorder = {
        'iteration' : [],
        'trn_loss' : [],
        'lr' : [],
        'val_loss' : [],
        'val_acc' : []}
    obe = partial(train.on_batch_end,
                  recorder,
                  test_dl,
                  net,
                  criterion,
                  device,
                  validate_it=validate_it)
    train.train_run(net,
                    train_dl,
                    criterion,
                    optimizer,
                    scheduler,
                    iterations,
                    obe,
                    device)
    return pd.DataFrame(recorder).to_csv(result_path, index=False)

In [14]:
# EXPERIMENT: learning rate range test

net = model.Cifar10Net_quick().to(device)
sched_init = partial(schedulers.TriangularScheduler, 4000, 0.001, 0.04)
run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               4000,
               f'{RESULT_DIR}/lrrt.csv',
               validate_it=100)

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2.3102407455444336 | 2.3044705772399903 | 0.1
2.312624454498291 | 2.303440170288086 | 0.1
2.299867630004883 | 2.3018194746971132 | 0.1003
2.299769163131714 | 2.299508664608002 | 0.1189
2.2890985012054443 | 2.2889735388755796 | 0.1731
2.1724050045013428 | 2.1056038117408753 | 0.2212
1.9830414056777954 | 1.9662329626083375 | 0.2669
1.7505059242248535 | 1.8333210253715515 | 0.3211
1.8099440336227417 | 1.7154952216148376 | 0.3644
1.735664963722229 | 1.6707480144500733 | 0.3902
1.8090471029281616 | 1.6958602201938628 | 0.3747
1.4420520067214966 | 1.562908991575241 | 0.4226
1.5176358222961426 | 1.5288190817832947 | 0.4484
1.5646848678588867 | 1.4690062773227692 | 0.4626
1.5287458896636963 | 1.4254344367980958 | 0.4705
1.3944320678710938 | 1.3957298028469085 | 0.493
1.4541237354278564 | 1.482874664068222 | 0.4647
1.2937211990356445 | 1.3920288789272308 | 0.4978
1.4595838785171509 | 1.3885995137691498 | 0.4898
1.3977084159851074 | 1.3327751433849335 | 0.5184
1.4297860860824585 | 1.313413417339

In [15]:
# EXPERIMENT: triangular (CLR) learning rate policy 


net = model.Cifar10Net_quick().to(device)
sched_init = partial(schedulers.TriangularScheduler, 2000, 0.0025, 0.01)

run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               35000,
               f'{RESULT_DIR}/triangular.csv')
               

# EXPERIMENT: fixed learning rate policy

sched_init = partial(schedulers.FixedScheduler, 0.001)
net = model.Cifar10Net_quick().to(device)
run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               60000,
               f'{RESULT_DIR}/fixed_1.csv')

sched_init = partial(schedulers.FixedScheduler, 0.0001)
run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               5000,
               f'{RESULT_DIR}/fixed_2.csv')

sched_init = partial(schedulers.FixedScheduler, 0.00001)
run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               5000,
               f'{RESULT_DIR}/fixed_3.csv')


HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.3046481609344482 | 2.3050693583488466 | 0.1
1.6276497840881348 | 1.6658544802665711 | 0.3775
1.6071046590805054 | 1.4575368249416352 | 0.4723
1.1853220462799072 | 1.1849050027132035 | 0.5854
1.0007984638214111 | 1.0766128051280974 | 0.6206
1.098536491394043 | 1.098484729528427 | 0.6053
1.067954659461975 | 1.0179495507478713 | 0.6452
0.873240053653717 | 0.8730618339776993 | 0.6951
0.8067431449890137 | 0.8435088264942169 | 0.7058
0.7172715067863464 | 0.8401214396953582 | 0.6997
0.9136711955070496 | 0.9045740330219268 | 0.6888
0.7858867645263672 | 0.8323401719331741 | 0.7095
0.676919162273407 | 0.768417640030384 | 0.7356
0.569383442401886 | 0.7894460743665696 | 0.7283
0.8763219714164734 | 0.8593053424358368 | 0.6918
0.7031083703041077 | 0.7635247391462326 | 0.7405
0.7737318277359009 | 0.7039665549993515 | 0.7597
0.7017474174499512 | 0.7498368972539902 | 0.7382
0.6863064765930176 | 0.7777117562294006 | 0.7274
0.5760299563407898 | 0.7579437607526779 | 0.7367
0.8141185641288757 | 0.6805918

HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.3076255321502686 | 2.3054757213592527 | 0.1
2.2966320514678955 | 2.2977423882484436 | 0.1386
1.9801298379898071 | 1.988151558637619 | 0.2702
1.7893126010894775 | 1.73906156539917 | 0.3537
1.5177464485168457 | 1.6250308978557586 | 0.3961
1.540555477142334 | 1.5488996946811675 | 0.4256
1.4656131267547607 | 1.5008938932418823 | 0.4429
1.3516697883605957 | 1.4505203294754028 | 0.463
1.4371273517608643 | 1.4071182465553285 | 0.4875
1.2681291103363037 | 1.3760166156291962 | 0.5007
1.322860598564148 | 1.3051490521430968 | 0.5242
1.350134015083313 | 1.2639894819259643 | 0.5442
1.3140590190887451 | 1.223908793926239 | 0.5625
1.4771825075149536 | 1.213692992925644 | 0.5653
1.0366463661193848 | 1.1663755536079408 | 0.585
0.9874826073646545 | 1.1559757047891617 | 0.5909
1.3291401863098145 | 1.1199875617027282 | 0.6039
1.1845009326934814 | 1.1341004800796508 | 0.6013
1.0622944831848145 | 1.1004797339439392 | 0.6124
1.0091549158096313 | 1.0626203805208205 | 0.6259
0.952467143535614 | 1.03741274237

KeyboardInterrupt: 