In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

In [2]:
def get_activation_function(name):
    if name == "relu":
        return F.relu
    elif name == "gelu":
        return F.gelu
    else:
        raise ValueError(f"Unsupported activation function: {name}")


def get_activation_module(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "gelu":
        return nn.GELU()
    else:
        raise ValueError(f"Unsupported activation function: {name}")


class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention' without using the math library.
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        # Compute scaled dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(d_k, dtype=query.dtype, device=query.device)
        )

        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        output = torch.matmul(p_attn, value)
        return output, p_attn


class MultiHeadedAttention(nn.Module):
    """
    Multi-Headed Attention module without using built-in attention modules.
    Supports parameters: d_model, num_heads, dropout, batch_first.
    """

    def __init__(self, d_model, num_heads, dropout=0.1, batch_first=True):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_heads = num_heads
        self.batch_first = batch_first

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, d_model) for _ in range(3)]
        )
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        if self.batch_first:
            batch_size, seq_len, _ = query.size()
        else:
            seq_len, batch_size, _ = query.size()
            # Transpose to batch first
            query = query.transpose(0, 1)
            key = key.transpose(0, 1)
            value = value.transpose(0, 1)

        # Linear projections
        query, key, value = [
            linear(x)
            .view(batch_size, seq_len, self.num_heads, self.d_k)
            .transpose(1, 2)
            for linear, x in zip(self.linear_layers, (query, key, value))
        ]  # Each tensor is of shape (batch_size, num_heads, seq_len, d_k)

        # Prepare masks
        if key_padding_mask is not None:
            # key_padding_mask: (batch_size, seq_len)
            # Expand to (batch_size, 1, 1, seq_len)
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(
                2
            )  # (batch_size, 1, 1, seq_len)
            # Expand to match the number of heads
            key_padding_mask = key_padding_mask.expand(-1, self.num_heads, -1, -1)
        if attn_mask is not None:
            # attn_mask: (seq_len, seq_len)
            attn_mask = attn_mask.unsqueeze(0)  # (1, seq_len, seq_len)
            attn_mask = attn_mask.expand(batch_size * self.num_heads, -1, -1).view(
                batch_size, self.num_heads, seq_len, seq_len
            )
        # Combine masks
        if key_padding_mask is not None and attn_mask is not None:
            combined_mask = key_padding_mask | attn_mask
        elif key_padding_mask is not None:
            combined_mask = key_padding_mask
        elif attn_mask is not None:
            combined_mask = attn_mask
        else:
            combined_mask = None

        # Apply attention
        x, attn = self.attention(
            query, key, value, mask=combined_mask, dropout=self.dropout
        )

        # Concatenate heads and apply final linear layer
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        x = self.output_linear(x)

        if not self.batch_first:
            x = x.transpose(0, 1)

        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        num_heads,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        batch_first=True,
    ):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(
            d_model, num_heads, dropout=dropout, batch_first=batch_first
        )
        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation_fn = get_activation_function(activation)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention
        attn_output = self.self_attn(
            src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Feedforward network
        ff_output = self.linear2(self.dropout(self.activation_fn(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
        )
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src

        for layer in self.layers:
            output = layer(
                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
            )

        if self.norm is not None:
            output = self.norm(output)

        return output

In [3]:
class SitsTransformer(nn.Module):
    def __init__(
        self,
        input_dim=10,
        num_classes=9,
        d_model=128,
        n_head=16,
        n_layers=1,
        d_inner=128,
        activation="relu",
        dropout=0.2,
        max_len=366,
        max_seq_len=70,
        T=1000,
        max_temporal_shift=30,
    ):
        super(SitsTransformer, self).__init__()
        self.modelname = self._get_name()
        self.max_seq_len = max_seq_len

        self.mlp_dim = [input_dim, 32, 64, d_model]
        layers = []
        for i in range(len(self.mlp_dim) - 1):
            layers.append(LinLayer(self.mlp_dim[i], self.mlp_dim[i + 1]))
        self.mlp1 = nn.Sequential(*layers)

        self.inlayernorm = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.position_enc = PositionalEncoding(
            d_model, max_len=max_len + 2 * max_temporal_shift, T=T
        )

        encoder_layer = TransformerEncoderLayer(
            d_model, n_head, d_inner, dropout, activation, batch_first=True
        )
        encoder_norm = nn.LayerNorm(d_model)
        self.transformerencoder = TransformerEncoder(
            encoder_layer, n_layers, encoder_norm
        )

        layers = []
        decoder = [d_model, 64, 32, num_classes]
        for i in range(len(decoder) - 1):
            layers.append(nn.Linear(decoder[i], decoder[i + 1]))
            if i < (len(decoder) - 2):
                layers.extend([nn.BatchNorm1d(decoder[i + 1]), nn.ReLU()])
        self.decoder = nn.Sequential(*layers)

        self.input_sample = {
            "doy": torch.randint(1, max_len, (2, self.max_seq_len), dtype=torch.int64),
            "mask": torch.zeros((2, self.max_seq_len), dtype=torch.bool),
            "weight": torch.rand((2, self.max_seq_len), dtype=torch.float32),
            "x": torch.rand((2, self.max_seq_len, input_dim), dtype=torch.float32)
        }
        self.expected_output_sample = torch.rand((2, num_classes), dtype=torch.float32)

    def forward(self, input, is_bert=False):
        x = input["x"]
        doy = input["doy"]
        mask = input["mask"]
        weight = input["weight"]

        x = self.mlp1(x)

        x = self.inlayernorm(x)
        x = self.dropout(x + self.position_enc(doy))

        x = self.transformerencoder(x, src_key_padding_mask=mask)

        if not is_bert:
            weight = self.dropout(weight)
            weight /= weight.sum(1, keepdim=True)
            x = torch.bmm(weight.unsqueeze(1), x).squeeze()

        logits = self.decoder(x)

        return logits


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, T: int = 10000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(T) / d_model))
        pe = torch.zeros(max_len + 1, d_model)
        pe[1:, 0::2] = torch.sin(position * div_term)
        pe[1:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, doy):
        """
        Args:
            doy: Tensor, shape [batch_size, seq_len]
        """
        return self.pe[doy]


class LinLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinLayer, self).__init__()
        self.lin = nn.Linear(in_dim, out_dim)
        self.ln = nn.LayerNorm(out_dim)

    def forward(self, x):
        x = self.lin(x)
        x = self.ln(x)
        x = F.relu(x)
        return x

In [4]:
# Check consistency of output sample

with torch.inference_mode():
    model = SitsTransformer()
    output = model(model.input_sample)
    assert output.shape == model.expected_output_sample.shape