In [2]:
import torch
import torch.nn as nn

class SpatioTemporalTransformer(nn.Module):
    """
    Transformer model for Spatio-Temporal water quality prediction.
    """
    def __init__(self, input_dim=8, hidden_dim=64, output_dim=4, n_heads=4, n_layers=2):
        super().__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.output_proj = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        """
        x = self.input_proj(x)
        x = self.transformer(x)
        x = x[:, -1, :]  # Take last timestep
        out = self.output_proj(x)
        return out