In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.metrics import mean_squared_error, r2_score

import torch

In [2]:
dataset = xr.open_dataset("../../combined_data/jan2025.nc")
input_vars = ["d2m", "u10", "v10", "tp", "lai_hv", "lai_lv"]
output_vars = ["frp"]
dataset[input_vars]

In [None]:
# Define the ConvLSTM model
import torch.nn as nn
import torch.nn.functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        
        # Gates: input, forget, cell, output
        self.conv = nn.Conv2d(
            in_channels=input_channels + hidden_channels,
            out_channels=4 * hidden_channels,
            kernel_size=kernel_size,
            padding=self.padding
        )
    
    def forward(self, x, h_prev, c_prev):
        # Combined input
        combined = torch.cat([x, h_prev], dim=1)
        
        # Convolution
        conv_output = self.conv(combined)
        
        # Split gates
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)
        
        # Apply activations
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        # Update cell state and hidden state
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, num_layers, output_channels):
        super(ConvLSTM, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        
        # Create ConvLSTM cells
        self.cells = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                self.cells.append(ConvLSTMCell(input_channels, hidden_channels, kernel_size))
            else:
                self.cells.append(ConvLSTMCell(hidden_channels, hidden_channels, kernel_size))
        
        # Output layer
        self.output_layer = nn.Conv2d(hidden_channels, output_channels, kernel_size=1)
    
    def forward(self, x):
        # x shape: (batch_size, seq_len, height, width, channels)
        batch_size, seq_len, height, width, channels = x.size()
        
        # Initialize hidden and cell states
        h = [torch.zeros(batch_size, self.hidden_channels, height, width, device=x.device) 
             for _ in range(self.num_layers)]
        c = [torch.zeros(batch_size, self.hidden_channels, height, width, device=x.device) 
             for _ in range(self.num_layers)]
        
        # Process sequence
        output_seq = []
        for t in range(seq_len):
            # Get current input (batch_size, height, width, channels)
            x_t = x[:, t]
            
            # Reshape to (batch_size, channels, height, width) for convolution
            x_t = x_t.permute(0, 3, 1, 2)
            
            # Process through layers
            for l in range(self.num_layers):
                if l == 0:
                    h[l], c[l] = self.cells[l](x_t, h[l], c[l])
                else:
                    h[l], c[l] = self.cells[l](h[l-1], h[l], c[l])
            
            # Get output from last layer
            out = self.output_layer(h[-1])
            output_seq.append(out)
        
        # Stack outputs along sequence dimension
        outputs = torch.stack(output_seq, dim=1)
        return outputs


In [3]:
X = torch.tensor(np.stack([dataset[input_var].values for input_var in input_vars], axis=-1), dtype=torch.float32)
X.shape

y = torch.tensor(dataset['frp'].values, dtype=torch.float32)
y.shape

array([[[[ 2.73648438e+02,  2.31756210e-01, -1.31976318e+00,
           1.08896638e-05,  2.22534180e+00,  2.13476562e+00],
         [ 2.73448242e+02,  7.92903900e-02, -1.48608398e+00,
           6.59578654e-06,  3.32202148e+00,  1.88690186e+00],
         [ 2.73350098e+02, -6.51187897e-02, -1.62628174e+00,
           5.69983968e-07,  3.84265137e+00,  1.71270752e+00],
         ...,
         [ 2.70825684e+02, -1.45098305e+00, -9.96276855e-01,
           0.00000000e+00,  0.00000000e+00,  5.15624940e-01],
         [ 2.70826172e+02, -1.38738441e+00, -9.98779297e-01,
           0.00000000e+00,  0.00000000e+00,  5.15624940e-01],
         [ 2.71061035e+02, -1.26360512e+00, -9.84313965e-01,
           0.00000000e+00,  0.00000000e+00,  5.15624940e-01]],

        [[ 2.73007812e+02, -9.22183990e-02, -1.32757568e+00,
           9.58330929e-07,  2.24645996e+00,  2.39837646e+00],
         [ 2.72827148e+02, -2.69464493e-01, -1.51165771e+00,
           4.24275640e-07,  3.39392090e+00,  2.09686279e+00],


In [None]:
# Split the data along the time dimension (first axis)
# Use the last 30% of the time steps for testing
split_idx = int(X.shape[0] * 0.7)  # 70% for training, 30% for testing

# Split the input data
X_train = X[:split_idx]
X_test = X[split_idx:]

# Split the target data
y_train = y[:split_idx]
y_test = y[split_idx:]

print(f"Training data shape: {X_train.shape}")
print(f"Testing data shape: {X_test.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Testing labels shape: {y_test.shape}")
