In [None]:
# default_exp train

# Train

> This module contains a script to train a model.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#export
from plant_pathology.dataset import *
from plant_pathology.evaluate import *

from fastai.vision.all import *
from fastcore.script import *
from fastai.callback.wandb import *
from wwf.vision.timm import *
import timm
import wandb
from typing import *
from sys import exit

## Train a Model on Data Split

In [None]:
def timm_or_fastai_arch(arch: str) -> (Union[Any, str], Callable[..., Learner]):
    try:  # Check if fastai arch
        model = globals()[arch]
        learner_func = cnn_learner
    except KeyError:  # Must be timm arch
        model = arch
        learner_func = timm_learner
    return model, learner_func

In [None]:
#export
def train(
    epochs: int, lr: Union[float, str], frz: int=1, pre: int=800, re: int=256,
    bs: int=256, fold: int=4, smooth: bool=False, 
    arch: str='resnet18', dump: bool=False, log: bool=False, mixup: float=0.,
    fp16: bool=False, dls: DataLoaders=None,
 ):
    # Prep Data, Opt, Loss, Arch
    if dls is None: dls = get_dls_all_in_1(presize=pre, resize=re, bs=bs, val_fold=fold)
    if log: wandb.init(project="plant-pathology")
    if smooth: loss_func = LabelSmoothingCrossEntropyFlat()
    else:      loss_func = CrossEntropyLossFlat()
    m, learner_func = timm_or_fastai_arch(arch)
    
    # Add callbacks
    cbs = [WandbCallback(), SaveModelCallback()] if log else []
    if mixup: cbs.append(MixUp(mixup))
        
    # Build learner
    learn = learner_func(dls, m, loss_func=loss_func,
                    metrics=[accuracy, RocAuc()], cbs=cbs)
    if dump: print(learn.model); exit()
    if lr=="find": learn.lr_find(); exit()
    if fp16: learn.to_fp16()
        
    # Train
    learn.freeze()
    learn.fit_one_cycle(frz, lr)
    learn.unfreeze()
    learn.fit_one_cycle(epochs, slice(lr/100, lr/2))  # Explore other divs
    return learn

In [None]:
learn = train(0, 0.001, bs=256, log=False)

epoch,train_loss,valid_loss,accuracy,roc_auc_score,time
0,2.117751,1.265273,0.605479,0.803554,00:44


In [None]:
learn.final_record

(#4) [2.117751121520996,1.265272617340088,0.6054794788360596,0.8035537836126292]

## Train Using Cross-Validation

In [None]:
#export
@call_parse
def train_cv(
    epochs:   Param("Number of unfrozen epochs", int), 
    lr:       Param("Initial learning rate", float), 
    frz:      Param("Number of frozen epochs", int)=1, 
    pre:      Param("Presize", int)=800, 
    re:       Param("Resize", int)=256,
    bs:       Param("Batch size", int)=256,  
    smooth:   Param("Label smoothing?", store_true)=False, 
    arch:     Param("Architecture", str)='resnet18', 
    dump:     Param("Print model", store_true)=False, 
    log:      Param("Log w/ W&B", store_true)=False,
    mixup:    Param("Mixup", float)=0.0,
    tta:      Param("Test-time augmentation", store_true)=False,
    fp16:     Param("Use mixed-precision", store_true)=False,
    eval_dir: Param("Evaluate model, save results in dir", Path)=None,
):
    print(locals())
    scores = []
    for fold in range(5):
        print(f"\nTraining on fold {fold}")
        learn = train(epochs, lr, frz=frz, pre=pre, re=re, bs=bs, smooth=smooth, 
                      arch=arch, dump=dump, log=log, fold=fold, mixup=mixup,
                      fp16=fp16,)
        if tta: 
            preds, lbls = learn.tta()
            res = [f(preds, lbls) for f in [learn.loss_func, accuracy, RocAuc()]]
        else: res = learn.final_record
        scores.append(res)
        
        # Create submission file for this model
        if eval_dir: evaluate(learn, Path(eval_dir)/f"predictions_fold_{fold}.csv", tta=True)
        
        # Delete learner to avoid OOM
        del learn
    scores = np.array(scores)
    print(f"Scores: {scores}\n")
    print(f"Mean: {scores.mean(0)}")

In [None]:
scores = np.ones((5, 4))
scores.mean(0)

array([1., 1., 1., 1.])

In [None]:
train_cv(1, 2e-2, eval_dir=".")


Training on fold 0


epoch,train_loss,valid_loss,accuracy,roc_auc_score,time
0,1.440475,5.252672,0.512329,0.838716,00:42


epoch,train_loss,valid_loss,accuracy,roc_auc_score,time
0,0.787104,5.80955,0.523288,0.891827,00:42


AttributeError: 'str' object has no attribute 'mkdir'

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_utils.ipynb.
Converted 01_dataset.ipynb.
Converted 02_evaluate.ipynb.
Converted 03_train.ipynb.
Converted index.ipynb.
