In [1]:
# default_exp models.task

In [2]:
from nbdev.showdoc import *
from nbdev.export import *

In [3]:
# export
import warnings
from collections import namedtuple
from typing import Any, Dict, List, Tuple

import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.metrics import Accuracy
from torch.utils.data import DataLoader

from src import _logger
from src.data import DatasetMapper
from src.models.builder import Net
from src.optimizers import create_optimizer
from src.schedulers import create_scheduler

warnings.filterwarnings("ignore")

In [4]:
# export
class Task(pl.LightningModule):
    "A general Task for Cassave Leaf Disease Classification"
    
    def __init__(self, conf: DictConfig):
        super().__init__()
        
        self.trn_metric = Accuracy()
        self.val_metric = Accuracy()
        self.tst_metric = Accuracy()
        self.save_hyperparameters(conf)
        
        # instantiate objects
        self.model = Net(self.hparams)
        self.criterion   = instantiate(self.hparams.loss)
        self.mixfunction = instantiate(self.hparams.mixmethod)

        _logger.info(f"LossFunction: {self.criterion}")
        if self.mixfunction is not None:
            _logger.info(f"Training with {self.mixfunction}.")
        
        self.lrs= None
            
    def setup(self, stage: str):
        "setups datasetMapper"
        mapper = DatasetMapper(self.hparams)
        mapper.generate_datasets()
        
        # Loads in the repective datasets
        self.train_dset = mapper.get_test_dataset()
        self.valid_dset = mapper.get_valid_dataset()
        self.test_dset  = mapper.get_test_dataset()
        
        # Loads in the transformations to be applied after mixmethod
        self.final_augs = mapper.get_transforms()

    def forward(self, x: Any) -> Any:
        "call the model"
        return self.model(x)
    
    def training_step(self, batch: Any, batch_idx: int) -> Any:
        "The Training Step: This is where the Magic Happens !!!"
        imgs, targs = batch
        self.preds, self.labels = None, None
        
        if self.mixfunction is not None:
            if self.current_epoch < self.hparams.training.mix_epochs :
                imgs = self.mixfunction(imgs, targs, model=self.model)
                logits= self.forward(imgs)
                loss= self.mixfunction.lf(logits, loss_func=self.criterion)
                acc = self.trn_metric(logits, targs)
            else:
                logits = self.forward(imgs)
                loss   = self.criterion(logits, targs)
                acc    = self.trn_metric(logits, targs)
        
        else:
            logits = self.forward(imgs)
            loss   = self.criterion(logits, targs)
            acc    = self.trn_metric(logits, targs)
        
        preds  = torch.argmax(logits, 1)
        self.labels = list(targs.data.cpu().numpy())
        self.preds  = list(preds.data.cpu().numpy())
        
        result_dict = {"train/loss": loss, "train/acc": acc}
        self.log_dict(result_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch: Any, batch_idx: int) -> None:
        "The Validation Step"
        imgs, targs = batch
        self.preds, self.labels = None, None
        
        logits = self.forward(imgs)
        loss = self.criterion(logits, targs)
        acc = self.val_metric(logits, targs)
        
        preds  = torch.argmax(logits, 1)
        self.labels = list(targs.data.cpu().numpy())
        self.preds  = list(preds.data.cpu().numpy())
        
        result_dict = {"valid/loss": loss, "valid/acc": acc} 
        self.log_dict(result_dict)
    
    def test_step(self, batch: Any, batch_idx: int) -> None:
        "The Test Step"
        imgs, targs = batch
        self.preds, self.labels = None, None
        
        logits = self.forward(imgs)
        loss = self.criterion(logits, targs)
        acc = self.tst_metric(logits, targs)
        
        preds  = torch.argmax(logits, 1)
        self.labels = list(targs.data.cpu().numpy())
        self.preds  = list(preds.data.cpu().numpy())
        
        result_dict = {"test/loss": loss, "test/acc": acc} 
        self.log_dict(result_dict)
    
    def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]:
        
        lrs = (self.hparams.training.learning_rate/self.hparams.training.lr_mult, 
               self.hparams.training.learning_rate)
        
        lr_tuple = namedtuple("LearningRates", ["base", "head"])
        
        self.lrs = lr_tuple(lrs[0], lrs[1])
        
        epochs  = self.hparams.training.num_epochs
        steps   = len(self.train_dataloader())/ self.hparams.training.accumulate_grad_batches
        
        total_params = self.model.get_param_list()
        params = [
            {"params": total_params[0], "lr":lrs[0]}, 
            {"params": total_params[1], "lr":lrs[1]},
        ]
        
        optim = create_optimizer(self.hparams.optimizer, params=params)
        sched = create_scheduler(self.hparams.scheduler, optim, steps, epochs)
        return [optim], [sched]
        
    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        "returns a PyTorch DataLoader for Training"
        if self.current_epoch == self.hparams.training.mix_epochs:
            if self.mixfunction is not None:
                name = self.mixfunction.__class__.__name__
                _logger.info(f"Train [ {self.current_epoch}/{self.trainer.max_epochs}]: Stopping {name} !")
                self.mixfunction.stop()
            
            self.train_dset.reload_transforms(self.final_augs)
            dataloader = torch.utils.data.DataLoader(self.train_dset, **self.hparams.data.dataloader)
        else:
            dataloader = torch.utils.data.DataLoader(self.train_dset, **self.hparams.data.dataloader)
        return dataloader
    
    def val_dataloader(self, *args, **kwargs) -> DataLoader:
        "returns a PyTorch DataLoader for Validation"
        return torch.utils.data.DataLoader(self.valid_dset, **self.hparams.data.dataloader)
    
    def test_dataloader(self, *args, **kwargs) -> DataLoader:
        "returns a PyTorch DataLoader for Testing"
        return torch.utils.data.DataLoader(self.test_dset, **self.hparams.data.dataloader)

In [5]:
import os
from hydra.experimental import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
from pytorch_lightning import Trainer

In [6]:
overrides = [
    "input.input_size=120",
    "training.batch_size=5",
    "augmentations=tfms-v0",
    "data.dataloader.num_workers=0",
    "general=default",
    "trainer=fast-dev-cpu",
    "optimizer=ranger",
    "mixmethod=snapmix",
    "training.mix_epochs=1",
    "training.batch_size=64",
    "model=v0",
    "model.base_model.activation=mish",
    "model.head.params.act_layer=mish",
    "training.accumulate_grad_batches=1",
    "loss=crossentropy",
]

with initialize(config_path=os.path.relpath("../conf/")):
    cfg = compose(config_name="effnet-base", overrides=overrides)

In [8]:
model = Task(cfg)

[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]: Configuration for the current model :
[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]:  feature_extractor: tf_efficientnet_b3_ns
[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]:  activation: mish
[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]:  params: {'drop_path_rate': 0.25}
[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]:  head: CnnHeadV0
[[32m01/31 20:52:01[0m [35msrc.models.builder[0m]:  params: {'n_out': 5, 'pool_type': 'avg', 'use_conv': False, 'act_layer': 'mish'}
[[32m01/31 20:52:02[0m [35m__main__[0m]: LossFunction: CrossEntropyLoss()
[[32m01/31 20:52:02[0m [35m__main__[0m]: Training with Snapmix(alpha=5.0, conf_prob=0.5, num_iters=1).


In [7]:
trainer = Trainer(reload_dataloaders_every_epoch=True, 
                  limit_train_batches=2, limit_test_batches=2, 
                  limit_val_batches=2, weights_summary=None, 
                  accumulate_grad_batches=1, num_sanity_val_steps=1)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores


In [9]:
trainer.fit(model)

[[32m01/31 20:52:04[0m [35msrc.data.dataset_factory[0m]: Generating Datasets for FOLD :0
[[32m01/31 20:52:05[0m [35msrc.data.dataset_factory[0m]: Train Dataset has 17117, Validation Dataset has 17117 instances.
[[32m01/31 20:52:05[0m [35msrc.optimizers[0m]: Ranger loaded from OPTIM_REGISTERY
[[32m01/31 20:52:05[0m [35msrc.schedulers[0m]: FlatCosScheduler loaded from SCHEDULER_REGISTERY


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

[[32m01/31 20:52:48[0m [35m__main__[0m]: Train [ 1/1000]: Stopping Snapmix !


Validating: 0it [00:00, ?it/s]

1

In [10]:
notebook2script()

Converted 00_core.ipynb.
Converted 01a_data.datasets.ipynb.
Converted 01b_data.datasests_factory.ipynb.
Converted 01c_data.mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03a_optimizers.ipynb.
Converted 03b_schedulers.ipynb.
Converted 04a_models.utils.ipynb.
Converted 04b_models.layers.ipynb.
Converted 04c_models.classifiers.ipynb.
Converted 04d_models.builder.ipynb.
Converted 04e_models.task.ipynb.
Converted 05_callbacks.ipynb.
Converted index.ipynb.
