In [8]:
import sys
from functools import partial
from typing import Callable, List
from torch import nn, optim
from torchvision import transforms
import torch
import pandas as pd
sys.path.append('../')
from modules import dataloaders, schedulers, model, train

print(torch.cuda.is_available(), torch.backends.cudnn.enabled)
cuda_flag = torch.cuda.is_available()
device = torch.device('cuda') if cuda_flag else torch.device('cpu')
RESULT_DIR = '../results/headline'

True True


In [10]:
dl, _ = dataloaders.get_cifar10_data_loaders(data_dir='../data/cifar10', batch_size=100)
rgb_ave = dataloaders.channel_avg(dl)
tfms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=rgb_ave, std=[1,1,1])])
train_dl, test_dl = dataloaders.get_cifar10_data_loaders(data_dir='../data/cifar10',
                                             batch_size=100,
                                             num_workers=8,
                                             pin_memory=cuda_flag,
                                             train_transform=tfms,
                                             test_transform=tfms)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [12]:
def run_experiment(train_dl:torch.utils.data.DataLoader,
                   test_dl:torch.utils.data.DataLoader,
                   net:nn.Module,
                   optim_groups:List[torch.nn.parameter.Parameter],
                   scheduler_init:Callable[[optim.Optimizer], optim.lr_scheduler._LRScheduler],
                   device:torch.device,
                   iterations:int,
                   result_path:str,
                   criterion:nn.Module=nn.CrossEntropyLoss(),
                   validate_it=1000):
    'TODO: docstring'
    optimizer = optim.SGD(optim_groups, lr=0.001, momentum=0.9, weight_decay=0.004)
    scheduler = scheduler_init(optimizer)
    recorder = {
        'iteration' : [],
        'trn_loss' : [],
        'lr' : [],
        'val_loss' : [],
        'val_acc' : []}
    obe = partial(train.on_batch_end,
                  recorder,
                  test_dl,
                  net,
                  criterion,
                  device,
                  validate_it=validate_it)
    train.train_run(net,
                    train_dl,
                    criterion,
                    optimizer,
                    scheduler,
                    iterations,
                    obe,
                    device)
    return pd.DataFrame(recorder).to_csv(result_path, index=False)

In [14]:
# EXPERIMENT: learning rate range test

net = model.Cifar10Net_quick().to(device)
sched_init = partial(schedulers.TriangularScheduler, 4000, 0.001, 0.04)
run_experiment(train_dl,
               test_dl,
               net,
               net.parameters(),
               sched_init,
               device,
               4000,
               f'{RESULT_DIR}/lrrt.csv',
               validate_it=100)

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2.3102407455444336 | 2.3044705772399903 | 0.1
2.312624454498291 | 2.303440170288086 | 0.1
2.299867630004883 | 2.3018194746971132 | 0.1003
2.299769163131714 | 2.299508664608002 | 0.1189
2.2890985012054443 | 2.2889735388755796 | 0.1731
2.1724050045013428 | 2.1056038117408753 | 0.2212
1.9830414056777954 | 1.9662329626083375 | 0.2669
1.7505059242248535 | 1.8333210253715515 | 0.3211
1.8099440336227417 | 1.7154952216148376 | 0.3644
1.735664963722229 | 1.6707480144500733 | 0.3902
1.8090471029281616 | 1.6958602201938628 | 0.3747
1.4420520067214966 | 1.562908991575241 | 0.4226
1.5176358222961426 | 1.5288190817832947 | 0.4484
1.5646848678588867 | 1.4690062773227692 | 0.4626
1.5287458896636963 | 1.4254344367980958 | 0.4705
1.3944320678710938 | 1.3957298028469085 | 0.493
1.4541237354278564 | 1.482874664068222 | 0.4647
1.2937211990356445 | 1.3920288789272308 | 0.4978
1.4595838785171509 | 1.3885995137691498 | 0.4898
1.3977084159851074 | 1.3327751433849335 | 0.5184
1.4297860860824585 | 1.313413417339

In [17]:
# EXPERIMENT: triangular (CLR) learning rate policy 
# Run 5 times

for i in range(1, 6):
    
    net = model.Cifar10Net_quick().to(device)
    sched_init = partial(schedulers.TriangularScheduler, 2000, 0.0025, 0.01)
    run_experiment(train_dl,
                   test_dl,
                   net,
                   net.parameters(),
                   sched_init,
                   device,
                   35000,
                   f'{RESULT_DIR}/triangular_{i}.csv')
               

# EXPERIMENT: fixed learning rate policy
# run 5 times

for i in range(1, 6):

    sched_init = partial(schedulers.FixedScheduler, 0.001)
    net = model.Cifar10Net_quick().to(device)
    run_experiment(train_dl,
                   test_dl,
                   net,
                   net.parameters(),
                   sched_init,
                   device,
                   60000,
                   f'{RESULT_DIR}/fixed_1_{i}.csv')

    sched_init = partial(schedulers.FixedScheduler, 0.0001)
    run_experiment(train_dl,
                   test_dl,
                   net,
                   net.parameters(),
                   sched_init,
                   device,
                   5000,
                   f'{RESULT_DIR}/fixed_2_{i}.csv')

    sched_init = partial(schedulers.FixedScheduler, 0.00001)
    run_experiment(train_dl,
                   test_dl,
                   net,
                   net.parameters(),
                   sched_init,
                   device,
                   5000,
                   f'{RESULT_DIR}/fixed_3_{i}.csv')


HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.3081281185150146 | 2.3051880550384523 | 0.1005
1.738964557647705 | 1.665504242181778 | 0.3855
1.4795106649398804 | 1.331455579996109 | 0.5143
1.150672435760498 | 1.1473946970701219 | 0.5882
1.0724025964736938 | 1.0152420073747634 | 0.6438
1.160057544708252 | 0.9966735970973969 | 0.654
0.9784924387931824 | 0.9519199401140213 | 0.6694
0.9737328886985779 | 0.8932129734754563 | 0.6839
0.7170023322105408 | 0.8222486644983291 | 0.7186
0.9279277920722961 | 0.8243024873733521 | 0.7127
0.8530652523040771 | 0.8690129601955414 | 0.6967
0.80401211977005 | 0.8342496061325073 | 0.7093
0.7014123797416687 | 0.7405118662118911 | 0.7455
0.6739166378974915 | 0.8270181936025619 | 0.7107
1.114944577217102 | 0.9178143709897995 | 0.6831
0.6370658278465271 | 0.7809011441469192 | 0.7254
0.6156260967254639 | 0.7235390922427177 | 0.7509
0.7376428842544556 | 0.7568728375434876 | 0.7392
0.6598946452140808 | 0.8181523317098618 | 0.7179
0.8681234121322632 | 0.7393771678209304 | 0.7427
0.6458835601806641 | 0.698226

HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.299347162246704 | 2.303637261390686 | 0.1
1.617915391921997 | 1.642737694978714 | 0.3965
1.225449800491333 | 1.4018867301940918 | 0.4846
1.0574604272842407 | 1.1851684534549713 | 0.5765
0.888020396232605 | 1.058115347623825 | 0.627
1.0745081901550293 | 1.0246241694688798 | 0.64
0.8646637201309204 | 0.9980749648809433 | 0.6453
0.9898565411567688 | 0.8737911486625671 | 0.7001
0.8038210272789001 | 0.8424463295936584 | 0.7096
0.8592265844345093 | 0.9025171422958373 | 0.6825
0.8175546526908875 | 0.880466188788414 | 0.6932
0.6128280758857727 | 0.8442606735229492 | 0.7032
0.6301299333572388 | 0.7580411154031753 | 0.7425
0.7640140652656555 | 0.7763451319932938 | 0.7337
0.772164523601532 | 0.8122978615760803 | 0.7227
0.8133676648139954 | 0.7431954419612885 | 0.7462
0.6177752017974854 | 0.7088285079598426 | 0.7549
0.7467721700668335 | 0.7654716563224793 | 0.7329
0.7681484818458557 | 0.8379352337121964 | 0.7118
0.6841251850128174 | 0.7114544740319252 | 0.7549
0.5848910808563232 | 0.704564609527

HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.2967615127563477 | 2.3065865063667297 | 0.1
1.7382959127426147 | 1.66662682056427 | 0.3861
1.4376815557479858 | 1.3496333837509156 | 0.5037
0.9594667553901672 | 1.1639399307966232 | 0.5821
1.0555710792541504 | 1.0355192428827287 | 0.637
0.9375876784324646 | 1.0330023556947707 | 0.6373
1.0066946744918823 | 0.9792000997066498 | 0.6525
0.9209896922111511 | 0.9121460723876953 | 0.68
0.8302682638168335 | 0.8245604473352433 | 0.712
0.7974468469619751 | 0.8463892287015915 | 0.7062
0.7657054662704468 | 0.8768961358070374 | 0.6993
0.6374652981758118 | 0.8050630861520767 | 0.7222
0.7283817529678345 | 0.766831705570221 | 0.731
0.5873028039932251 | 0.7840184801816941 | 0.7267
0.7146365642547607 | 0.8640239024162293 | 0.6952
0.7668939232826233 | 0.7430382764339447 | 0.7464
0.5502988696098328 | 0.7236514693498611 | 0.7435
0.7647444009780884 | 0.7737916851043701 | 0.7331
0.8036968111991882 | 0.800426282286644 | 0.7284
0.594533383846283 | 0.7221731323003769 | 0.7508
0.6039214730262756 | 0.6897313562

HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.3107714653015137 | 2.3037078833580016 | 0.1001
1.7871917486190796 | 1.6296509957313539 | 0.3891
1.397151231765747 | 1.4346000742912293 | 0.4785
1.3798251152038574 | 1.1770915579795838 | 0.5772
1.277231216430664 | 1.0596520459651948 | 0.6244
0.9499526023864746 | 0.9992859840393067 | 0.6465
0.9290933012962341 | 1.008832541704178 | 0.6426
0.9343844056129456 | 0.8648380535840988 | 0.7
0.8247926235198975 | 0.8192868012189866 | 0.7132
0.9912943840026855 | 0.8615369200706482 | 0.7007
0.773872435092926 | 0.8739808905124664 | 0.7023
0.6843559741973877 | 0.8040102541446685 | 0.7207
0.7088581919670105 | 0.7617915314435959 | 0.7373
0.6052320003509521 | 0.8647695463895798 | 0.6982
0.794687807559967 | 0.8799992960691452 | 0.7007
0.7306278944015503 | 0.7869658249616623 | 0.7191
0.6305683255195618 | 0.7061074459552765 | 0.7583
0.6098673343658447 | 0.7731593436002732 | 0.7361
0.8396649956703186 | 0.8128264397382736 | 0.7156
0.6211410164833069 | 0.7215045467019081 | 0.7487
0.5655218362808228 | 0.69371

HBox(children=(IntProgress(value=0, max=35000), HTML(value='')))

2.3086419105529785 | 2.305150496959686 | 0.1
1.6793400049209595 | 1.6421953332424164 | 0.3945
1.4035935401916504 | 1.4425472509860993 | 0.4875
1.2629897594451904 | 1.1422997510433197 | 0.5924
1.0646018981933594 | 1.054687920808792 | 0.626
1.0891112089157104 | 1.0199997323751449 | 0.6351
1.0268969535827637 | 0.979197204709053 | 0.6545
0.7591963410377502 | 0.9384381020069122 | 0.6754
0.7378935813903809 | 0.8171835106611252 | 0.7196
0.7479073405265808 | 0.9077901250123978 | 0.6814
0.8332498073577881 | 0.9133367121219635 | 0.6827
0.6751994490623474 | 0.8140503251552582 | 0.717
0.7222087383270264 | 0.7421411156654358 | 0.748
0.7132317423820496 | 0.8391111242771149 | 0.7118
0.7766240835189819 | 0.8012507647275925 | 0.726
0.5493375658988953 | 0.775205384194851 | 0.7324
0.6715793013572693 | 0.7169601884484291 | 0.7526
0.574696958065033 | 0.7526003652811051 | 0.7418
0.8064730763435364 | 0.780873510837555 | 0.7313
0.7076030969619751 | 0.7391607218980789 | 0.7438
0.647723913192749 | 0.69023797929

HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.302161455154419 | 2.303682413101196 | 0.1
2.2194292545318604 | 2.196604869365692 | 0.1935
1.8911033868789673 | 1.829579312801361 | 0.3122
1.642427682876587 | 1.6669208014011383 | 0.3839
1.5316646099090576 | 1.5719902324676513 | 0.4163
1.667109727859497 | 1.4963288640975951 | 0.4483
1.5728968381881714 | 1.4321515989303588 | 0.4758
1.5480401515960693 | 1.4284981632232665 | 0.4872
1.3230552673339844 | 1.3662333154678346 | 0.5051
1.327164888381958 | 1.32965864777565 | 0.5243
1.389552116394043 | 1.2690277636051177 | 0.5461
1.0710490942001343 | 1.2374437129497529 | 0.5555
1.0825456380844116 | 1.221505137681961 | 0.5676
1.1476000547409058 | 1.166151067018509 | 0.5858
1.1713985204696655 | 1.151418097615242 | 0.5942
1.0483144521713257 | 1.1111010801792145 | 0.607
1.0236915349960327 | 1.0919683080911637 | 0.6183
1.0846236944198608 | 1.0807886445522308 | 0.6175
1.1846979856491089 | 1.0579353016614914 | 0.6304
0.9463834166526794 | 1.0462534660100937 | 0.6321
0.8931590914726257 | 1.01630709409713

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.480459064245224 | 0.7445696312189102 | 0.7491
0.516143798828125 | 0.7163516545295715 | 0.7592
0.7367179989814758 | 0.7143551474809646 | 0.76
0.5775317549705505 | 0.7142197108268737 | 0.7599
0.49819809198379517 | 0.7134304809570312 | 0.759
0.6096295714378357 | 0.7139431044459343 | 0.7589


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.47544652223587036 | 0.7139426523447037 | 0.7588
0.6663917303085327 | 0.7109513586759567 | 0.7618
0.7243687510490417 | 0.7107250562310219 | 0.7601
0.49908941984176636 | 0.7107535094022751 | 0.7608
0.782410740852356 | 0.7105361387133599 | 0.7617
0.5956037640571594 | 0.7106218779087067 | 0.761


HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.3125743865966797 | 2.3039981579780577 | 0.1
2.2647597789764404 | 2.2585552096366883 | 0.1716
1.8907458782196045 | 1.9379895663261413 | 0.272
1.7592915296554565 | 1.7582949686050415 | 0.3444
1.6241341829299927 | 1.6454595613479615 | 0.393
1.5875000953674316 | 1.576356165409088 | 0.4164
1.5342645645141602 | 1.5195362389087677 | 0.4368
1.6605738401412964 | 1.468752954006195 | 0.4606
1.319650411605835 | 1.414180474281311 | 0.4783
1.3844164609909058 | 1.371578073501587 | 0.491
1.4671516418457031 | 1.3303797924518586 | 0.5139
1.2029234170913696 | 1.2879391384124756 | 0.5323
1.2779635190963745 | 1.253465923666954 | 0.549
1.17975914478302 | 1.2369444274902344 | 0.5591
1.1509382724761963 | 1.2030966192483903 | 0.5758
1.183713436126709 | 1.2046618050336837 | 0.5702
1.1943981647491455 | 1.136709355711937 | 0.5972
1.295949101448059 | 1.1217266088724136 | 0.5998
1.1120951175689697 | 1.0926160448789597 | 0.6131
1.0200469493865967 | 1.0735883218050004 | 0.6154
0.9997021555900574 | 1.081205865740776

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.7536212205886841 | 0.7232513117790222 | 0.7542
0.7806680202484131 | 0.705348442196846 | 0.7584
0.5248857140541077 | 0.703777636885643 | 0.7594
0.5718789100646973 | 0.7068646782636643 | 0.7569
0.5383149981498718 | 0.7025806814432144 | 0.7584
0.7145435214042664 | 0.701644839644432 | 0.7585


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.7400761246681213 | 0.7016532534360885 | 0.7584
0.6493908762931824 | 0.7000586172938347 | 0.7607
0.7983807921409607 | 0.7002582937479019 | 0.7601
0.6366581916809082 | 0.7003783422708512 | 0.7583
0.680396318435669 | 0.7000463408231735 | 0.7592
0.6009761691093445 | 0.7004066002368927 | 0.7588


HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.2893190383911133 | 2.305589187145233 | 0.1109
2.117851734161377 | 2.1278552532196047 | 0.2106
1.7930207252502441 | 1.865374310016632 | 0.3151
1.5930087566375732 | 1.7053271186351777 | 0.3637
1.4978039264678955 | 1.590683308839798 | 0.4053
1.5037145614624023 | 1.518517769575119 | 0.433
1.4995872974395752 | 1.4822452425956727 | 0.4553
1.3237515687942505 | 1.4106175196170807 | 0.48
1.392795205116272 | 1.3448641657829286 | 0.5074
1.407563328742981 | 1.3059576058387756 | 0.5266
1.1862844228744507 | 1.271847310066223 | 0.5384
1.3045339584350586 | 1.2751716578006744 | 0.5419
1.2036088705062866 | 1.2160215497016906 | 0.5661
1.202947735786438 | 1.1749471193552017 | 0.5805
1.150506854057312 | 1.1608082675933837 | 0.5878
0.9414973258972168 | 1.1261319673061372 | 0.5995
1.2323987483978271 | 1.1276207000017167 | 0.5974
1.0142583847045898 | 1.0955677497386933 | 0.6085
1.0193076133728027 | 1.0668160605430603 | 0.621
1.081117033958435 | 1.0587258231639862 | 0.6289
1.0626050233840942 | 1.035202063918

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.6400810480117798 | 0.7470376598834991 | 0.7413
0.6437246203422546 | 0.7172367799282074 | 0.7541
0.6993143558502197 | 0.7131546598672867 | 0.7552
0.5067879557609558 | 0.713929348886013 | 0.7561
0.5510636568069458 | 0.7151313328742981 | 0.7559
0.8361061811447144 | 0.7156076937913894 | 0.7544


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.514895498752594 | 0.7156133061647415 | 0.7544
0.679033637046814 | 0.711599503159523 | 0.7536
0.776516854763031 | 0.7112180674076081 | 0.7537
0.631237268447876 | 0.7111503225564957 | 0.7539
0.6389885544776917 | 0.7111429086327553 | 0.7535
0.5790529847145081 | 0.7111992639303207 | 0.7535


HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.2942497730255127 | 2.304554522037506 | 0.1
2.275017023086548 | 2.2805168795585633 | 0.1906
1.9132020473480225 | 1.9282059252262116 | 0.3029
1.7522635459899902 | 1.6925348031520844 | 0.3743
1.8402700424194336 | 1.5746035492420196 | 0.4157
1.6396420001983643 | 1.5177902817726134 | 0.4387
1.5997252464294434 | 1.4513669657707213 | 0.4659
1.5713578462600708 | 1.4414219987392425 | 0.4762
1.3311092853546143 | 1.362903118133545 | 0.5051
1.3911932706832886 | 1.312984845638275 | 0.5258
1.2477164268493652 | 1.2973519742488862 | 0.5316
1.377392292022705 | 1.2451799380779267 | 0.5551
1.182533621788025 | 1.221853305697441 | 0.5638
1.1561472415924072 | 1.18661379635334 | 0.5821
1.2055118083953857 | 1.1569846081733703 | 0.5924
1.2438620328903198 | 1.1356801825761795 | 0.6003
0.9588149189949036 | 1.1636093878746032 | 0.5893
1.0649280548095703 | 1.107153869867325 | 0.6093
1.0031914710998535 | 1.0985066014528275 | 0.6075
1.453597903251648 | 1.0640332835912705 | 0.6219
1.067983865737915 | 1.041152359843

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.5727329254150391 | 0.7625291085243225 | 0.7376
0.5825637578964233 | 0.7100369697809219 | 0.7588
0.5754715204238892 | 0.7080162507295609 | 0.7589
0.5785784721374512 | 0.7117397880554199 | 0.7581
0.5773131847381592 | 0.708684373497963 | 0.7591
0.6024155616760254 | 0.7072399196028709 | 0.7594


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.7596209645271301 | 0.7072223627567291 | 0.7594
0.500663697719574 | 0.7052286583185196 | 0.761
0.6125845909118652 | 0.7053691473603249 | 0.762
0.5493850111961365 | 0.7054248493909836 | 0.7616
0.6369531154632568 | 0.7053566789627075 | 0.7615
0.6935456991195679 | 0.7051638948917389 | 0.7615


HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))

2.3076300621032715 | 2.3058548402786254 | 0.1
2.2786638736724854 | 2.2763002848625185 | 0.2151
1.9515702724456787 | 1.9702565634250642 | 0.2841
1.8182286024093628 | 1.753011659383774 | 0.356
1.7078303098678589 | 1.6215591216087342 | 0.4
1.580376386642456 | 1.539909280538559 | 0.4288
1.4411572217941284 | 1.4893878173828126 | 0.4534
1.409519910812378 | 1.4303413724899292 | 0.4755
1.167982578277588 | 1.3637947535514832 | 0.5047
1.3845378160476685 | 1.329739637374878 | 0.5206
1.263371467590332 | 1.3070685338974 | 0.5289
1.2464780807495117 | 1.291545866727829 | 0.5382
1.2362220287322998 | 1.2329987275600434 | 0.5595
1.402095913887024 | 1.1989743721485138 | 0.5714
1.1734600067138672 | 1.1551129853725433 | 0.5921
1.1591438055038452 | 1.1297784858942033 | 0.6005
1.109192132949829 | 1.1020210641622543 | 0.611
1.173784613609314 | 1.1164515542984008 | 0.6011
1.1502082347869873 | 1.0843731743097305 | 0.616
0.9079688787460327 | 1.0691058397293092 | 0.6232
1.0360027551651 | 1.0437082797288895 | 0.63

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.6413182020187378 | 0.734231016933918 | 0.748
0.5960551500320435 | 0.7112665700912476 | 0.7564
0.6096856594085693 | 0.7114290902018547 | 0.7555
0.6046786308288574 | 0.7088432914018631 | 0.7577
0.57435542345047 | 0.7097105646133423 | 0.7573
0.6621636152267456 | 0.7108620977401734 | 0.7566


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.5907119512557983 | 0.71086072742939 | 0.7567
0.5895251035690308 | 0.706057899594307 | 0.7585
0.6198488473892212 | 0.7063105863332748 | 0.7577
0.6500656008720398 | 0.7058626517653466 | 0.7583
0.6162900328636169 | 0.7059858545660973 | 0.7596
0.7409528493881226 | 0.7061109763383865 | 0.7599
