In [1]:
import math
import numpy as np

import torch
from torch import nn
from torch.nn import Module
from torch import Tensor, einsum
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
import torch.nn.functional as F

from torchmetrics import Metric as TorchMetric

import einops
from einops import rearrange

from operations.data import generate_dataset
from operations.data import generate_dataloader
from operations.embeds import Embedding
from operations.model import NewGELU
from operations.utils import generate_splits
from operations.utils import preprocess
from operations.utils import CutMix, Mixup

from typing import Optional, Dict, List, Tuple, Union, Any
from argparse import Namespace
from abc import ABC, abstractmethod

In [9]:
# metrics
class Metric(object):
    def __init__(self):
        self._name = ""

    def reset(self):
        raise NotImplementedError("Custom Metrics must implement this function")

    def __call__(self, y_pred: Tensor, y_true: Tensor):
        raise NotImplementedError("Custom Metrics must implement this function")


class MultipleMetrics(object):
    def __init__(self, metrics: List[Metric], prefix: str = ""):

        instantiated_metrics = []
        for metric in metrics:
            if isinstance(metric, type):
                instantiated_metrics.append(metric())
            else:
                instantiated_metrics.append(metric)
        self._metrics = instantiated_metrics
        self.prefix = prefix

    def reset(self):
        for metric in self._metrics:
            metric.reset()

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> Dict:
        logs = {}
        for metric in self._metrics:
            if isinstance(metric, Metric):
                logs[self.prefix + metric._name] = metric(y_pred, y_true)
            elif isinstance(metric, TorchMetric):
                metric.update(y_pred, y_true.int())  # type: ignore[attr-defined]
                logs[self.prefix + type(metric).__name__] = (
                    metric.compute().detach().cpu().numpy()
                )
        return logs

In [10]:
# callbacks
class Callback(object):
    """
    Base class used to build new callbacks.
    """

    def __init__(self):
        pass

    def set_params(self, params):
        self.params = params

    def set_model(self, model: Any):
        self.model = model

    def set_trainer(self, trainer: Any):
        self.trainer = trainer

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
        pass

    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
        pass

    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
        pass

    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
        pass

    def on_train_begin(self, logs: Optional[Dict] = None):
        pass

    def on_train_end(self, logs: Optional[Dict] = None):
        pass

    def on_eval_begin(self, logs: Optional[Dict] = None):
        # at the moment only used to reset metrics before eval
        pass

In [11]:
# transforms
from torchvision.transforms import (
    Pad,
    Lambda,
    Resize,
    Compose,
    TenCrop,
    FiveCrop,
    ToTensor,
    Grayscale,
    Normalize,
    CenterCrop,
    RandomCrop,
    ToPILImage,
    ColorJitter,
    PILToTensor,
    RandomApply,
    RandomOrder,
    GaussianBlur,
    RandomAffine,
    RandomChoice,
    RandomInvert,
    RandomErasing,
    RandomEqualize,
    RandomRotation,
    RandomSolarize,
    RandomGrayscale,
    RandomPosterize,
    ConvertImageDtype,
    InterpolationMode,
    RandomPerspective,
    RandomResizedCrop,
    RandomAutocontrast,
    RandomVerticalFlip,
    LinearTransformation,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
)

Transforms = Union[
    Pad,
    Lambda,
    Resize,
    Compose,
    TenCrop,
    FiveCrop,
    ToTensor,
    Grayscale,
    Normalize,
    CenterCrop,
    RandomCrop,
    ToPILImage,
    ColorJitter,
    PILToTensor,
    RandomApply,
    RandomOrder,
    GaussianBlur,
    RandomAffine,
    RandomChoice,
    RandomInvert,
    RandomErasing,
    RandomEqualize,
    RandomRotation,
    RandomSolarize,
    RandomGrayscale,
    RandomPosterize,
    ConvertImageDtype,
    InterpolationMode,
    RandomPerspective,
    RandomResizedCrop,
    RandomAutocontrast,
    RandomVerticalFlip,
    LinearTransformation,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
]

In [12]:
# initializer
class Initializer(object):
    def __call__(self, model: nn.Module):
        raise NotImplementedError("Initializer must implement this method")

In [13]:
class BaseTabularModelWithAttention(nn.Module):
    def __init__(
        self,
        column_idx: Dict[str, int],
        cat_embed_input: Optional[List[Tuple[str, int]]],
        cat_embed_dropout: float,
        use_cat_bias: bool,
        cat_embed_activation: Optional[str],
        full_embed_dropout: bool,
        shared_embed: bool,
        add_shared_embed: bool,
        frac_shared_embed: float,
        continuous_cols: Optional[List[str]],
        cont_norm_layer: str,
        embed_continuous: bool,
        cont_embed_dropout: float,
        use_cont_bias: bool,
        cont_embed_activation: Optional[str],
        input_dim: int,
    ):
        super().__init__()

        self.column_idx = column_idx
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
        self.use_cat_bias = use_cat_bias
        self.cat_embed_activation = cat_embed_activation
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed

        self.continuous_cols = continuous_cols
        self.cont_norm_layer = cont_norm_layer
        self.embed_continuous = embed_continuous
        self.cont_embed_dropout = cont_embed_dropout
        self.use_cont_bias = use_cont_bias
        self.cont_embed_activation = cont_embed_activation

        self.input_dim = input_dim

        self.cat_and_cont_embed = SameSizeCatAndContEmbeddings(
            input_dim,
            column_idx,
            cat_embed_input,
            cat_embed_dropout,
            use_cat_bias,
            full_embed_dropout,
            shared_embed,
            add_shared_embed,
            frac_shared_embed,
            continuous_cols,
            cont_norm_layer,
            embed_continuous,
            cont_embed_dropout,
            use_cont_bias,
        )
        self.cat_embed_act_fn = (
            get_activation_fn(cat_embed_activation)
            if cat_embed_activation is not None
            else None
        )
        self.cont_embed_act_fn = (
            get_activation_fn(cont_embed_activation)
            if cont_embed_activation is not None
            else None
        )

    def _get_embeddings(self, X: Tensor) -> Tensor:
        x_cat, x_cont = self.cat_and_cont_embed(X)
        if x_cat is not None:
            x = (
                self.cat_embed_act_fn(x_cat)
                if self.cat_embed_act_fn is not None
                else x_cat
            )
        if x_cont is not None:
            if self.cont_embed_act_fn is not None:
                x_cont = self.cont_embed_act_fn(x_cont)
            x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont
        return x

    @property
    def output_dim(self) -> int:
        raise NotImplementedError

    @property
    def attention_weights(self):
        raise NotImplementedError

In [17]:
class SaintEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_heads: int,
        use_bias: bool,
        attn_dropout: float,
        ff_dropout: float,
        activation: str,
        n_feat: int,
    ):
        super(SaintEncoder, self).__init__()

        self.n_feat = n_feat

        self.col_attn = MultiHeadedAttention(
            input_dim,
            n_heads,
            use_bias,
            attn_dropout,
        )
        self.col_attn_ff = FeedForward(input_dim, ff_dropout, activation)
        self.col_attn_addnorm = AddNorm(input_dim, attn_dropout)
        self.col_attn_ff_addnorm = AddNorm(input_dim, ff_dropout)

        self.row_attn = MultiHeadedAttention(
            n_feat * input_dim,
            n_heads,
            use_bias,
            attn_dropout,
        )
        self.row_attn_ff = FeedForward(n_feat * input_dim, ff_dropout, activation)
        self.row_attn_addnorm = AddNorm(n_feat * input_dim, attn_dropout)
        self.row_attn_ff_addnorm = AddNorm(n_feat * input_dim, ff_dropout)

    def forward(self, X: Tensor) -> Tensor:
        x = self.col_attn_addnorm(X, self.col_attn)
        x = self.col_attn_ff_addnorm(x, self.col_attn_ff)
        x = einops.rearrange(x, "b n d -> 1 b (n d)")
        x = self.row_attn_addnorm(x, self.row_attn)
        x = self.row_attn_ff_addnorm(x, self.row_attn_ff)
        x = einops.rearrange(x, "1 b (n d) -> b n d", n=self.n_feat)
        return x

In [18]:
class FeedForward(nn.Module):
    def __init__(
        self,
        input_dim: int,
        dropout: float,
        activation: str,
        mult: float = 4.0,
    ):
        super(FeedForward, self).__init__()
        ff_hidden_dim = int(input_dim * mult)
        self.w_1 = nn.Linear(
            input_dim,
            ff_hidden_dim * 2 if activation.endswith("glu") else ff_hidden_dim,
        )
        self.w_2 = nn.Linear(ff_hidden_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = get_activation_fn(activation)

    def forward(self, X: Tensor) -> Tensor:
        return self.w_2(self.dropout(self.activation(self.w_1(X))))

    
class AddNorm(nn.Module):
    """aka PosNorm"""

    def __init__(self, input_dim: int, dropout: float):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(input_dim)

    def forward(self, X: Tensor, sublayer: nn.Module) -> Tensor:
        return self.ln(X + self.dropout(sublayer(X)))

    
def get_activation_fn(activation):
    if activation == "relu":
        return nn.ReLU(inplace=True)
    elif activation == "leaky_relu":
        return nn.LeakyReLU(inplace=True)
    elif activation == "tanh":
        return nn.Tanh()
    elif activation == "gelu":
        return nn.GELU()
    elif activation == "geglu":
        return GEGLU()
    elif activation == "reglu":
        return REGLU()
    elif activation == "softplus":
        return nn.Softplus()
    else:
        raise ValueError(
            "Only the following activation functions are currently "
            "supported: {}. Note that 'geglu' and 'reglu' "
            "should only be used as transformer's activations".format(
                ", ".join(allowed_activations)
            )
        )        

In [19]:
class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_heads: int,
        use_bias: bool,
        dropout: float,
        query_dim: Optional[int] = None,
    ):
        super(MultiHeadedAttention, self).__init__()

        assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'"

        self.head_dim = input_dim // n_heads
        self.n_heads = n_heads

        self.dropout = nn.Dropout(dropout)

        query_dim = query_dim if query_dim is not None else input_dim
        self.q_proj = nn.Linear(query_dim, input_dim, bias=use_bias)
        self.kv_proj = nn.Linear(input_dim, input_dim * 2, bias=use_bias)
        self.out_proj = (
            nn.Linear(input_dim, query_dim, bias=use_bias) if n_heads > 1 else None
        )

    def forward(self, X_Q: Tensor, X_KV: Optional[Tensor] = None) -> Tensor:
        # b: batch size
        # s: seq length
        # l: target sequence length
        # m: used to refer indistinctively to s or l
        # h: number of attention heads,
        # d: head_dim
        q = self.q_proj(X_Q)
        X_KV = X_KV if X_KV is not None else X_Q
        k, v = self.kv_proj(X_KV).chunk(2, dim=-1)
        q, k, v = map(
            lambda t: einops.rearrange(t, "b m (h d) -> b h m d", h=self.n_heads),
            (q, k, v),
        )
        scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(self.head_dim)
        attn_weights = scores.softmax(dim=-1)
        self.attn_weights = attn_weights
        attn_weights = self.dropout(attn_weights)
        attn_output = einsum("b h s l, b h l d -> b h s d", attn_weights, v)
        output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads)

        if self.out_proj is not None:
            output = self.out_proj(output)

        return output

In [20]:
class MLP(nn.Module):
    def __init__(
        self,
        d_hidden: List[int],
        activation: str,
        dropout: Optional[Union[float, List[float]]],
        batchnorm: bool,
        batchnorm_last: bool,
        linear_first: bool,
    ):
        super(MLP, self).__init__()

        if not dropout:
            dropout = [0.0] * len(d_hidden)
        elif isinstance(dropout, float):
            dropout = [dropout] * len(d_hidden)

        self.mlp = nn.Sequential()
        for i in range(1, len(d_hidden)):
            self.mlp.add_module(
                "dense_layer_{}".format(i - 1),
                dense_layer(
                    d_hidden[i - 1],
                    d_hidden[i],
                    activation,
                    dropout[i - 1],
                    batchnorm and (i != len(d_hidden) - 1 or batchnorm_last),
                    linear_first,
                ),
            )

    def forward(self, X: Tensor) -> Tensor:
        return self.mlp(X)

In [21]:
# embedding_layer
class SameSizeCatAndContEmbeddings(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        column_idx: Dict[str, int],
        cat_embed_input: Optional[List[Tuple[str, int]]],
        cat_embed_dropout: float,
        use_cat_bias: bool,
        full_embed_dropout: bool,
        shared_embed: bool,
        add_shared_embed: bool,
        frac_shared_embed: float,
        continuous_cols: Optional[List[str]],
        cont_norm_layer: str,
        embed_continuous: bool,
        cont_embed_dropout: float,
        use_cont_bias: bool,
    ):
        super(SameSizeCatAndContEmbeddings, self).__init__()

        self.embed_dim = embed_dim
        self.cat_embed_input = cat_embed_input
        self.continuous_cols = continuous_cols
        self.embed_continuous = embed_continuous

        # Categorical
        if cat_embed_input is not None:
            self.cat_embed = SameSizeCatEmbeddings(
                embed_dim,
                column_idx,
                cat_embed_input,
                cat_embed_dropout,
                use_cat_bias,
                full_embed_dropout,
                shared_embed,
                add_shared_embed,
                frac_shared_embed,
            )
        # Continuous
        if continuous_cols is not None:
            self.cont_idx = [column_idx[col] for col in continuous_cols]
            if cont_norm_layer == "layernorm":
                self.cont_norm: NormLayers = nn.LayerNorm(len(continuous_cols))
            elif cont_norm_layer == "batchnorm":
                self.cont_norm = nn.BatchNorm1d(len(continuous_cols))
            else:
                self.cont_norm = nn.Identity()
            if self.embed_continuous:
                self.cont_embed = ContEmbeddings(
                    len(continuous_cols),
                    embed_dim,
                    cont_embed_dropout,
                    use_cont_bias,
                )

    def forward(self, X: Tensor) -> Tuple[Tensor, Any]:

        if self.cat_embed_input is not None:
            x_cat = self.cat_embed(X)
        else:
            x_cat = None

        if self.continuous_cols is not None:
            x_cont = self.cont_norm((X[:, self.cont_idx].float()))
            if self.embed_continuous:
                x_cont = self.cont_embed(x_cont)
        else:
            x_cont = None

        return x_cat, x_cont
    
class SameSizeCatEmbeddings(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        column_idx: Dict[str, int],
        embed_input: Optional[List[Tuple[str, int]]],
        embed_dropout: float,
        use_bias: bool,
        full_embed_dropout: bool,
        shared_embed: bool,
        add_shared_embed: bool,
        frac_shared_embed: float,
    ):
        super(SameSizeCatEmbeddings, self).__init__()

        self.n_tokens = sum([ei[1] for ei in embed_input])
        self.column_idx = column_idx
        self.embed_input = embed_input
        self.shared_embed = shared_embed
        self.with_cls_token = "cls_token" in column_idx

        self.embed_layers_names = None
        if self.embed_input is not None:
            self.embed_layers_names = {
                e[0]: e[0].replace(".", "_") for e in self.embed_input
            }

        categorical_cols = [ei[0] for ei in embed_input]
        self.cat_idx = [self.column_idx[col] for col in categorical_cols]

        if use_bias:
            if shared_embed:
                warnings.warn(
                    "The current implementation of 'SharedEmbeddings' does not use bias",
                    UserWarning,
                )
            n_cat = (
                len(categorical_cols) - 1
                if self.with_cls_token
                else len(categorical_cols)
            )
            self.bias = nn.init.kaiming_uniform_(
                nn.Parameter(torch.Tensor(n_cat, embed_dim)), a=math.sqrt(5)
            )
        else:
            self.bias = None

        # Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
        if self.shared_embed:
            self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict(
                {
                    "emb_layer_"
                    + self.embed_layers_names[col]: SharedEmbeddings(
                        val if col == "cls_token" else val + 1,
                        embed_dim,
                        embed_dropout,
                        full_embed_dropout,
                        add_shared_embed,
                        frac_shared_embed,
                    )
                    for col, val in self.embed_input
                }
            )
        else:
            n_tokens = sum([ei[1] for ei in embed_input])
            self.embed = nn.Embedding(n_tokens + 1, embed_dim, padding_idx=0)
            if full_embed_dropout:
                self.dropout: DropoutLayers = FullEmbeddingDropout(embed_dropout)
            else:
                self.dropout = nn.Dropout(embed_dropout)

    def forward(self, X: Tensor) -> Tensor:
        if self.shared_embed:
            cat_embed = [
                self.embed["emb_layer_" + self.embed_layers_names[col]](  # type: ignore[index]
                    X[:, self.column_idx[col]].long()
                ).unsqueeze(
                    1
                )
                for col, _ in self.embed_input
            ]
            x = torch.cat(cat_embed, 1)
        else:
            x = self.embed(X[:, self.cat_idx].long())
            if self.bias is not None:
                if self.with_cls_token:
                    # no bias to be learned for the [CLS] token
                    bias = torch.cat(
                        [torch.zeros(1, self.bias.shape[1], device=x.device), self.bias]
                    )
                else:
                    bias = self.bias
                x = x + bias.unsqueeze(0)

            x = self.dropout(x)
        return x

class ContEmbeddings(nn.Module):
    def __init__(
        self,
        n_cont_cols: int,
        embed_dim: int,
        embed_dropout: float,
        use_bias: bool,
    ):
        super(ContEmbeddings, self).__init__()

        self.n_cont_cols = n_cont_cols
        self.embed_dim = embed_dim
        self.embed_dropout = embed_dropout
        self.use_bias = use_bias

        self.weight = nn.init.kaiming_uniform_(
            nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)), a=math.sqrt(5)
        )

        self.bias = (
            nn.init.kaiming_uniform_(
                nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)), a=math.sqrt(5)
            )
            if use_bias
            else None
        )

    def forward(self, X: Tensor) -> Tensor:
        x = self.weight.unsqueeze(0) * X.unsqueeze(2)
        if self.bias is not None:
            x = x + self.bias.unsqueeze(0)
        return F.dropout(x, self.embed_dropout, self.training)

    def extra_repr(self) -> str:
        s = "{n_cont_cols}, {embed_dim}, embed_dropout={embed_dropout}, use_bias={use_bias}"
        return s.format(**self.__dict__)

In [22]:
class SAINT(BaseTabularModelWithAttention):
    r"""Defines a [SAINT model](https://arxiv.org/abs/2106.01342) that
    can be used as the `deeptabular` component of a Wide & Deep model or
    independently by itself.
    :information_source: **NOTE**: This is an slightly modified and enhanced
     version of the model described in the paper,
    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
        the model. Required to slice the tensors. e.g.
        _{'education': 0, 'relationship': 1, 'workclass': 2, ...}_
    cat_embed_input: List, Optional, default = None
        List of Tuples with the column name and number of unique values and
        embedding dimension. e.g. _[(education, 11), ...]_
    cat_embed_dropout: float, default = 0.1
        Categorical embeddings dropout
    use_cat_bias: bool, default = False,
        Boolean indicating if bias will be used for the categorical embeddings
    cat_embed_activation: Optional, str, default = None,
        Activation function for the categorical embeddings, if any. _'tanh'_,
        _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
        `pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
        If `full_embed_dropout = True`, `cat_embed_dropout` is ignored.
    shared_embed: bool, default = False
        The idea behind `shared_embed` is described in the Appendix A in the
        [TabTransformer paper](https://arxiv.org/abs/2012.06678): the
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns. In other
        words, the idea is to let the model learn which column is embedded
        at the time.
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        `frac_shared_embed` with the shared embeddings.
        See `pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
    frac_shared_embed: float, default = 0.25
        The fraction of embeddings that will be shared (if `add_shared_embed
        = False`) by all the different categories for one particular
        column.
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: _'layernorm'_, _'batchnorm'_ or None.
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
        Boolean indicating if bias will be used for the continuous embeddings
    cont_embed_activation: str, default = None
        Activation function to be applied to the continuous embeddings, if
        any. _'tanh'_, _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
    input_dim: int, default = 32
        The so-called *dimension of the model*. Is the number of
        embeddings used to encode the categorical and/or continuous columns
    n_heads: int, default = 8
        Number of attention heads per Transformer block
    use_qkv_bias: bool, default = False
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
    n_blocks: int, default = 2
        Number of SAINT-Transformer blocks.
    attn_dropout: float, default = 0.2
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
    transformer_activation: str, default = "gelu"
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
    mlp_hidden_dims: List, Optional, default = None
        MLP hidden dimensions. If not provided it will default to $[l, 4
        \times l, 2 \times l]$ where $l$ is the MLP's input dimension
    mlp_activation: str, default = "relu"
        MLP activation function. _'tanh'_, _'relu'_, _'leaky_relu'_ and
        _'gelu'_ are supported
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If `True: [LIN -> ACT -> BN -> DP]`. If `False: [BN -> DP ->
        LIN -> ACT]`
    Attributes
    ----------
    cat_and_cont_embed: nn.Module
        This is the module that processes the categorical and continuous columns
    encoder: nn.Module
        Sequence of SAINT-Transformer blocks
    mlp: nn.Module
        MLP component in the model
    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        cont_norm_layer: str = None,
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
        input_dim: int = 32,
        use_qkv_bias: bool = False,
        n_heads: int = 8,
        n_blocks: int = 2,
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
        mlp_dropout: float = 0.1,
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )

        self.use_qkv_bias = use_qkv_bias
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation

        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
        self.mlp_dropout = mlp_dropout
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

        # Embeddings are instantiated at the base model
        # Transformer blocks
        self.encoder = nn.Sequential()
        for i in range(n_blocks):
            self.encoder.add_module(
                "saint_block" + str(i),
                SaintEncoder(
                    input_dim,
                    n_heads,
                    use_qkv_bias,
                    attn_dropout,
                    ff_dropout,
                    transformer_activation,
                    self.n_feats,
                ),
            )

        self.mlp_first_hidden_dim = (
            self.input_dim if self.with_cls_token else (self.n_feats * self.input_dim)
        )

        if mlp_hidden_dims is not None:
            self.mlp = MLP(
                [self.mlp_first_hidden_dim] + mlp_hidden_dims,
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
        else:
            self.mlp = None

    def forward(self, X: Tensor) -> Tensor:
        x = self._get_embeddings(X)
        x = self.encoder(x)
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
        if self.mlp is not None:
            x = self.mlp(x)
        return x

    @property
    def output_dim(self) -> int:
        r"""The output dimension of the model. This is a required property
        neccesary to build the `WideDeep` class
        """
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
            else self.mlp_first_hidden_dim
        )

    @property
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
        The shape of the attention weights is:
        - column attention: $(N, H, F, F)$
        - row attention: $(1, H, N, N)$
        where $N$ is the batch size, $H$ is the number of heads and $F$ is the
        number of features/columns in the dataset
        """
        attention_weights = []
        for blk in self.encoder:
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights

In [26]:
# base trainer
# source: TabTransformers
class BaseTrainer(ABC):
    def __init__(
        self,
        model: SAINT,
        objective: str,
        custom_loss_function: Optional[Module],
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]],
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]],
        initializers: Optional[Union[Initializer, Dict[str, Initializer]]],
        transforms: Optional[List[Transforms]],
        callbacks: Optional[List[Callback]],
        metrics: Optional[Union[List[Metric], List[TorchMetric]]],
        verbose: int,
        seed: int,
        **kwargs,
    ):

        self._check_inputs(
            model, objective, optimizers, lr_schedulers, custom_loss_function
        )
        self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)

        self.early_stop = False
        self.verbose = verbose
        self.seed = seed

        self.model = model
        if self.model.is_tabnet:
            self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
            self.reducing_matrix = create_explain_matrix(self.model)
        self.model.to(self.device)
        self.model.wd_device = self.device

        self.objective = objective
        self.method = _ObjectiveToMethod.get(objective)

        self._initialize(initializers)
        self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs)
        self.optimizer = self._set_optimizer(optimizers)
        self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs)
        self.transforms = self._set_transforms(transforms)
        self._set_callbacks_and_metrics(callbacks, metrics)

    @abstractmethod
    def fit(
        self,
        X_wide: Optional[np.ndarray],
        X_tab: Optional[np.ndarray],
        X_text: Optional[np.ndarray],
        X_img: Optional[np.ndarray],
        X_train: Optional[Dict[str, np.ndarray]],
        X_val: Optional[Dict[str, np.ndarray]],
        val_split: Optional[float],
        target: Optional[np.ndarray],
        n_epochs: int,
        validation_freq: int,
        batch_size: int,
    ):
        raise NotImplementedError("Trainer.fit method not implemented")

    @abstractmethod
    def predict(
        self,
        X_wide: Optional[np.ndarray],
        X_tab: Optional[np.ndarray],
        X_text: Optional[np.ndarray],
        X_img: Optional[np.ndarray],
        X_test: Optional[Dict[str, np.ndarray]],
        batch_size: int,
    ) -> np.ndarray:
        raise NotImplementedError("Trainer.predict method not implemented")

    @abstractmethod
    def predict_proba(
        self,
        X_wide: Optional[np.ndarray],
        X_tab: Optional[np.ndarray],
        X_text: Optional[np.ndarray],
        X_img: Optional[np.ndarray],
        X_test: Optional[Dict[str, np.ndarray]],
        batch_size: int,
    ) -> np.ndarray:
        raise NotImplementedError("Trainer.predict_proba method not implemented")

    @abstractmethod
    def save(
        self,
        path: str,
        save_state_dict: bool,
        model_filename: str,
    ):
        raise NotImplementedError("Trainer.save method not implemented")

    def _restore_best_weights(self):
        already_restored = any(
            [
                (
                    callback.__class__.__name__ == "EarlyStopping"
                    and callback.restore_best_weights
                )
                for callback in self.callback_container.callbacks
            ]
        )
        if already_restored:
            pass
        else:
            for callback in self.callback_container.callbacks:
                if callback.__class__.__name__ == "ModelCheckpoint":
                    if callback.save_best_only:
                        if self.verbose:
                            print(
                                f"Model weights restored to best epoch: {callback.best_epoch + 1}"
                            )
                        self.model.load_state_dict(callback.best_state_dict)
                    else:
                        if self.verbose:
                            print(
                                "Model weights after training corresponds to the those of the "
                                "final epoch which might not be the best performing weights. Use "
                                "the 'ModelCheckpoint' Callback to restore the best epoch weights."
                            )

    def _initialize(self, initializers):
        if initializers is not None:
            if isinstance(initializers, Dict):
                self.initializer = MultipleInitializer(
                    initializers, verbose=self.verbose
                )
                self.initializer.apply(self.model)
            elif isinstance(initializers, type):
                self.initializer = initializers()
                self.initializer(self.model)
            elif isinstance(initializers, Initializer):
                self.initializer = initializers
                self.initializer(self.model)

    def _set_loss_fn(self, objective, custom_loss_function, **kwargs):

        class_weight = (
            torch.tensor(kwargs["class_weight"]).to(self.device)
            if "class_weight" in kwargs
            else None
        )

        if custom_loss_function is not None:
            return custom_loss_function
        elif (
            self.method not in ["regression", "qregression"]
            and "focal_loss" not in objective
        ):
            return alias_to_loss(objective, weight=class_weight)
        elif "focal_loss" in objective:
            alpha = kwargs.get("alpha", 0.25)
            gamma = kwargs.get("gamma", 2.0)
            return alias_to_loss(objective, alpha=alpha, gamma=gamma)
        else:
            return alias_to_loss(objective)

    def _set_optimizer(self, optimizers):
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
                optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
            elif isinstance(optimizers, Dict):
                opt_names = list(optimizers.keys())
                mod_names = [n for n, c in self.model.named_children()]
                # if with_fds - the prediction layer is part of the model and
                # should be optimized with the rest of deeptabular
                # component/model
                if self.model.with_fds:
                    if "enf_pos" in mod_names:
                        mod_names.remove("enf_pos")
                    mod_names.remove("fds_layer")
                    optimizers["deeptabular"].add_param_group(
                        {"params": self.model.fds_layer.pred_layer.parameters()}
                    )
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
                optimizer = MultipleOptimizer(optimizers)
        else:
            optimizer = torch.optim.Adam(self.model.parameters())  # type: ignore
        return optimizer

    def _set_lr_scheduler(self, lr_schedulers, **kwargs):

        # ReduceLROnPlateau is special
        reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)

        self._set_reduce_on_plateau_criterion(
            lr_schedulers, reducelronplateau_criterion
        )

        if lr_schedulers is not None:

            if isinstance(lr_schedulers, LRScheduler) or isinstance(
                lr_schedulers, ReduceLROnPlateau
            ):
                lr_scheduler = lr_schedulers
                cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
            else:
                lr_scheduler = MultipleLRScheduler(lr_schedulers)
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in lr_scheduler._schedulers.items()
                ]
                cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
        else:
            lr_scheduler, cyclic_lr = None, False

        self.cyclic_lr = cyclic_lr

        return lr_scheduler

    def _set_reduce_on_plateau_criterion(
        self, lr_schedulers, reducelronplateau_criterion
    ):

        self.reducelronplateau = False

        if isinstance(lr_schedulers, Dict):
            for _, scheduler in lr_schedulers.items():
                if isinstance(scheduler, ReduceLROnPlateau):
                    self.reducelronplateau = True
        elif isinstance(lr_schedulers, ReduceLROnPlateau):
            self.reducelronplateau = True

        if self.reducelronplateau and not reducelronplateau_criterion:
            UserWarning(
                "The learning rate scheduler of at least one of the model components is of type "
                "ReduceLROnPlateau. The step method in this scheduler requires a 'metrics' param "
                "that can be either the validation loss or the validation metric. Please, when "
                "instantiating the Trainer, specify which quantity will be tracked using "
                "reducelronplateau_criterion = 'loss' (default) or reducelronplateau_criterion = 'metric'"
            )
            self.reducelronplateau_criterion = "loss"
        else:
            self.reducelronplateau_criterion = reducelronplateau_criterion

    @staticmethod
    def _set_transforms(transforms):
        if transforms is not None:
            return MultipleTransforms(transforms)()
        else:
            return None

    def _set_callbacks_and_metrics(self, callbacks, metrics):
        self.callbacks: List = [History(), LRShedulerCallback()]
        if callbacks is not None:
            for callback in callbacks:
                if isinstance(callback, type):
                    callback = callback()
                self.callbacks.append(callback)
        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
            self.callbacks += [MetricCallback(self.metric)]
        else:
            self.metric = None
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self.model)
        self.callback_container.set_trainer(self)

    @staticmethod
    def _check_inputs(
        model,
        objective,
        optimizers,
        lr_schedulers,
        custom_loss_function,
    ):

        if model.with_fds and _ObjectiveToMethod.get(objective) != "regression":
            raise ValueError(
                "Feature Distribution Smooting can be used only for regression"
            )

        if _ObjectiveToMethod.get(objective) == "multiclass" and model.pred_dim == 1:
            raise ValueError(
                "This is a multiclass classification problem but the size of the output layer"
                " is set to 1. Please, set the 'pred_dim' param equal to the number of classes "
                " when instantiating the 'WideDeep' class"
            )

        if isinstance(optimizers, Dict):
            if lr_schedulers is not None and not isinstance(lr_schedulers, Dict):
                raise ValueError(
                    "''optimizers' and 'lr_schedulers' must have consistent type: "
                    "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                    "Please, read the documentation or see the examples for more details"
                )

        if custom_loss_function is not None and objective not in [
            "binary",
            "multiclass",
            "regression",
        ]:
            raise ValueError(
                "If 'custom_loss_function' is not None, 'objective' must be 'binary' "
                "'multiclass' or 'regression', consistent with the loss function"
            )

    @staticmethod
    def _set_device_and_num_workers(**kwargs):

        # Important note for Mac users: Since python 3.8, the multiprocessing
        # library start method changed from 'fork' to 'spawn'. This affects the
        # data-loaders, which will not run in parallel.
        default_num_workers = (
            0
            if sys.platform == "darwin" and sys.version_info.minor > 7
            else os.cpu_count()
        )
        default_device = "cuda" if torch.cuda.is_available() else "cpu"
        device = kwargs.get("device", default_device)
        num_workers = kwargs.get("num_workers", default_num_workers)
        return device, num_workers

In [25]:
X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
colnames = ['a', 'b', 'c', 'd', 'e']
cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
continuous_cols = ['e']
column_idx = {k:v for v,k in enumerate(colnames)}
model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
out = model(X_tab)

print("X_tab:\n", X_tab)
print("colnames:\n", colnames)
print("cat_embed_input:\n", cat_embed_input)
print("continuous_cols:\n", continuous_cols)
print("column_idx:\n", column_idx)
print("output shape:\n", out.shape)

X_tab:
 tensor([[2.0000, 1.0000, 3.0000, 1.0000, 0.0404],
        [1.0000, 2.0000, 0.0000, 2.0000, 0.8259],
        [3.0000, 0.0000, 2.0000, 0.0000, 0.7349],
        [2.0000, 1.0000, 1.0000, 3.0000, 0.2618],
        [2.0000, 2.0000, 1.0000, 1.0000, 0.3920]])
colnames:
 ['a', 'b', 'c', 'd', 'e']
cat_embed_input:
 [('a', 4), ('b', 4), ('c', 4), ('d', 4)]
continuous_cols:
 ['e']
column_idx:
 {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}
output shape:
 torch.Size([5, 160])
