In [1]:
import pytorch_lightning as pl
from src import Model, DataModule

In [2]:
size = 256
config = {
    # optimization
    'lr': 0.0009120108393559097,
    'optimizer': 'Adam',
    'batch_size': 64,
    # data
    'extra_data': 1,
    'subset': 0.1,
    'num_workers': 4,
    'pin_memory': True,
    # model
    'backbone': 'efficientnet_b2a',
    'pretrained': True,
    'unfreeze': 0,
    # data augmentation
    'size': size,
    'train_trans': {
        'RandomCrop': {
            'height': size, 
            'width': size
        },
        'HorizontalFlip': {},
        'VerticalFlip': {},
        'Normalize': {}
    },
    'val_trans': {
        'CenterCrop': {
            'height': size, 
            'width': size
        },
        'Normalize': {}
    },
    # training params
    'precision': 16,
    'max_epochs': 50,
    'val_batches': 5,
    'es_start_from': 0
}

In [3]:
dm = DataModule(
    file = 'data_extra' if config['extra_data'] else 'data_old', 
    **config
)

model = Model(config)

In [15]:

trainer = pl.Trainer(
    gpus=1,
    precision=config['precision'],
    limit_val_batches=config['val_batches'],
    auto_scale_batch_size='binsearch'
)

trainer.tune(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training samples:  21642
Validation samples:  5411
Training only on 2165 samples


Batch size 2 succeeded, trying batch size 4
Batch size 4 succeeded, trying batch size 8
Batch size 8 succeeded, trying batch size 16
Batch size 16 succeeded, trying batch size 32
Batch size 32 succeeded, trying batch size 64
Batch size 64 succeeded, trying batch size 128
Batch size 128 succeeded, trying batch size 256
Batch size 256 failed, trying batch size 192
Batch size 192 failed, trying batch size 160
Batch size 160 failed, trying batch size 144
Batch size 144 succeeded, trying batch size 152
Batch size 152 failed, trying batch size 148
Batch size 148 failed, trying batch size 146
Batch size 146 failed, trying batch size 145
Batch size 145 failed, trying batch size 144
Finished batch size finder, will continue with full run using batch size 144


In [10]:
model.hparams.batch_size = 64
model.hparams

"backbone":      efficientnet_b2a
"batch_size":    64
"es_start_from": 0
"extra_data":    1
"lr":            0.0003
"max_epochs":    50
"num_workers":   4
"optimizer":     Adam
"pin_memory":    True
"precision":     16
"pretrained":    True
"size":          256
"subset":        0.1
"train_trans":   {'RandomCrop': {'height': 256, 'width': 256}, 'HorizontalFlip': {}, 'VerticalFlip': {}, 'Normalize': {}}
"unfreeze":      0
"val_batches":   5
"val_trans":     {'CenterCrop': {'height': 256, 'width': 256}, 'Normalize': {}}

In [11]:
trainer = pl.Trainer(
    gpus=1,
    precision=config['precision'],
    limit_val_batches=config['val_batches'],
    auto_lr_find=True
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [12]:
lr_finder = trainer.tuner.lr_find(model, dm)


  | Name     | Type                 | Params
--------------------------------------------------
0 | backbone | EfficientNetFeatures | 7.2 M 
1 | head     | Sequential           | 1.8 K 
--------------------------------------------------
7.2 M     Trainable params
0         Non-trainable params
7.2 M     Total params


HBox(children=(FloatProgress(value=0.0, description='Finding best initial lr', style=ProgressStyle(description…

LR finder stopped early due to diverging loss.


In [7]:
lr_finder.results

{'lr': [1e-08,
  1.4454397707459274e-08,
  1.7378008287493753e-08,
  2.0892961308540398e-08,
  2.51188643150958e-08,
  3.019951720402016e-08,
  3.630780547701014e-08,
  4.36515832240166e-08,
  5.248074602497726e-08,
  6.309573444801934e-08,
  7.585775750291837e-08,
  9.120108393559096e-08,
  1.0964781961431852e-07,
  1.3182567385564074e-07,
  1.5848931924611133e-07,
  1.9054607179632475e-07,
  2.2908676527677735e-07,
  2.7542287033381663e-07,
  3.311311214825911e-07,
  3.9810717055349735e-07,
  4.786300923226383e-07,
  5.75439937337157e-07,
  6.918309709189366e-07,
  8.317637711026709e-07,
  1e-06,
  1.2022644346174132e-06,
  1.445439770745928e-06,
  1.7378008287493761e-06,
  2.089296130854039e-06,
  2.5118864315095797e-06,
  3.0199517204020163e-06,
  3.630780547701014e-06,
  4.365158322401661e-06,
  5.248074602497728e-06,
  6.3095734448019305e-06,
  7.585775750291836e-06,
  9.120108393559096e-06,
  1.0964781961431852e-05,
  1.3182567385564076e-05,
  1.584893192461114e-05,
  1.90546071

In [8]:
# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

In [9]:
new_lr

0.0009120108393559097