In [None]:
# default_exp train

# Train

> This module contains a script to train a model.

In [None]:
#hide
from nbdev.showdoc import *
from plant_pathology.config import DATA_PATH
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#export
from plant_pathology.dataset import *
from plant_pathology.evaluate import *

from fastai.vision.all import *
from fastcore.script import *
from fastai.callback.wandb import *
from wwf.vision.timm import *
import timm
import wandb
from typing import *
from sys import exit

## Train a Model on Data Split

In [None]:
#export
def timm_or_fastai_arch(arch: str) -> (Union[Any, str], Callable[..., Learner]):
    try:  # Check if fastai arch
        model = globals()[arch]
        learner_func = cnn_learner
    except KeyError:  # Must be timm arch
        model = arch
        learner_func = timm_learner
    return model, learner_func

In [None]:
#export
def train(
    data_path: Path, epochs: int = 1, lr: Union[float, str] = 3e-4, frz: int=1, pre: int=800, re: int=256,
    bs: int=200, fold: int=4, smooth: bool=False,
    arch: str='resnet18', dump: bool=False, log: bool=False, mixup: float=0.,
    fp16: bool=False, dls: DataLoaders=None, save: bool=False, pseudo: Path=None,
 ) -> Learner:
    # Prep Data, Opt, Loss, Arch
    if dls is None:
        dls = get_dls_all_in_1(
            data_path=data_path, presize=pre, resize=re, bs=bs, val_fold=fold, pseudo_labels_path=pseudo
        )
    if log: wandb.init(project="plant-pathology")
    if smooth: loss_func = LabelSmoothingCrossEntropyFlat()
    else:      loss_func = CrossEntropyLossFlat()
    m, learner_func = timm_or_fastai_arch(arch)

    # Add callbacks
    cbs = [SaveModelCallback("roc_auc_score", fname=f"model_val_on_{fold}")] if save or log else []
    if log: cbs.append(WandbCallback())
    if mixup: cbs.append(MixUp(mixup))

    # Build learner
    print(f"# train exs: {len(dls.train_ds)}, val exs: {len(dls.valid_ds)}")
    learn = learner_func(dls, m, loss_func=loss_func,
                    metrics=[accuracy, RocAuc()], cbs=cbs)
    if dump: print(learn.model); exit()
    if lr=="find": learn.lr_find(); exit()
    if fp16: learn.to_fp16()

    # Train
    learn.freeze()
    learn.fit_one_cycle(frz, lr)
    learn.unfreeze()
    learn.fit_one_cycle(epochs, slice(lr/100, lr/2))  # Explore other divs
    return learn

In [None]:
#slow
#hide
learn = train(DATA_PATH, epochs=0, lr=0.001, bs=256, log=False)

# train exs: 1457, val exs: 364


epoch,train_loss,valid_loss,accuracy,roc_auc_score,time


KeyboardInterrupt: 

## Train Using Cross-Validation

In [None]:
#export
def softmax_RocAuc(logits, labels):
    probs = logits.softmax(-1)
    return RocAuc()(probs, labels)

In [None]:
#hide
preds = torch.randn(2, 4)
labels = tensor([1, 2, 3, 4]).unsqueeze(-1)
preds, labels.shape

(tensor([[-0.2578, -0.3262,  1.1754, -0.7382],
         [-1.1492,  0.2354, -1.1844, -1.6058]]),
 torch.Size([4, 1]))

In [None]:
#export
@call_parse
def train_cv(
    path:     Param("Path to data dir", Path),
    epochs:   Param("Number of unfrozen epochs", int)=1,
    lr:       Param("Initial learning rate", float)=3e-4,
    frz:      Param("Number of frozen epochs", int)=1,
    pre:      Param("Image presize", int, nargs="+")=(682, 1024),
    re:       Param("Image resize", int)=256,
    bs:       Param("Batch size", int)=256,
    smooth:   Param("Label smoothing?", store_true)=False,
    arch:     Param("Architecture", str)='resnet18',
    dump:     Param("Don't train, just print model", store_true)=False,
    log:      Param("Log w/ W&B", store_true)=False,
    save:     Param("Save model based on RocAuc", store_true)=False,
    mixup:    Param("Mixup (0.4 is good)", float)=0.0,
    tta:      Param("Test-time augmentation", store_true)=False,
    fp16:     Param("Mixed-precision training", store_true)=False,
    do_eval: Param("Evaluate model and save predictions CSV", store_true)=False,
    val_fold: Param("Don't go cross-validation, just do 1 fold (or pass 9 "
                    "to train on all data)", int)=None,
    pseudo:   Param("Path to pseudo labels to train on", Path)=None,
    export:   Param("Export learner(s) to export_val_on_{fold}.pkl", store_true)=False,
):
    print(locals())
    scores = []
    for fold in range(5):
        if val_fold is not None: fold = val_fold  # Not doing CV
        print(f"\nTraining on fold {fold}")
        learn = train(data_path=path, epochs=epochs, lr=lr, frz=frz, pre=pre,
                      re=re, bs=bs, smooth=smooth, arch=arch, dump=dump, log=log,
                      fold=fold, mixup=mixup, fp16=fp16, save=save, pseudo=pseudo)

        if hasattr(learn, "mixup") and tta: learn.remove_cb(MixUp)  # Bug when doing tta w/Mixup

        if tta and val_fold != 9:  # There IS a valid set
            preds, lbls = learn.tta()
            res = [f(preds, lbls) for f in [learn.loss_func, accuracy, softmax_RocAuc]]
        else: res = learn.final_record
        scores.append(res)

        # Create submission file for this model
        if do_eval: print("Evaluating"); evaluate(learn, path=path/"test.csv", name=f"predictions_fold_{fold}.csv", tta=tta)

        if export: learn.export(f"export_val_on_{fold}.pkl")
        # Delete learner to avoid OOM
        del learn
        if val_fold is not None: break
    scores = np.array(scores)
    print(f"Scores: {scores}\n")
    if val_fold is None: print(f"Mean: {scores.mean(0)}")

In [None]:
scores = np.ones((5, 4))
scores.mean(0)

array([1., 1., 1., 1.])

In [None]:
#slow
#hide
train_cv(DATA_PATH, 0, 2e-2, pre=64, re=64, bs=512, fp16=True, val_fold=4, tta=True, mixup=0.4, do_eval=True, export=True)

{'path': Path('../data'), 'epochs': 0, 'lr': 0.02, 'frz': 1, 'pre': 64, 're': 64, 'bs': 512, 'smooth': False, 'arch': 'resnet18', 'dump': False, 'log': False, 'save': False, 'mixup': 0.4, 'tta': True, 'fp16': True, 'do_eval': True, 'val_fold': 4, 'pseudo': None, 'export': True}

Training on fold 4
# train exs: 1457, val exs: 364


epoch,train_loss,valid_loss,accuracy,roc_auc_score,time
0,2.303842,10.276655,0.332418,0.583753,00:22


  warn("Your generator is empty.")


Evaluating


  warn("Your generator is empty.")


Scores: [[1.387954   0.34065935 0.59147109]]



In [None]:
#slow
#hide
# Check predictions CSV was saved
preds_path = Path("predictions_fold_4.csv")
assert preds_path.exists(), "Predictions CSV not saved properly"
preds_path.unlink()

# Check Learner was exported properly
export_path = Path("export_val_on_4.pkl")
assert export_path.exists(), "Learner not exported properly"
export_path.unlink()

In [None]:
#hide
from nbdev.export import *; notebook2script()

Converted 00_utils.ipynb.
Converted 01_dataset.ipynb.
Converted 02_evaluate.ipynb.
Converted 03_train.ipynb.
Converted 04_generate_pseudo_labels.ipynb.
Converted 05_self_knowledge_distillation.ipynb.
Converted 06_create_folds.ipynb.
Converted 07_pretrained_models.ipynb.
Converted Untitled.ipynb.
Converted config.ipynb.
Converted index.ipynb.
