In [1]:
import xarray as xr
import numpy as np

regridded_encode = xr.open_dataset("data/encoded_dataa_64x32.nc")

In [None]:
num_repeats = 5

expanded_regridded_encode = regridded_encode.isel(time=slice(0, 5)).copy()
expanded_regridded_encode = xr.concat([expanded_regridded_encode] * num_repeats, dim="time")

new_time_values = np.arange(25)
expanded_regridded_encode = expanded_regridded_encode.assign_coords(time=new_time_values)

In [None]:
import torch
from torch.optim import Adam
from torch.nn import MSELoss
from learnedPhysics import LearnedPhysicsModel
from dynamical_core import DynamicalCoreRunner
from loss_function import combined_loss
import jax

def train_model(expanded_regridded_encode, regridded_encode, variables_to_keep,
                num_iterations=24, integration_steps=5, loop_iterations=4, lr=1e-3):

    model = LearnedPhysicsModel()
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = MSELoss()

    rng_key = jax.random.PRNGKey(42)

    for i in range(num_iterations):

        data_lp = expanded_regridded_encode.isel(time=i)

        predicted_tendencies = model.forward(data_lp)

        predicted_tendencies = predicted_tendencies.view(6, 32, 64, 32)

        if (i + 1) % 5 == 0:

            runner = DynamicalCoreRunner(regridded_encode,
                                            integration_steps=integration_steps,
                                            loop_iterations=loop_iterations,
                                            time_i=0)
            out_state = runner.run()

            data_pred = expanded_regridded_encode.isel(time=i + 1)
            
            out_state_pred = out_state.drop_vars([var for var in out_state.data_vars if var not in variables_to_keep])
            data_pred = data_pred.drop_vars([var for var in data_pred.data_vars if var not in variables_to_keep])

            out_state_pred_tensor = torch.stack(
                [torch.tensor(out_state_pred[var].values, dtype=torch.float32)
                    for var in out_state_pred.data_vars],
                dim=0
            )
            data_pred_tensor = torch.stack(
                [torch.tensor(data_pred[var].values, dtype=torch.float32)
                    for var in data_pred.data_vars],
                dim=0
            )
            
            target_error = combined_loss(out_state_pred_tensor, data_pred_tensor)
            
            loss = criterion(predicted_tendencies, target_error)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Itération {i+1}: Loss = {loss.item():.6f}")

In [None]:
train_model(expanded_regridded_encode, regridded_encode, variables_to_keep)