In [1]:
# %pip install torch

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt



In [3]:
class LANDModel(nn.Module):
    def __init__(self, Na=512, Nb=1024, dropout_rate=0.45):
        super(LANDModel, self).__init__()

        # Local DEM branch
        self.dem_local = nn.Sequential(
            nn.Linear(25, 128),
            nn.ReLU()
        )

        # Regional DEM branch
        self.dem_regional = nn.Sequential(
            nn.Linear(25, 128),
            nn.ReLU()
        )

        # Conv2D（channel-wise: groups=16）
        self.reanalysis_conv = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, groups=16),  # shape: (batch, 16, 1, 1)
            nn.ReLU()
        )

        # Month one-hot (12D)
        self.month_dense = nn.Linear(12, 128)

        # Dense layers
        self.fc1 = nn.Sequential(
            nn.Linear(128 + 128 + 12 + 16, Nb), # 128 + 128 + 12 + 16
            nn.BatchNorm1d(Nb),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(Nb, Nb),
            nn.BatchNorm1d(Nb),
            nn.ReLU()
        )
        self.output = nn.Sequential(
            nn.Linear(Nb, 1),
            nn.ReLU()   # Assuming the output is non-negative
        )

    def forward(self, local_dem, regional_dem, month_onehot, reanalysis_input):
        """
        Inputs:
        - local_dem: (batch, 25)
        - regional_dem: (batch, 25)
        - month_onehot: (batch, 12)
        - reanalysis_input: (batch, 16, 3, 3)
        """

        # DEM branches
        local_feat = self.dem_local(local_dem)       # (batch, 128)
        regional_feat = self.dem_regional(regional_dem)  # (batch, 128)

        # Conv2D on reanalysis
        x_reanalysis = self.reanalysis_conv(reanalysis_input)  # (batch, 16, 1, 1)
        x_reanalysis = x_reanalysis.view(x_reanalysis.size(0), -1)  # Flatten → (batch, 16)

        # Concatenate all inputs
        x = torch.cat([local_feat, regional_feat, month_onehot, x_reanalysis], dim=1)  # (batch, 384)

        # Fully connected
        x = self.fc1(x)
        x = self.fc2(x)
        out = self.output(x)

        return out

In [4]:
# Model testing and evaluation
# batch_size = 4
# local_dem = torch.randn(batch_size, 25)
# regional_dem = torch.randn(batch_size, 25)
# month_onehot = torch.zeros(batch_size, 12)
# month_onehot[:, 3] = 1  # Assuming April (index 3) is the month of interest

# reanalysis_input = torch.randn(batch_size, 16, 3, 3)

# Use model
# model = LANDModel()
# output = model(local_dem, regional_dem, month_onehot, reanalysis_input)
# print(output.shape)  # torch.Size([4, 1])
