In [None]:
# hide
# skip
!git clone https://github.com/benihime91/gale # install gale on colab
!pip install -e "gale[dev]"

In [None]:
# default_exp core.classes

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Classes
> Interfaces common to all `Modules` and `Models` in Gale.

In [None]:
# export
import copy
import logging
import math
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import *

import hydra
import pytorch_lightning as pl
import torch
import torchmetrics
from fastcore.all import L, noop, patch
from omegaconf import DictConfig, OmegaConf
from torch.nn import Module

from gale.core.nn.optim import OPTIM_REGISTRY, SCHEDULER_REGISTRY
from gale.core.nn.utils import params, trainable_params
from gale.core.utils.logger import log_main_process

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

## Configurable-

In [None]:
# export
class Configurable(ABC):
    """
    Helper Class to instantiate obj from config
    """

    @classmethod
    def from_config_dict(cls, config: DictConfig, **kwargs):
        """
        Instantiates object using `DictConfig-based` configuration. You can optionally
        pass in extra `kwargs`
        """
        # Resolve the config dict
        if isinstance(config, DictConfig):
            config = OmegaConf.to_container(config, resolve=True)
            config = OmegaConf.create(config)

        if "_target_" in config:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config, **kwargs)
        else:
            # instantiate directly using kwargs
            try:
                instance = cls(cfg=config, **kwargs)
            except:
                cfg = OmegaConf.to_container(config, resolve=True)
                instance = cls(**config, **kwargs)

        if not hasattr(instance, "_cfg"):
            instance._cfg = config
        return instance

    def to_config_dict(self) -> DictConfig:
        # fmt: off
        """Returns object's configuration to config dictionary"""
        if hasattr(self, "_cfg") and self._cfg is not None and isinstance(self._cfg, DictConfig):
            # Resolve the config dict
            config = OmegaConf.to_container(self._cfg, resolve=True)
            config = OmegaConf.create(config)
            OmegaConf.set_struct(config, True)
            self._cfg = config

            return self._cfg
        else:
            raise NotImplementedError("to_config_dict() can currently only return object._cfg but current object does not have it.")
        # fmt: on

<IPython.core.display.Javascript object>

This class provides a common interface for modules so that, they can be easy loaded from a Hydra Config file. This class also supports instantiating via hydra.

In [None]:
show_doc(Configurable.from_config_dict)

<h4 id="Configurable.from_config_dict" class="doc_header"><code>Configurable.from_config_dict</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>Configurable.from_config_dict</code>(**`config`**:`DictConfig`, **\*\*`kwargs`**)

Instantiates object using `DictConfig-based` configuration. You can optionally
pass in extra `kwargs`

<IPython.core.display.Javascript object>

In [None]:
show_doc(Configurable.to_config_dict)

<h4 id="Configurable.to_config_dict" class="doc_header"><code>Configurable.to_config_dict</code><a href="__main__.py#L33" class="source_link" style="float:right">[source]</a></h4>

> <code>Configurable.to_config_dict</code>()

Returns object's configuration to config dictionary

<IPython.core.display.Javascript object>

## GaleModule-

In [None]:
# export
class GaleModule(Module, Configurable, metaclass=ABCMeta):
    """
    Abstract class offering interface which should be implemented by all `Backbones`,
    `Heads` and `Meta Archs` in gale.
    """

    @abstractmethod
    def forward(self) -> Any:
        """
        The main logic for the model lives here. Can return either features, logits
        or loss.
        """
        raise NotImplementedError

    @abstractmethod
    def build_param_dicts(self) -> Union[Iterable, List[Dict], Dict, List]:
        """
        Should return the iterable of parameters to optimize or dicts defining parameter groups
        for the Module.
        """
        raise NotImplementedError

    @property
    def param_lists(self):
        "Returns the list of paramters in the module"
        return [p for p in self.parameters()]

    def all_params(self, n=slice(None), with_grad=False):
        "List of `param_groups` upto n"
        res = L(p for p in self.param_lists[n])
        # fmt: off
        return L(o for o in res if hasattr(o, "grad") and o.grad is not None) if with_grad else res
        # fmt: on

    def _set_require_grad(self, rg, p):
        p.requires_grad_(rg)

    def unfreeze(self) -> None:
        """
        Unfreeze all parameters for training.
        """
        for param in self.parameters():
            param.requires_grad = True

        self.train()

    def freeze(self) -> None:
        """
        Freeze all params for inference & set model to eval
        """
        for param in self.parameters():
            param.requires_grad = False
        self.eval()

    def freeze_to(self, n) -> None:
        "Freeze parameter groups up to `n`"
        self.frozen_idx = n if n >= 0 else len(self.param_lists) + n
        if self.frozen_idx >= len(self.param_lists):
            # fmt: off
            _logger.warning(f"Freezing {self.frozen_idx} groups; model has {len(self.param_lists)}; whole model is frozen.")
            # fmt: on

        for o in self.all_params(slice(n, None)):
            self._set_require_grad(True, o)

        for o in self.all_params(slice(None, n)):
            self._set_require_grad(False, o)

    @contextmanager
    def as_frozen(self):
        """
        Context manager which temporarily freezes a module, yields control
        and finally unfreezes the module.
        """
        self.freeze()

        try:
            yield
        finally:
            self.unfreeze()

<IPython.core.display.Javascript object>

Any Module that is Registerd in Gale should inherit from this class or its subclass.

## Internals

In [None]:
show_doc(GaleModule.forward)

<h4 id="GaleModule.forward" class="doc_header"><code>GaleModule.forward</code><a href="__main__.py#L8" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.forward</code>()

The main logic for the model lives here. Can return either features, logits
or loss.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.build_param_dicts)

<h4 id="GaleModule.build_param_dicts" class="doc_header"><code>GaleModule.build_param_dicts</code><a href="__main__.py#L16" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.build_param_dicts</code>()

Should return the iterable of parameters to optimize or dicts defining parameter groups
for the Module.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.from_config_dict)

<h4 id="Configurable.from_config_dict" class="doc_header"><code>Configurable.from_config_dict</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>Configurable.from_config_dict</code>(**`config`**:`DictConfig`, **\*\*`kwargs`**)

Instantiates object using `DictConfig-based` configuration. You can optionally
pass in extra `kwargs`

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.param_lists)

<h4 id="GaleModule.param_lists" class="doc_header"><code>GaleModule.param_lists</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Returns the list of paramters in the module

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.all_params)

<h4 id="GaleModule.all_params" class="doc_header"><code>GaleModule.all_params</code><a href="__main__.py#L29" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.all_params</code>(**`n`**=*`slice(None, None, None)`*, **`with_grad`**=*`False`*)

List of `param_groups` upto n

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.freeze)

<h4 id="GaleModule.freeze" class="doc_header"><code>GaleModule.freeze</code><a href="__main__.py#L48" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.freeze</code>()

Freeze all params for inference & set model to eval

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.freeze_to)

<h4 id="GaleModule.freeze_to" class="doc_header"><code>GaleModule.freeze_to</code><a href="__main__.py#L56" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.freeze_to</code>(**`n`**)

Freeze parameter groups up to `n`

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.unfreeze)

<h4 id="GaleModule.unfreeze" class="doc_header"><code>GaleModule.unfreeze</code><a href="__main__.py#L39" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.unfreeze</code>()

Unfreeze all parameters for training.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleModule.as_frozen)

<h4 id="GaleModule.as_frozen" class="doc_header"><code>GaleModule.as_frozen</code><a href="__main__.py#L70" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleModule.as_frozen</code>()

Context manager which temporarily freezes a module, yields control
and finally unfreezes the module.

<IPython.core.display.Javascript object>

## OptimSchedBuilder-

In [None]:
# export
class OptimSchedBuilder:
    """
    Interface that constructs an Optimizer and Scheduler a from config.
    """

    _train_dl: Callable
    _trainer: pl.Trainer
    optimization_cfg: DictConfig

<IPython.core.display.Javascript object>

In [None]:
# collapse-output
from dataclasses import dataclass, field

from fastcore.all import Path
from nbdev.export import Config
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from gale.config import get_config

data_path = Path(Config().path("nbs_path")) / "data"
dset = FashionMNIST(root=data_path, download=True)

cfg = get_config()

<IPython.core.display.Javascript object>

In [None]:
# collapse-output
builder = OptimSchedBuilder()
builder._train_dl = DataLoader(dset, batch_size=32)
builder._trainer = pl.Trainer(max_epochs=10, accumulate_grad_batches=1)

print(OmegaConf.to_yaml(cfg.optimization))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


optimizer:
  name: AdamW
  init_args:
    betas:
    - 0.95
    - 0.999
    eps: 1.0e-05
    weight_decay: 0.01
    amsgrad: false
scheduler:
  name: OneCycleLR
  init_args:
    max_lr: -1
    total_steps: null
    epochs: -1
    steps_per_epoch: null
    pct_start: 0.3
    anneal_strategy: cos
    cycle_momentum: true
    base_momentum: 0.85
    max_momentum: 0.95
    div_factor: 25.0
    final_div_factor: 10000.0
  interval: step
  monitor: null
steps_per_epoch: -1
max_steps: -1
max_epochs: -1



<IPython.core.display.Javascript object>

`OptimSchedBuilder` is used in to instantite the `optimizer` and the `lr_scheduler (optional)` given a OmegaConf config that follows the structure of gale default config ...

A generic optimzation config structure for image classification is shown in above. Notice how some values are `-1`. These values are automatically computed during the training process. Also requested optimizers and schedulers should be present in `OPTIM_REGISTRY` and `SCHEDULER_REGISTRY` resp.

In [None]:
# export
@patch
def prepare_optimization_config(self: OptimSchedBuilder, config: DictConfig):
    """
    Prepares `OptimizationConfig` config and adds some interval
    values and infers values like max_steps, max_epochs, etc.

    This method also fills in the values for `max_iters` & `epochs`, `steps_per_epoch`
    which are required by some of the LearningRate Schedulers.
    """
    opt_config = copy.deepcopy(config)
    self.optimization_cfg = opt_config

    self.optimization_cfg["steps_per_epoch"] = len(self._train_dl)

    if self._trainer.max_epochs is None and self._trainer.max_steps is None:
        msg = "Either one of max_epochs or max_epochs must be provided in Trainer"
        log_main_process(_logger, logging.ERROR, msg)
        raise ValueError

    # compute effective num training steps
    # fmt: off
    if isinstance(self._trainer.limit_train_batches, int) and self._trainer.limit_train_batches != 0:
    # fmt: on
        dataset_size = self.trainer.limit_train_batches
    
    elif isinstance(self._trainer.limit_train_batches, float):
        # limit_train_batches is a percentage of batches
        dataset_size = len(self._train_dl)
        dataset_size = int(dataset_size * self._trainer.limit_train_batches)
    
    else:
        dataset_size = len(self._train_dl)

    num_devices = max(1, self._trainer.num_gpus, self._trainer.num_processes)

    if self._trainer.tpu_cores:
        num_devices = max(num_devices, self._trainer.tpu_cores)

    effective_batch_size = self._trainer.accumulate_grad_batches * num_devices
    max_steps = (dataset_size // effective_batch_size) * self._trainer.max_epochs

    if self._trainer.max_steps is None:
        self.optimization_cfg["max_epochs"] = self._trainer.max_epochs
        self.optimization_cfg["max_steps"] = max_steps

    else:
        epochs = self._trainer.max_steps * len(self._train_dl)
        self.optimization_cfg["max_steps"] = self._trainer.max_steps
        self.optimization_cfg["max_epochs"] = epochs

    # covert config to Dictionary
    # fmt: off
    sched_config = OmegaConf.to_container(self.optimization_cfg.scheduler.init_args, resolve=True)

    max_steps = self.optimization_cfg["max_steps"]
    max_epochs = self.optimization_cfg["max_epochs"]
    steps = self.optimization_cfg["steps_per_epoch"]

    # populate values in learning rate schedulers
    if "max_iters" in sched_config:
        if sched_config["max_iters"] == -1:
            OmegaConf.update(self.optimization_cfg, "scheduler.init_args.max_iters", max_steps)
            msg = f"Set the value of 'max_iters' to be {max_steps}."
            log_main_process(_logger, logging.DEBUG, msg)

    if "epochs" in sched_config:
        if sched_config["epochs"] == -1:
            OmegaConf.update(self.optimization_cfg, "scheduler.init_args.epochs", max_epochs)
            msg = f"Set the value of 'epochs' to be {max_epochs}."
            log_main_process(_logger, logging.DEBUG, msg)

    if "steps_per_epoch" in sched_config:
        if sched_config["steps_per_epoch"] is None:
            OmegaConf.update(self.optimization_cfg, "scheduler.init_args.steps_per_epoch", steps)
            msg = f"Set the value of 'steps_per_epoch' to be {steps}."
            log_main_process(_logger, logging.DEBUG, msg)
    # fmt: on

<IPython.core.display.Javascript object>

In [None]:
builder.prepare_optimization_config(config=cfg.optimization)

<IPython.core.display.Javascript object>

In [None]:
# export
@patch
def build_optimizer(self: OptimSchedBuilder, params: Any) -> torch.optim.Optimizer:
    """
    Builds a single optimizer from `OptimizationConfig`. `params` are the parameter
    dict with the weights for the optimizer to optimizer.

    Note this method must be called after `prepare_optimization_config()`
    """
    if not isinstance(self.optimization_cfg, DictConfig):
        msg = "optimization_cfg not found, did you call `prepare_optimization_config`."
        log_main_process(_logger, logging.WARNING, msg)
        raise NameError
    else:
        if self.optimization_cfg.optimizer.name is None:
            msg = "Optimizer is None, so no optimizer will be created."
            log_main_process(_logger, logging.WARNING, msg)
            opt = None
        else:
            opt = self.optimization_cfg.optimizer
            opt = OPTIM_REGISTRY.get(opt.name)(params=params, **opt.init_args)
            msg = f"Built optimizer, {opt.__class__.__name__} with {len(opt.param_groups)} param group(s)."
            log_main_process(_logger, logging.DEBUG, msg)
        return opt

<IPython.core.display.Javascript object>

In [None]:
# export
@patch
def build_lr_scheduler(
    self: OptimSchedBuilder, optimizer: torch.optim.Optimizer
) -> Any:
    """
    Builds a LearningRate scheduler from `OptimizationConfig`. Returns an LRScheduler dict
    that is required by PyTorch Lightning for LRSchedulers.
    Note this method must be called after `prepare_optimization_config()`
    """
    if not isinstance(self.optimization_cfg, DictConfig):
        msg = "optimization_cfg not found, did you call `prepare_optimization_config`."
        log_main_process(_logger, logging.WARNING, msg)
        raise NameError
    else:
        if self.optimization_cfg.scheduler.name is None:
            msg = "scheduler is None, so no scheduler will be created."
            log_main_process(_logger, logging.WARNING, msg)
            sched = None
        else:
            _c = self.optimization_cfg.scheduler.init_args
            _temp = OmegaConf.to_container(_c, resolve=True)
            kwds = {}

            # if a key value is ListConfig then we convert it to simple list
            for key, value in _temp.items():
                if isinstance(value, list):
                    kwds[key] = list(value)
                else:
                    kwds[key] = value

            instance = SCHEDULER_REGISTRY.get(self.optimization_cfg.scheduler.name)
            sch = instance(optimizer=optimizer, **kwds)

            # convert the lr_scheduler to pytorch-lightning LRScheduler dictionary format
            msg = f"LRScheduler : {sch.__class__.__name__}."
            log_main_process(_logger, logging.DEBUG, msg)
            sched = {
                "scheduler": sch,
                "interval": self.optimization_cfg.scheduler.interval,
                "monitor": self.optimization_cfg.scheduler.monitor,
            }
            return sched

<IPython.core.display.Javascript object>

In [None]:
params = [torch.nn.Parameter(torch.randn(1, 2))]
optimizer = builder.build_optimizer(params)
assert isinstance(optimizer, torch.optim.AdamW)

# for onecycle lrs; we need max_lrs
builder.optimization_cfg.scheduler.init_args.max_lr = [1e-03]
scheduler = builder.build_lr_scheduler(optimizer)
assert isinstance(scheduler, Dict)
assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.OneCycleLR)

<IPython.core.display.Javascript object>

## GaleTask -

In [None]:
# export
class GaleTask(pl.LightningModule, OptimSchedBuilder, metaclass=ABCMeta):
    """
    Interface for Pytorch-lightning based Gale modules
    """

    def __init__(
        self,
        cfg: DictConfig,
        trainer: Optional[pl.Trainer] = None,
        metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
    ):
        """
        Base class from which all PyTorch Lightning Tasks in Gale should inherit.

        Arguments:
        1. `cfg` `(DictConfig)`:  configuration object. cfg object should be inherited from `BaseGaleConfig`.
        2. `trainer` `(Optional, pl.Trainer)`: Pytorch Lightning Trainer instance
        3. `metrics` `(Optional)`: Metrics to compute for training and evaluation.
        """
        super().__init__()
        self._cfg = OmegaConf.structured(cfg)

        self.save_hyperparameters(self._cfg)
        self._train_dl = noop
        self._validation_dl = noop
        self._test_dl = noop
        self._optimizer = noop
        self._scheduler = noop
        self._trainer = trainer
        self.metrics = metrics

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        "Returns the Dataloader used for Training"
        if self._train_dl is not None and self._train_dl is not noop:
            return self._train_dl

    def val_dataloader(self) -> Any:
        "Returns the List of Dataloaders or Dataloader used for Validation"
        if self._validation_dl is not None and self._validation_dl is not noop:
            return self._validation_dl

    def test_dataloader(self) -> Any:
        "Returns the List of Dataloaders or Dataloader used for Testing"
        if self._test_dl is not None and self._test_dl is not noop:
            return self._test_dl

    @abstractmethod
    def forward(self, x: torch.Tensor):
        """
        The Forward method for LightningModule, users should modify this method.
        """
        raise NotImplementedError

    @abstractmethod
    def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
        """
        Setups data loader to be used in training

        Arguments:
        1. `train_data_config`: training data loader parameters.
        """
        raise NotImplementedError

    @abstractmethod
    def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
        """
        Setups data loader to be used in validation

        Arguments:
        1. `val_data_config`: validation data loader parameters.
        """
        raise NotImplementedError

    def setup_test_data(
        self, test_data_config: Optional[Union[DictConfig, Dict]] = None
    ):
        """
        (Optionally) Setups data loader to be used in test

        Arguments:
        1. `test_data_config`: test data loader parameters.
        """
        raise NotImplementedError

    @property
    def param_dicts(self) -> Union[Iterator, List[Dict]]:
        """
        Property that returns the param dicts for optimization.
        Override for custom training behaviour. Currently returns all the trainable paramters.
        """
        return L(self).map(trainable_params)

    def shared_step(self, batch: Any, batch_idx: int, stage: str) -> Any:
        """
        The common training/validation/test step. Override for custom behavior. This step
        is shared between training/validation/test step. For training/validation/test steps
        `stage` is train/val/test respectively. You training logic should go here avoid directly overriding
        training/validation/test step methods.
        """
        raise NotImplementedError

    def training_step(self, batch: Any, batch_idx: int) -> Any:
        """
        The training step of the LightningModule. For common use cases you need
        not need to override this method. See `GaleTask.shared_step()`
        """
        return self.shared_step(batch, batch_idx, stage="train")

    def validation_step(self, batch: Any, batch_idx: int) -> None:
        """
        The validation step of the LightningModule. For common use cases you need
        not need to override this method. See `GaleTask.shared_step()`
        """
        return self.shared_step(batch, batch_idx, stage="val")

    def test_step(self, batch: Any, batch_idx: int) -> None:
        """
        The test step of the LightningModule. For common use cases you need
        not need to override this method. See `GaleTask.shared_step()`
        """
        return self.shared_step(batch, batch_idx, stage="test")

    def setup_optimization(self, optim_config: DictConfig = None):
        """
        Prepares an optimizer from a string name and its optional config parameters.

        Args:
        1. `optim_config`: A `dictionary`/`DictConfig` or instance of `OptimizationConfig`.
        """
        # If config was not explicitly passed to us
        if optim_config is None:
            # See if internal config has `optim` namespace
            if self._cfg is not None and hasattr(self._cfg, "optimization"):
                optim_config = self._cfg.optimization

        # If config is still None, or internal config has no Optim, return without instantiation
        if optim_config is None:
            msg = "No optimizer config provided, therefore no optimizer was created"
            log_main_process(_logger, logging.WARNING, msg)
            return

        else:
            # Preserve the configuration
            if not isinstance(optim_config, DictConfig):
                optim_config = OmegaConf.create(optim_config)

            # prepare the optimization config
            self.prepare_optimization_config(optim_config)

            # Setup optimizer and scheduler
            self._optimizer = self.build_optimizer(self.param_dicts)
            self._scheduler = self.build_lr_scheduler(self._optimizer)

    def configure_optimizers(self):
        """
        Choose what optimizers and learning-rate schedulers to use in your optimization.
        See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html
        """
        # if self.setup_optimization() has been called manually no
        # need to call again
        if self._optimizer is noop and self._scheduler is noop:
            self.setup_optimization()

        if self._scheduler is None:
            return self._optimizer
        else:
            return [self._optimizer], [self._scheduler]

<IPython.core.display.Javascript object>

## Methods and Properties 

In [None]:
show_doc(GaleTask.forward)

<h4 id="GaleTask.forward" class="doc_header"><code>GaleTask.forward</code><a href="__main__.py#L48" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.forward</code>(**`x`**:`Tensor`)

The Forward method for LightningModule, users should modify this method.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.train_dataloader)

<h4 id="GaleTask.train_dataloader" class="doc_header"><code>GaleTask.train_dataloader</code><a href="__main__.py#L33" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.train_dataloader</code>()

Returns the Dataloader used for Training

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.val_dataloader)

<h4 id="GaleTask.val_dataloader" class="doc_header"><code>GaleTask.val_dataloader</code><a href="__main__.py#L38" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.val_dataloader</code>()

Returns the List of Dataloaders or Dataloader used for Validation

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.test_dataloader)

<h4 id="GaleTask.test_dataloader" class="doc_header"><code>GaleTask.test_dataloader</code><a href="__main__.py#L43" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.test_dataloader</code>()

Returns the List of Dataloaders or Dataloader used for Testing

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.setup_training_data)

<h4 id="GaleTask.setup_training_data" class="doc_header"><code>GaleTask.setup_training_data</code><a href="__main__.py#L55" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.setup_training_data</code>(**`train_data_config`**:`Union`\[`DictConfig`, `Dict`\])

Setups data loader to be used in training

Arguments:
1. `train_data_config`: training data loader parameters.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.setup_validation_data)

<h4 id="GaleTask.setup_validation_data" class="doc_header"><code>GaleTask.setup_validation_data</code><a href="__main__.py#L65" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.setup_validation_data</code>(**`val_data_config`**:`Union`\[`DictConfig`, `Dict`\])

Setups data loader to be used in validation

Arguments:
1. `val_data_config`: validation data loader parameters.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.setup_test_data)

<h4 id="GaleTask.setup_test_data" class="doc_header"><code>GaleTask.setup_test_data</code><a href="__main__.py#L75" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.setup_test_data</code>(**`test_data_config`**:`Union`\[`DictConfig`, `Dict`, `NoneType`\]=*`None`*)

(Optionally) Setups data loader to be used in test

Arguments:
1. `test_data_config`: test data loader parameters.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.shared_step)

<h4 id="GaleTask.shared_step" class="doc_header"><code>GaleTask.shared_step</code><a href="__main__.py#L94" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.shared_step</code>(**`batch`**:`Any`, **`batch_idx`**:`int`, **`stage`**:`str`)

The common training/validation/test step. Override for custom behavior. This step
is shared between training/validation/test step. For training/validation/test steps
`stage` is train/val/test respectively. You training logic should go here avoid directly overriding
training/validation/test step methods.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.training_step)

<h4 id="GaleTask.training_step" class="doc_header"><code>GaleTask.training_step</code><a href="__main__.py#L103" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.training_step</code>(**`batch`**:`Any`, **`batch_idx`**:`int`)

The training step of the LightningModule. For common use cases you need
not need to override this method. See [`GaleTask.shared_step()`](/gale/core.classes.html#GaleTask.shared_step())

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.validation_step)

<h4 id="GaleTask.validation_step" class="doc_header"><code>GaleTask.validation_step</code><a href="__main__.py#L110" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.validation_step</code>(**`batch`**:`Any`, **`batch_idx`**:`int`)

The validation step of the LightningModule. For common use cases you need
not need to override this method. See [`GaleTask.shared_step()`](/gale/core.classes.html#GaleTask.shared_step())

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.test_step)

<h4 id="GaleTask.test_step" class="doc_header"><code>GaleTask.test_step</code><a href="__main__.py#L117" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.test_step</code>(**`batch`**:`Any`, **`batch_idx`**:`int`)

The test step of the LightningModule. For common use cases you need
not need to override this method. See [`GaleTask.shared_step()`](/gale/core.classes.html#GaleTask.shared_step())

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.param_dicts)

<h4 id="GaleTask.param_dicts" class="doc_header"><code>GaleTask.param_dicts</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Property that returns the param dicts for optimization.
Override for custom training behaviour. Currently returns all the trainable paramters.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.configure_optimizers)

<h4 id="GaleTask.configure_optimizers" class="doc_header"><code>GaleTask.configure_optimizers</code><a href="__main__.py#L155" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.configure_optimizers</code>()

Choose what optimizers and learning-rate schedulers to use in your optimization.
See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html

<IPython.core.display.Javascript object>

In [None]:
show_doc(GaleTask.setup_optimization)

<h4 id="GaleTask.setup_optimization" class="doc_header"><code>GaleTask.setup_optimization</code><a href="__main__.py#L124" class="source_link" style="float:right">[source]</a></h4>

> <code>GaleTask.setup_optimization</code>(**`optim_config`**:`DictConfig`=*`None`*)

Prepares an optimizer from a string name and its optional config parameters.

Args:
1. `optim_config`: A `dictionary`/`DictConfig` or instance of `OptimizationConfig`.

<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script()

Converted 00_core.utils.logger.ipynb.
Converted 00a_core.utils.visualize.ipynb.
Converted 00b_core.utils.structures.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 02_core.nn.optim.optimizers.ipynb.
Converted 02a_core.nn.optim.lr_schedulers.ipynb.
Converted 03_core.classes.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 04a_classification.modelling.heads.ipynb.
Converted 04b_classification.modelling.meta_arch.common.ipynb.
Converted 04b_classification.modelling.meta_arch.vit.ipynb.
Converted 05_classification.data.common.ipynb.
Converted 05a_classification.data.transforms.ipynb.
Converted 05b_classification.data.build.ipynb.
Converted 06_classification.task.ipynb.
Converted 07_collections.pandas.ipynb.
Converted 07a_collections.callbacks.notebook.ipynb.
Converted 07b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>