# Implementing a Custom Trainer

Abstract base classes (ABCs) define a blueprint for a class, specifying its methods and attributes, but not its implementation. They are important in implementing a consistent interface, as they enforce a set of requirements on implementing classes and make it easier to write code that can work with multiple implementations.

First, we define a boilerplate for the `TrainerBase` class, which is the same implemented in `archai.api.trainer_base` module.

In [41]:
from abc import abstractmethod

from overrides import EnforceOverrides


class TrainerBase(EnforceOverrides):
    def __init__(self) -> None:
        pass

    @abstractmethod
    def train(self) -> None:
        pass

    @abstractmethod
    def evaluate(self) -> None:
        pass

    @abstractmethod
    def predict(self) -> None:
        pass

## PyTorch-based Trainer

In the context of a custom trainer, using ABCs can help ensure that the provider implements the required methods and provides a consistent interface for training, evaluating and predicting. In this example, we will implement a PyTorch-based trainer, as follows:

In [42]:
from typing import Optional

import torch
from overrides import overrides
from torch.utils.data import Dataset


class PyTorchTrainer(TrainerBase):
    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
    ) -> None:
        super().__init__()

        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        # Setup the trainer
        self._setup()

    def _setup(self) -> None:
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
        self.optimizer.zero_grad()

        outputs = self.model(inputs)

        loss = self.loss_fn(outputs, labels)
        loss.backward()

        self.optimizer.step()

        return loss.item()

    @overrides
    def train(self) -> None:
        total_loss = 0.0

        train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)

        self.model.train()
        for idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.view(inputs.size(0), -1)
            
            total_loss += self._train_step(inputs, labels)

            if idx % 10 == 0:
                print(f"Batch {idx} loss: {total_loss / (idx + 1)}")

    def _eval_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
        with torch.no_grad():
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, labels)

        return loss.item()

    @overrides
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> None:
        eval_dataset = eval_dataset if eval_dataset else self.eval_dataset
        assert eval_dataset is not None, "`eval_dataset` has not been provided."

        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=64, shuffle=False)

        eval_loss = 0.0

        self.model.eval()
        for idx, (inputs, labels) in enumerate(eval_loader):
            inputs = inputs.view(inputs.size(0), -1)

            loss = self._eval_step(inputs, labels)

            eval_loss += loss

        self.model.train()

        eval_loss /= idx

        return eval_loss

    @overrides
    def predict(self, inputs: torch.Tensor) -> None:
        self.model.eval()
        preds = self.model(inputs)
        self.model.train()

        return preds

### Defining the Model

Once the data is loaded, we can define any CV-based model. In this example, we will create a simple linear model using PyTorch:

In [43]:
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(28 * 28, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

model = Model()

### Creating and Training with the Trainer

After loading the data and creating the data, we need to plug these instances into the `PyTorchTrainer` and start the training, as follows:

In [44]:

from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider

dataset_provider = MnistDatasetProvider()
train_dataset = dataset_provider.get_train_dataset()

trainer = PyTorchTrainer(model, train_dataset=train_dataset)
trainer.train()

Batch 0 loss: 2.3260679244995117
Batch 10 loss: 2.170200304551558
Batch 20 loss: 2.013322977792649
Batch 30 loss: 1.88041107116207
Batch 40 loss: 1.7696490898364927
Batch 50 loss: 1.6737279284234141
Batch 60 loss: 1.580565809226427
Batch 70 loss: 1.5032367882594255
Batch 80 loss: 1.4377957427943195
Batch 90 loss: 1.3756403844435137
Batch 100 loss: 1.323841373519142
Batch 110 loss: 1.2777430780299075
Batch 120 loss: 1.2366919005212704
Batch 130 loss: 1.1966573777999587
Batch 140 loss: 1.1582764298357862
Batch 150 loss: 1.1270605416487385
Batch 160 loss: 1.0991555715199584
Batch 170 loss: 1.071832075105076
Batch 180 loss: 1.0470546456002399
Batch 190 loss: 1.0244618591837857
Batch 200 loss: 1.0021289893940313
Batch 210 loss: 0.9831973584059855
Batch 220 loss: 0.9649985725253956
Batch 230 loss: 0.9455169081945956
Batch 240 loss: 0.9282616229720135
Batch 250 loss: 0.911252826689724
Batch 260 loss: 0.894186370674221
Batch 270 loss: 0.8811650495027704
Batch 280 loss: 0.8695434934303854
Batch

### Evaluating and Predicting with the Trainer

Finally, we evaluate our pre-trained model with the validation set and create a set of random-based inputs to calculate the model's predictions:

In [45]:
val_dataset = dataset_provider.get_val_dataset()

eval_loss  = trainer.evaluate(eval_dataset=val_dataset)
print(f"Eval loss: {eval_loss}")

inputs = torch.zeros(1, 28 * 28)
preds = trainer.predict(inputs)
print(f"Predictions: {preds}")

Eval loss: 0.3327318702896054
Predictions: tensor([[-0.1380,  0.1920, -0.0317, -0.0902,  0.0679,  0.1481, -0.0149,  0.0998,
         -0.2689, -0.0635]], grad_fn=<AddmmBackward0>)
