# Best Augmentations Tool
Based on this article: https://platform.ai/blog/page/2/finding-useful-augmentations-with-minimal-use-of-compute/

Documentation is written assuming you've read this article and you have a classification problem.

For other types of problems, it would likely be straightforward to modify the code - in particular, attention should be paid to the section titled 'Validate TTA Error Rates', where you will likely need to change the metric.

**For every bold step with a number by it, you probably need to do something. If there's no number, you only need to read if you want to understand what's going on.**

Note that the method implemented in this notebook has given some good results in practice, but, as mentioned in the article, has not been tested enough to determine whether these results are due to chance.

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

## 1. Define Path to Data

In [3]:
def get_path(woof, size):
    if   size<=128: path = URLs.IMAGEWOOF_160 if woof else URLs.IMAGENETTE_160
    elif size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else          : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    return untar_data(path)

woof = True
size = 128

path = get_path(woof=woof, size=size); path

PosixPath('/home/dc/.fastai/data/imagewoof-160')

## 2. Define `get_data`

It should accept a list of transforms (in the style of `get_transforms()`) and return a databunch.

It should also create the same validation set every time.

In [4]:
def make_data(path, size, bs=64, tfms=None, workers=None):
    if workers is None: workers = min(8, num_cpus())
    return (ImageList.from_folder(path)
            .split_by_rand_pct(.2, seed=42)
            .label_from_folder()
            .transform(tfms, size=size)
            .databunch(bs=bs, num_workers=workers)
            .presize(size, scale=(0.35,1))
            .normalize(imagenet_stats))

# just add transforms!
get_data = partial(make_data, path, size)

## 3. Define `get_learner`

Define a function to make a learner.

In [5]:
def get_learner(data, pretrained=True, load_fn=None, model_fn='modelbest', csv_fn='history'):
    callback_fns = [partial(SaveModelCallback, name=model_fn),
                    partial(CSVLogger, filename=csv_fn)]
    learn = cnn_learner(data, models.resnet34, metrics=error_rate, pretrained=pretrained, 
                        callback_fns=callback_fns)
    if load_fn: learn = learn.load(load_fn)
    return learn

## 4. Define Transforms

Define which transforms to test using `tfm_dict`.

The following code block will define the same transforms as mentioned in the article.

This method of creating tests is flexible, but rather complicated. For simple tests, it may be easier to write your own code that generates similarly structured output. In the end, `make_tfms` is just creating lists of transforms organized by transform type. If it's not apparent how you should modify `tfm_dict` to define the tests you want to perform, try running this cell as is, then inspecting the output of `make_tfms`. 

`make_tfms` is later used in the section titled 'Validate TTA Error Rates'.

In [6]:
# For each key in `tfm_dict`, this code block will produce all combinations
#    of aug_func(s) and parameter values

### Structure of `tfm_dict`
###
#   {'printable_tfm_name': 
#         [aug_func or [aug_funcs],
#          {'parameter_name1': [parameter1 values],
#           'parameter_name2': [parameter2 values]}
#         ]},
#    'next_tfm': ...
#   }
###

### For convenience, here's the transforms defined in fastai:

# all_tfms = ['brightness', 'contrast', 'crop', 'crop_pad', 'cutout', 'dihedral', 
#             'dihedral_affine', 'flip_affine', 'flip_lr', 'get_transforms', 'jitter', 
#             'pad', 'perspective_warp', 'rand_pad', 'rand_crop', 'rand_zoom', 
#             'rgb_randomize', 'rotate', 'skew', 'squish', 'rand_resize_crop', 
#             'symmetric_warp', 'tilt', 'zoom', 'zoom_crop']

tfm_dict = {'flip': [[flip_lr, dihedral_affine], {}], 
            'symmetric_warp': [symmetric_warp, 
                               {'magnitude': [(-m, m) for m in [.1, .2, .3, .4]]}], 
            'rotate': [rotate, 
                       {'degrees': [(-m, m) for m in [10., 20., 30., 40.]]}],
            'zoom': [zoom, 
                     {'scale': [(1., max_zoom) for max_zoom in [1.1, 1.2, 1.3, 1.4]]}], 
            'brightness': [brightness, 
                           {'change': [(.5*(1-max_lighting), .5*(1+max_lighting))
                                       for max_lighting in [.1, .2, .3, .4]]}], 
            'contrast': [contrast, 
                         {'scale': [(1-max_lighting, 1/(1-max_lighting))
                                    for max_lighting in [.1, .2, .3, .4]]}],
            'skew': [skew, 
                     {'direction': [(0, 0),(7, 7)], 
                      'magnitude': [max_skew for max_skew in [.2, .4, .6, .8]]}],
            'squish': [squish, 
                       {'scale': [(1/max_scale, max_scale) 
                                  for max_scale in [1.2, 1.8, 2.4, 3.]]}],
            'rand_pad_crop': [partial(rand_pad, size=size), 
                         {'padding': [size/16, size/8, size/4]}]}

def generate_param_dict(names, values):
    """
    Returns a list of parameter dictionaries for all combinations of parameter values.
    
    Parameters:
      names  - list of parameter names           (e.g. ['direction', 'magnitude'] )
      values - list of values for each parameter (e.g. [[(0,0),(7,7)],  [.2, .4, .6, .8]] )
    """
    grid_shape = [range(len(lst)) for lst in values]
    grid = np.array(np.meshgrid(*grid_shape)).T.reshape(-1,len(names))
    combinations_d = [{names[n_ix]: values[n_ix][v_ix] for n_ix, v_ix in enumerate(v_ix)}
                      for v_ix in grid]
    return combinations_d

def make_tfms(tfm_dict, pretty_print=False):
    i = 0
    tfms = defaultdict(list)
    for name, tfm_info in tfm_dict.items():
        if pretty_print: print(f'{name}: ')
        sub_tfms = listify(tfm_info[0])
        params = tfm_info[1]
        for sub_tfm in sub_tfms:
            if len(params) == 0:
                i += 1
                if pretty_print: print(f'  {i}.  {sub_tfm}: <no params>')
                tfms[name].append([sub_tfm, {}])
            else:
                param_names = [k for k in params.keys()]
                param_values = [params[n] for n in param_names]
                value_combos = generate_param_dict(param_names, param_values)
                for combo in value_combos:
                    i += 1
                    if pretty_print: print(f'  {i}.  {sub_tfm}:  {combo}')
                    tfms[name].append([sub_tfm, combo])
    return dict(tfms)

make_tfms(tfm_dict, pretty_print=True);

flip: 
  1.  TfmPixel (flip_lr): <no params>
  2.  TfmAffine (dihedral_affine): <no params>
symmetric_warp: 
  3.  TfmCoord (symmetric_warp):  {'magnitude': (-0.1, 0.1)}
  4.  TfmCoord (symmetric_warp):  {'magnitude': (-0.2, 0.2)}
  5.  TfmCoord (symmetric_warp):  {'magnitude': (-0.3, 0.3)}
  6.  TfmCoord (symmetric_warp):  {'magnitude': (-0.4, 0.4)}
rotate: 
  7.  TfmAffine (rotate):  {'degrees': (-10.0, 10.0)}
  8.  TfmAffine (rotate):  {'degrees': (-20.0, 20.0)}
  9.  TfmAffine (rotate):  {'degrees': (-30.0, 30.0)}
  10.  TfmAffine (rotate):  {'degrees': (-40.0, 40.0)}
zoom: 
  11.  TfmAffine (zoom):  {'scale': (1.0, 1.1)}
  12.  TfmAffine (zoom):  {'scale': (1.0, 1.2)}
  13.  TfmAffine (zoom):  {'scale': (1.0, 1.3)}
  14.  TfmAffine (zoom):  {'scale': (1.0, 1.4)}
brightness: 
  15.  TfmLighting (brightness):  {'change': (0.45, 0.55)}
  16.  TfmLighting (brightness):  {'change': (0.4, 0.6)}
  17.  TfmLighting (brightness):  {'change': (0.35, 0.65)}
  18.  TfmLighting (brightness):  

## Define a few helper functions:

Feel free to skip over this part.

In [7]:
def error_rate_of_best_valid_loss(path, fn):
    """Returns the error_rate for the minimum validation loss."""
    end = '' if fn.endswith('.csv') else '.csv'
    df = pd.read_csv(path/f'{fn}{end}')
    idx = df.valid_loss.idxmin()
    return df.iloc[idx].error_rate

In [8]:
def _tta_only_w_tfms(learn:Learner, tfms:list, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None) -> Iterator[List[Tensor]]:
    """
    Computes the outputs for several augmented inputs for TTA.
    Compare to `_tta_only` in fastai/vision/tta.py
    """
    from fastai.basic_train import _loss_func2activ
    from fastai.basic_data import DatasetType
    
    dl = learn.dl(ds_type)
    ds = dl.dataset
    old = ds.tfms
    activ = ifnone(activ, _loss_func2activ(learn.loss_func))
    try:
        pbar = master_bar(range(8))
        for i in pbar:
            tfm = ifnone(tfms, [])
            ds.tfms = tfm
            yield get_preds(learn.model, dl, pbar=pbar, activ=activ)[0]
    finally: ds.tfms = old
        
def _TTA_w_tfms(learn:Learner, tfms:list, beta:float=0.4, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None) -> Tensors:
    """
    Applies TTA to predict on `ds_type` dataset.
    Compare to `_TTA` in fastai/vision/tta.py
    """
    preds,y = learn.get_preds(ds_type, activ=activ)
    all_preds = list(_tta_only_w_tfms(learn, tfms, ds_type=ds_type, activ=activ))
    avg_preds = torch.stack(all_preds).mean(0)
    if beta is None: return preds,avg_preds,y
    else:
        final_preds = preds*beta + avg_preds*(1-beta)
        return final_preds, y

# Monkey-patch Learner
Learner.TTA_with_tfms = _TTA_w_tfms

In [9]:
def train(data, learn_func, epochs_head, epochs_full, 
          lr_head=3e-3, lr_full=slice(3e-5, 3e-4),
          model_fn='modelbest', csv_fn='history'):
    learn = learn_func(data, model_fn=f'{model_fn}_head', csv_fn=f'{csv_fn}_head')
    if epochs_head:
        learn.freeze()
        learn.fit_one_cycle(epochs_head, lr_head)
    if epochs_full:
        learn = learn_func(data, model_fn=f'{model_fn}_full', csv_fn=f'{csv_fn}_full')
        learn.load(f'{model_fn}_head')
        learn.unfreeze()
        learn.fit_one_cycle(epochs_full, lr_full)

## 5. Define Test Hyperparameters

The article doesn't mention what values were used for these. You will likely get better results by adjusting these.

In [10]:
EPOCHS_HEAD = 3         # Number of epochs to train the model head
EPOCHS_FULL = 5         # Number of epochs to train the full model
THRESHOLD = 0.99        # Exclude transforms which don't produce an error rate less than THRESHOLD*ERR_NONE
BETA = 0.4              # called WEIGHT_UNTRANSFORMED in the article

## Fine-Tune with No Data Augmentation

Train the last layer group on 80% of the training set for EPOCHS_HEAD epochs, without any data augmentation.

In [11]:
data = get_data()
csv_fn = 'history_no_aug'
model_fn = 'no_aug'
train(data, get_learner, EPOCHS_HEAD, 0, model_fn=model_fn, csv_fn=csv_fn)

epoch,train_loss,valid_loss,error_rate,time
0,0.850775,0.44907,0.145174,00:11
1,0.568087,0.331373,0.109266,00:09
2,0.419854,0.305615,0.102703,00:09


Better model found at epoch 0 with valid_loss value: 0.44907015562057495.
Better model found at epoch 1 with valid_loss value: 0.33137258887290955.
Better model found at epoch 2 with valid_loss value: 0.30561521649360657.


## Calculate Error Rate

Calculate the error rate ERR_NONE on the remaining 20% of the training set.

In [12]:
ERR_NONE = error_rate_of_best_valid_loss(data.path, f'{csv_fn}_head'); ERR_NONE

0.102703

## Calculate TTA Error Rates for Augmentation Tests

For each kind of transformation, for each possible magnitude, calculate the TTA error rate on the remaining 20% of the training set. TTA predictions are based on `BETA*logits_without_tfms + (1 - BETA)*avg_tta_logits`.

**Note: running this cell may take a while, depending on how many TTA tests there are to run**

In [13]:
all_tests = make_tfms(tfm_dict)
results = {}
learn = get_learner(data, load_fn=f'{model_fn}_head')
for name, tests in all_tests.items():
    tfms = [tfm(**params) for tfm, params in tests]
    errs = [error_rate(*learn.TTA_with_tfms(tfm, beta=BETA)) for tfm in tfms]
    results[name] = list(errs)

In [14]:
results

{'flip': [tensor(0.0892), tensor(0.0927)],
 'symmetric_warp': [tensor(0.0892),
  tensor(0.0861),
  tensor(0.0876),
  tensor(0.0838)],
 'rotate': [tensor(0.0931), tensor(0.0911), tensor(0.0876), tensor(0.0884)],
 'zoom': [tensor(0.0923), tensor(0.0876), tensor(0.0838), tensor(0.0838)],
 'brightness': [tensor(0.1019),
  tensor(0.1039),
  tensor(0.1012),
  tensor(0.1015)],
 'contrast': [tensor(0.1019), tensor(0.1019), tensor(0.1019), tensor(0.0996)],
 'skew': [tensor(0.0919),
  tensor(0.0942),
  tensor(0.0973),
  tensor(0.0969),
  tensor(0.0938),
  tensor(0.0907),
  tensor(0.0965),
  tensor(0.0950)],
 'squish': [tensor(0.0942), tensor(0.0896), tensor(0.0896), tensor(0.0892)],
 'rand_pad_crop': [tensor(0.0903), tensor(0.0907), tensor(0.0919)]}

## Pick Best Transformations

For each kind of transformation, choose the magnitude which leads to the lowest TTA error rate, if that error rate is lower than THRESHOLD * ERR_NONE; otherwise, don't include that kind of transformation in the final set of augmentations.

In [15]:
final_tfms = []
tfm_types = list(results.keys())
for tfm_type in tfm_types:
    errs = results[tfm_type]
    best_idx = np.argmin(errs)
    best_err = errs[best_idx]
    if best_err < THRESHOLD * ERR_NONE:
        final_tfms.append(all_tests[tfm_type][best_idx])
final_tfms

[[TfmPixel (flip_lr), {}],
 [TfmCoord (symmetric_warp), {'magnitude': (-0.4, 0.4)}],
 [TfmAffine (rotate), {'degrees': (-30.0, 30.0)}],
 [TfmAffine (zoom), {'scale': (1.0, 1.3)}],
 [TfmLighting (brightness), {'change': (0.35, 0.65)}],
 [TfmLighting (contrast), {'scale': (0.6, 1.6666666666666667)}],
 [TfmCoord (skew), {'direction': (7, 7), 'magnitude': 0.4}],
 [TfmAffine (squish), {'scale': (0.3333333333333333, 3.0)}],
 [functools.partial(<function rand_pad at 0x7f94ae5e8050>, size=128),
  {'padding': 8.0}]]

## Train with Best Transformations

With the chosen set of augmentations, train the head for EPOCHS_HEAD epochs and the full network for EPOCHS_FULL.

In [16]:
tfms = []
for tfm_func, params in final_tfms:
    tfm = tfm_func(**params)
    # some tfm_funcs return a list of tfms, we need to squish them all together
    tfms.extend(tfm if is_listy(tfm) else [tfm])
    
data = get_data(tfms=[tfms, []])
csv_fn = 'history_best'
model_fn = 'best_tfms'
train(data, get_learner, EPOCHS_HEAD, EPOCHS_FULL, model_fn=model_fn, csv_fn=csv_fn)

epoch,train_loss,valid_loss,error_rate,time
0,1.600415,0.553326,0.18417,00:15
1,1.1176,0.442382,0.146718,00:15
2,0.963632,0.379293,0.119691,00:15


Better model found at epoch 0 with valid_loss value: 0.5533258318901062.
Better model found at epoch 1 with valid_loss value: 0.44238224625587463.
Better model found at epoch 2 with valid_loss value: 0.3792932331562042.


epoch,train_loss,valid_loss,error_rate,time
0,0.92773,0.364375,0.120077,00:16
1,0.94038,0.442665,0.14749,00:15
2,0.8694,0.404624,0.13668,00:15
3,0.7588,0.35512,0.115058,00:16
4,0.740413,0.338701,0.109653,00:15


Better model found at epoch 0 with valid_loss value: 0.36437463760375977.
Better model found at epoch 3 with valid_loss value: 0.3551204204559326.
Better model found at epoch 4 with valid_loss value: 0.3387012779712677.


In [17]:
ERR_BEST_AUGS = error_rate_of_best_valid_loss(data.path, f'{csv_fn}_full')
ERR_BEST_AUGS

0.109653

## Train Baseline

As a baseline, train the network for the same number of epochs using the transforms provided by get_transforms().

In [18]:
tfms = get_transforms()
data = get_data(tfms=tfms)
csv_fn = 'history_baseline'
model_fn = 'baseline'
train(data, get_learner, EPOCHS_HEAD, EPOCHS_FULL, model_fn=model_fn, csv_fn=csv_fn)

epoch,train_loss,valid_loss,error_rate,time
0,0.914859,0.40971,0.130888,00:11
1,0.63022,0.338875,0.116602,00:11
2,0.497871,0.295071,0.096525,00:11


Better model found at epoch 0 with valid_loss value: 0.4097101092338562.
Better model found at epoch 1 with valid_loss value: 0.3388746380805969.
Better model found at epoch 2 with valid_loss value: 0.2950712740421295.


epoch,train_loss,valid_loss,error_rate,time
0,0.501394,0.327873,0.106564,00:13
1,0.496799,0.340813,0.105019,00:13
2,0.432724,0.327442,0.105019,00:13
3,0.334245,0.287076,0.09305,00:13
4,0.275871,0.286491,0.095753,00:14


Better model found at epoch 0 with valid_loss value: 0.3278730511665344.
Better model found at epoch 2 with valid_loss value: 0.3274422883987427.
Better model found at epoch 3 with valid_loss value: 0.28707581758499146.
Better model found at epoch 4 with valid_loss value: 0.2864907383918762.


In [19]:
ERR_DEFAULT_AUGS = error_rate_of_best_valid_loss(data.path, f'{csv_fn}_full')
ERR_DEFAULT_AUGS

0.09575299999999999

## Optional: Train 2nd Baseline - No Augs

This wasn't done in the article, but this is a good check. If you get a better result using no data augmentation at all, then you are probably not training for enough epochs as additional data augmentation can cause slower convergence (TODO: need to verify this)

In [20]:
tfms = None
data = get_data(tfms=tfms)
csv_fn = 'history_baseline_none'
model_fn = 'baseline_none'
train(data, get_learner, EPOCHS_HEAD, EPOCHS_FULL, model_fn=model_fn, csv_fn=csv_fn)

epoch,train_loss,valid_loss,error_rate,time
0,0.828232,0.469234,0.136293,00:10
1,0.563967,0.351915,0.12471,00:10
2,0.429604,0.31348,0.099228,00:10


Better model found at epoch 0 with valid_loss value: 0.4692339599132538.
Better model found at epoch 1 with valid_loss value: 0.35191527009010315.
Better model found at epoch 2 with valid_loss value: 0.3134795129299164.


epoch,train_loss,valid_loss,error_rate,time
0,0.41293,0.328887,0.105405,00:13
1,0.446553,0.347663,0.111969,00:13
2,0.314578,0.328951,0.113514,00:13
3,0.229698,0.312599,0.101544,00:13
4,0.180178,0.307187,0.102317,00:13


Better model found at epoch 0 with valid_loss value: 0.32888707518577576.
Better model found at epoch 3 with valid_loss value: 0.3125993311405182.
Better model found at epoch 4 with valid_loss value: 0.30718743801116943.


In [21]:
ERR_BASELINE_NONE = error_rate_of_best_valid_loss(data.path, f'{csv_fn}_full')
ERR_BASELINE_NONE

0.10231699999999999

## Compare to Baseline

In [22]:
try: print('baseline_none:', ERR_BASELINE_NONE, '\ndefault_augs:', ERR_DEFAULT_AUGS, '\nbest_augs:', ERR_BEST_AUGS)
except: print('default_augs:', ERR_DEFAULT_AUGS, '\nbest_augs:', ERR_BEST_AUGS)

baseline_none: 0.10231699999999999 
default_augs: 0.09575299999999999 
best_augs: 0.109653


In [23]:
final_tfms

[[TfmPixel (flip_lr), {}],
 [TfmCoord (symmetric_warp), {'magnitude': (-0.4, 0.4)}],
 [TfmAffine (rotate), {'degrees': (-30.0, 30.0)}],
 [TfmAffine (zoom), {'scale': (1.0, 1.3)}],
 [TfmLighting (brightness), {'change': (0.35, 0.65)}],
 [TfmLighting (contrast), {'scale': (0.6, 1.6666666666666667)}],
 [TfmCoord (skew), {'direction': (7, 7), 'magnitude': 0.4}],
 [TfmAffine (squish), {'scale': (0.3333333333333333, 3.0)}],
 [functools.partial(<function rand_pad at 0x7f94ae5e8050>, size=128),
  {'padding': 8.0}]]

In [24]:
get_transforms()

([RandTransform(tfm=TfmCrop (crop_pad), kwargs={'row_pct': (0, 1), 'col_pct': (0, 1), 'padding_mode': 'reflection'}, p=1.0, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmPixel (flip_lr), kwargs={}, p=0.5, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmCoord (symmetric_warp), kwargs={'magnitude': (-0.2, 0.2)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmAffine (rotate), kwargs={'degrees': (-10.0, 10.0)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmAffine (zoom), kwargs={'scale': (1.0, 1.1), 'row_pct': (0, 1), 'col_pct': (0, 1)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmLighting (brightness), kwargs={'change': (0.4, 0.6)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
  RandTransform(tfm=TfmLighting (contrast), kwargs={'scale': (0.8, 1.25)}, p=0.75, resolved={}, do