The Decoder module generates local high-resolution forecast maps (1 km resolution) from the
coarse predictions. Even if the ConvLSTM/Processor operates on a lower-resolution grid (for
computational efficiency or due to patch embedding), the Decoder will upsample and refine the output.

U-Net architecture, a well-known convolutional network with an encoder-decoder structure and skip connections that preserve fine details . The U-Net takes as input the coarse forecast (e.g. a 2D field at, say, 10 km resolution) and outputs a finer 1 km grid.

-----------------------------------------------------------------------------------------------------------------------------
this acts as a super
resolution or downscaling model, adding local detail (potentially informed by high-res static data like
 topography or coastline, if we include those as additional inputs).

Our U-Net Decoder operates per forecast time step (it processes one frame at a time, independently,
 since spatial super-resolution can be learned time-independently). We design the U-Net with a
 contracting path that reduces the spatial dimension and an expanding path that increases it back, with
 skip connections from contracting to expanding path to preserve high-frequency information.

In [None]:
class UNetDecoder(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64):
    super().__init__()
    
    # Contracting path
    self.enc1 = nn.Sequential(
    nn.Conv2d(in_channels, base_channels, 3, padding=1), nn.ReLU(),
    nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU())
    self.pool1 = nn.MaxPool2d(2)
    self.enc2 = nn.Sequential(
    nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
    nn.ReLU())
    self.pool2 = nn.MaxPool2d(2)
    self.enc3 = nn.Sequential(
    nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(base_channels*4, base_channels*4, 3, padding=1),
    nn.ReLU())
    
    # Expanding path
    self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
    self.dec2 = nn.Sequential(
    nn.Conv2d(base_channels*4, base_channels*2, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
    nn.ReLU())
    self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2)
    self.dec1 = nn.Sequential(nn.Conv2d(base_channels*2, base_channels, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU())
    self.final = nn.Conv2d(base_channels, out_channels, kernel_size=1)
 def forward(self, x):
    
    # x: (B, in_channels, H_coarse, W_coarse)
    e1 = self.enc1(x) # (B, 64, H, W)
    p1 = self.pool1(e1) # (B, 64, H/2, W/2)
    e2 = self.enc2(p1) # (B, 128, H/2, W/2)
    p2 = self.pool2(e2) # (B, 128, H/4, W/4)
    e3 = self.enc3(p2) # (B, 256, H/4, W/4)
    
    # Decoder
    u2 = self.up2(e3) # (B, 128, H/2, W/2)
    u2 = torch.cat([u2, e2], dim=1) # skip connection concatenation
    d2 = self.dec2(u2) # (B, 128, H/2, W/2)
    u1 = self.up1(d2) # (B, 64, H, W)
    u1 = torch.cat([u1, e1], dim=1) # concat skip from e1
    d1 = self.dec1(u1) # (B, 64, H, W)
    out = self.final(d1)
    return out
# (B, out_channels, H, W)

 We integrate Encoder, Processor, and Decoder into one end-to-end model, which we’ll call
 WeatherForecastNet . During training, the dataflow is:
 1. 
2. 
3. 
4. 
Input: A sequence of past observations (or just the current state) on the common grid. For
 example, we might use the last 3 hourly grids as input to provide the model some notion of
 recent motion, although the problem statement focuses on assimilating current data.
 Encoder: Produces encoded features (global and/or spatial) from the input.
 Processor: Generates a sequence of coarse future predictions (up to 48 frames for 48 hours).
 Decoder: Upscales each coarse frame to high resolution.
 We train the model in a supervised manner using historical data. We need training pairs of (input data,
 ground-truth future data). Ground truth could come from reanalysis (like ERA5) or from the actual
 observations at +hours (for example, satellite images at future times, buoy readings at future times,
 etc., interpolated to the grid).