# salad.solver Demo

In this tutorial we will step through the use of solvers within the ``salad`` package.
Solvers are located in the ``salad.solver`` package and form a hierarchy for different application purposes.
All solvers are subclasses of ``salad.solver.Solver``, which contains the basic training mechanisms, functions for logging etc.

In [1]:
import numpy
import torch
from torch import nn
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

The core idea behind solver classes is to provide a clear abstraction between parts that are redundant between deep learning and domain adaptation experiments (such as training loop, logging, checkpointing etc.) from the actual algorithmic contribution of a paper.

By building a hierarchy of solvers, it is also possible to reuse features of base class solvers.
In the following, we will quickly show the main ideas with a simple example.

In [2]:
from salad.solver import Solver

We will first create a simple model and toy dataset

In [10]:
from torch import nn

class SmallModel(nn.Module):
    """ Model for Toy Dataset
    """

    def __init__(self, track_stats = True):
        
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64, track_running_stats = track_stats),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(64,64),
            nn.BatchNorm1d(64, track_running_stats = track_stats),
            nn.ReLU()
        )
        self.classifier = nn.Linear(64, 2)
        self._weight_init()
        
    def parameters(self, d = 0):
        return super().parameters()
        
    def _weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
        
    def forward(self, x, d = None):
        
        z = self.features(x)
        y = self.classifier(z)
        
        return z, y

In [9]:
class MySolver(Solver):
    
    def __init__(self, model, dataset, *args, **kwargs):
        self.model = model
        
        super().__init__(model, dataset, *args, **kwargs)
        
    def _init_models(self, **kwargs):
        super._init_models(**kwargs)
        
        self.register_model?
        
solver = MySolver()

TypeError: __init__() missing 2 required positional arguments: 'model' and 'dataset'