In [1]:
import math
import sys

from icecream import ic, install

install()

sys.path.append("..")

In [2]:
import logging
import os

import pandas as pd
import torch
from model.models import ClimODE
from model.velocity import get_kernel, get_velocities
from torch import optim
from torch.utils.data import DataLoader
from utils.loss import CustomGaussianNLLLoss

import data.loading as loading
from data.dataset import Forcasting_ERA5Dataset
from data.embeddings import get_time_localisation_embeddings
from data.processing import select_data

variables_time_dependant = ["t2m", "t", "z", "u10", "v10"]
variables_static = ["lsm", "orography"]

gpu_device = torch.device("cpu")  # fallback to cpu
if torch.cuda.is_available():
    gpu_device = torch.device("cuda")
    torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
    gpu_device = torch.device("mps")
    torch.mps.empty_cache()

config = {
    "data_path_wb1": "../data/era5_data/",
    "data_path_wb2": "../data/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr",
    "freq": 6,  # In hours
    "nb_variable_time_dependant": len(variables_time_dependant),
    "periods": {
        "train": ("2006-01-01", "2015-12-31"),
        "val": ("2016-01-01", "2016-12-31"),
        "test": ("2017-01-01", "2018-12-31"),
    },
    "vel": {
        "rbf_alpha": 1.0,
        "stacking": 3,
        "bs": 50,
        "fitting_epoch": 200,
        "regul_coeff": 1e-7,
        "lr": 2,
        "device": gpu_device,
    },
    "model": {
        "VelocityModel": {
            "local": {
                "in_channels": 30 + 34,  # 34 d'embeding, 30 = jsp
                "layers_length": [5, 3, 2],
                "layers_hidden_size": [128, 64, 2 * 5],
                # 5 = out_types = len(paths_to_data)
            },
            "global": {
                "in_channels": 30 + 34,
                "out_channels": 2 * 5,
            },
            "gamma": 0.1,
        },
        "EmissionModel": {
            "in_channels": 9 + 34,  # err_in ; je sais pas pourquoi 9
            "layers_length": [3, 2, 2],
            "layers_hidden_size": [
                128,
                64,
                2 * len(variables_time_dependant),
            ],  # 5 = out_types = len(paths_to_data)
        },
        "norm_type": "batch",
        "n_res_blocks": [3, 2, 2],
        "kernel_size": 3,
        "stride": 1,
        "dropout": 0.1,
    },
    "bs": 8,
    "max_epoch": 300,
    "lr": 0.0005,
    "device": gpu_device,
}

if __name__ == "__main__":
    # check the script is executed within the parent directory

    logging.basicConfig(level=logging.INFO)

    periods = {
        k: pd.date_range(*p, freq=str(config["freq"]) + "H")
        for (k, p) in config["periods"].items()
    }
    raw_data = loading.wb1(config["data_path_wb1"], periods)
    train_raw_data = raw_data.sel(time=periods["train"])
    # data = loading.wb2(config["data_path_wb2"], periods)

    logging.info("Raw data loaded, merged and normalized")
    logging.info("Raw data disk size: {} MiB".format(raw_data.nbytes / 1e6))

    data_selected = select_data(raw_data, periods)

    kernel = get_kernel(raw_data, config["vel"])
    data_velocities = get_velocities(data_selected, kernel, config)
    train_velocities = torch.cat(tuple(data_velocities["train"].values()), dim=1).view(
        -1, 32, 64, 10
    )  # (1826, 10, 32, 64) -> (1826, 32, 64, 10) pour compatibilité avec les futurs cat

    train_data = torch.cat(
        [t.unsqueeze(-1) for t in data_selected["train"].values()], dim=-1
    )
    dataset = Forcasting_ERA5Dataset(train_data)
    train_loader = DataLoader(dataset, batch_size=config["bs"], shuffle=True)

    time_step = torch.Tensor(list(range(len(train_data))))
    time_step = torch.arange(0, len(train_data), 1)
    # time_step = torch.Tensor(list(range(22)))
    time_pos_embedding = get_time_localisation_embeddings(
        time_step,
        torch.tensor(train_raw_data["lat"].values),
        torch.tensor(train_raw_data["lon"].values),
        torch.tensor(train_raw_data["lsm"].values),
        torch.tensor(train_raw_data["orography"].values),
    ).float()  # float64 to float32 (important for conv) TODO
    model = ClimODE(config, time_pos_embedding)
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"])
    criterion = CustomGaussianNLLLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 300)

  k: pd.date_range(*p, freq=str(config["freq"]) + "H")
INFO:root:Raw data loaded, merged and normalized
INFO:root:Raw data disk size: 777.712696 MiB
INFO:root:Velocities for train loaded from cache
INFO:root:Velocities for val loaded from cache
INFO:root:Velocities for test loaded from cache


In [5]:
import torch
from model.conv import ClimateResNet2D
from torch import nn
from torch.nn import functional as F
from torchdiffeq import odeint

"""
WIP
"""


class EmissionModel(nn.Module):
    """
    Equivalent of noise_net_contrib() using a class format.
    """

    def __init__(self, config, time_pos_embedding):
        super().__init__()
        self.sub_config = config["model"]["EmissionModel"]
        self.model = ClimateResNet2D(
            self.sub_config["in_channels"],
            self.sub_config["layers_length"],
            self.sub_config["layers_hidden_size"],
            config,
        )
        self.time_pos_embedding = time_pos_embedding
        self.nb_var_time_dep = config["nb_variable_time_dependant"]

    def forward(
        self,
        t,
        x,
    ):
        """
        WIP, not tested yet.
        """
        ic(x.shape)  # [8, 32, 64, 5]
        original_x = x.reshape(
            8, self.nb_var_time_dep, 32, 64
        )  # view fonctionne pas et puis anyway j'dois faire une copy
        x = torch.cat([x, self.time_pos_embedding[t]], dim=-1)
        x = x.view(
            8, -1, 32, 64
        )  # Matching conv shape (8, 32, 64, 43) -> (8, 43, 32, 64)
        x = self.model(x)
        ic(
            x.shape, # [8, 10, 32, 64]
            x[:, : self.nb_var_time_dep].shape, original_x.shape
        )  
        # From original code, not sure if it's correct
        mean = original_x + x[:, : self.nb_var_time_dep]
        std = F.softmax(x[:, self.nb_var_time_dep :])
        return mean, std


class AttentionModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        hidden_channels = in_channels // 2
        self.query = self.make_layer(
            in_channels, in_channels // 8, hidden_channels, stride=1, padding=True
        )
        self.key = self.make_layer(
            in_channels, in_channels // 8, hidden_channels, stride=2
        )
        self.value = self.make_layer(
            in_channels, out_channels, hidden_channels, stride=2
        )
        self.post_map = nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1))

    @staticmethod
    def make_layer(in_channels, out_channels, hidden_channels, stride, padding=False):
        def get_block(in_channels, out_channels, stride, padding):
            if padding:
                block = [
                    nn.ReflectionPad2d((0, 0, 1, 1)),
                    nn.CircularPad2d((1, 1, 0, 0)),
                ]
            else:
                block = []

            block += [
                nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=stride),
                nn.LeakyReLU(0.3),
            ]
            return block

        return nn.Sequential(
            *get_block(in_channels, hidden_channels, stride, padding),
            *get_block(hidden_channels, out_channels, stride, padding),
            *get_block(out_channels, out_channels, 1, padding),
        )

    def forward(self, x):
        """WIP

        Parameters
        ----------
        x : _type_
            shape code origin: (1, 64, 32, 64)

        Returns
        -------
        _type_
            _description_
        """
        # On flatten sur la latitude et la longitude
        q = self.query(x).flatten(-2, -1)  # (1, 64, 32, 64) -> (1, 64, 2048)
        k = self.key(x).flatten(-2, -1)  # (1, 8, 3, 13) -> (1, 8, 65)
        v = self.value(x).flatten(-2, -1)  # (1, 10, 3, 13) -> (1, 10, 65)
        # ic(q.shape, k.shape, v.shape)
        # Il doit y avoir moyen de mieux faire, le contiguous est salle je pense
        attention_beta = F.softmax(torch.bmm(q.transpose(1, 2), k), dim=1)
        attention_beta = torch.bmm(v, attention_beta.transpose(1, 2))
        attention_beta = attention_beta.view(1, -1, 32, 64).contiguous()
        # ic(self.post_map(attention_beta).shape) # (1, 10, 32, 64)
        return self.post_map(attention_beta)
        """
        size = x.size()
        x = x.float()
        q, k, v = (
            self.query(x).flatten(-2, -1),
            self.key(x).flatten(-2, -1),
            self.value(x).flatten(-2, -1),
        )
        beta = F.softmax(torch.bmm(q.transpose(1, 2), k), dim=1)
        o = torch.bmm(v, beta.transpose(1, 2))
        o = self.post_map(o.view(-1, self.out_ch, size[-2], size[-1]).contiguous())
        return o
        """


class VelocityModel(nn.Module):
    """
    Equivalent of $f_\theta$ in the paper.
    """

    def __init__(self, config, time_pos_embedding):
        super().__init__()
        sub_config = config["model"]["VelocityModel"]
        self.time_pos_embedding = time_pos_embedding
        self.local_model = ClimateResNet2D(
            sub_config["local"]["in_channels"],
            sub_config["local"]["layers_length"],
            sub_config["local"]["layers_hidden_size"],
            config,
        )
        self.global_model = AttentionModel(
            sub_config["global"]["in_channels"],
            sub_config["global"]["out_channels"],
        )  # input original code ([1, 64, 32, 64])
        self.gamma = nn.Parameter(torch.tensor([sub_config["gamma"]]))

    def forward(self, t, x):
        """
        WIP, not tested yet.
        Input must directly have all the parameters concatenated.
        x: shape: (32,64,15) -> (32,64,10) + (32,64,5)
        """

        # Obligé de cat avant puis uncat ici car odeint ne peut pas split ces param je pense
        # pour le coup un tensors dict ici serait plus propre mais plus le temps
        # Si on passe en (batch, timestep, année, ...,...), il faudra rajouter un :
        past_velocity = x[:, :, :10]  # v in original code
        past_velocity_x = past_velocity[:, :, :5]
        past_velocity_y = past_velocity[:, :, 5:]
        past_velocity_grad_x = torch.gradient(past_velocity_x, dim=-2)[0]
        past_velocity_grad_y = torch.gradient(past_velocity_y, dim=-3)[0]

        x_0 = x[:, :, 10:]  # ds in original code
        x_0_grad_x = torch.gradient(x_0, dim=-2)[0]  # sur la dim de la logitude (64)
        x_0_grad_y = torch.gradient(x_0, dim=-3)[0]  # sur la dim de la latitude (32)
        nabla_u = torch.cat([x_0_grad_x, x_0_grad_y], dim=-1)  # (32,64,2*5)

        t_emb = t.view(1, 1, 1).expand(32, 64, 1)
        t = int(t.item()) * 100
        # ic(
        #     x.shape,
        #     nabla_u.shape,
        #     self.time_pos_embedding[t].shape,
        #     past_velocity.shape,
        #     x_0.shape,
        # )

        x = torch.cat([t_emb, x, nabla_u, self.time_pos_embedding[t]], dim=-1)
        # Unsquueze for simulate a batch of 1
        # and inverting the last dimension to the match CNN style conv (sorry j'aurai pu le faire avant j'ai merdé tant pis TODO)
        x = x.view(1, 64, 32, 64)
        dv = self.local_model(x)
        dv += self.gamma * self.global_model(x)
        dv = dv.squeeze().view(32, 64, -1)  # (32, 64, 10)

        adv1 = past_velocity_x * x_0_grad_x + past_velocity_y * x_0_grad_y
        adv2 = x_0 * (past_velocity_grad_x + past_velocity_grad_y)

        # ic(dv.shape, adv1.shape, adv2.shape)
        dvs = torch.cat([dv, adv1 + adv2], dim=-1)
        return dvs


class ClimODE(nn.Module):
    def __init__(self, config, time_pos_embedding):
        super(ClimODE, self).__init__()
        self.config = config
        self.device = config["device"]
        self.freq = config["freq"]
        self.velocity_model = VelocityModel(config, time_pos_embedding)
        self.emission_model = EmissionModel(config, time_pos_embedding)
        self.time_pos_embedding = time_pos_embedding

    def forward(self, t, x):
        """
        WIP, not tested yet.
        """

        # Calcul of news timesteps
        init_time = t[0].item() * self.freq
        final_time = t[-1].item() * self.freq
        steps_val = final_time - init_time
        ode_t = 0.01 * torch.linspace(
            init_time, final_time, steps=int(steps_val) + 1
        ).to(self.device)  # Je sais pas pourquoi 0.01

        # Solvings ODE
        x = odeint(self.velocity_model, x, ode_t, method="euler")
        # ic(x.shape) # torch.Size([43, 32, 64, 15])
        x = x[
            :, :, :, -5:
        ]  # On récupère que les données de la prédiction uniquement, pas des past velocities si je comprends bien ???
        # Nan je sais pas
        # ic(x.shape) # torch.Size([43, 32, 64, 5])
        x = x[::6]  # idk pourquoi on fait ça, je crois qu'on rediscretise en 8 morceaux
        # ic(x.shape) # ([8, 32, 64, 5])
        mean, std = self.emission_model(t, x)
        return mean, std


model = ClimODE(config, time_pos_embedding)

In [10]:
for epoch in range(config["max_epoch"]):
    if epoch == 0:
        var_coeff = 0.001
    else:
        var_coeff = 2 * scheduler.get_last_lr()[0]
    for i, x in enumerate(train_loader):
        t = time_step[i * config["bs"] : (i + 1) * config["bs"]]
        y_forecast = x[1:].to(gpu_device)
        # Past velocities cat with x[0]
        x_0 = torch.cat((train_velocities[i], x[0]), dim=-1)
        # simplement i car indexé comme ça, demander mathis why mais anyway,
        # maybe parce que c'est long à compute donc on en fait le minimum
        # train_velocities.shape: torch.Size([1826, 32, 64, 10])
        # train_data.shape: torch.Size([14605, 32, 64, 5])
        # 14605/8: 1825.625

        mean, std = model(t, x_0)
        # ic(mean.shape, std.shape) # [8, 5, 32, 64] both
        ic(x.shape)
        x = x.view(-1, config['nb_variable_time_dependant'], 32, 64)
        loss = criterion(x, mean, std, var_coeff)
        loss.backward()

        ic(x.shape)
        ic(y_forecast.shape)

    break

ic| x.shape: torch.Size([8, 32, 64, 5])
ic| x.shape: torch.Size([8, 10, 32, 64])
    x[:, : self.nb_var_time_dep].shape: torch.Size([8, 5, 32, 64])
    original_x.shape: torch.Size([8, 5, 32, 64])
  std = F.softmax(x[:, self.nb_var_time_dep :])
ic| x.shape: torch.Size([8, 32, 64, 5])
ic| x.shape: torch.Size([8, 5, 32, 64])
ic| y_forecast.shape: torch.Size([7, 32, 64, 5])
ic| x.shape: torch.Size([8, 32, 64, 5])
ic| x.shape: torch.Size([8, 10, 32, 64])
    x[:, : self.nb_var_time_dep].shape: torch.Size([8, 5, 32, 64])
    original_x.shape: torch.Size([8, 5, 32, 64])
ic| x.shape: torch.Size([8, 32, 64, 5])
ic| x.shape: torch.Size([8, 5, 32, 64])
ic| y_forecast.shape: torch.Size([7, 32, 64, 5])
ic| x.shape: torch.Size([8, 32, 64, 5])
ic| x.shape: torch.Size([8, 10, 32, 64])
    x[:, : self.nb_var_time_dep].shape: torch.Size([8, 5, 32, 64])
    original_x.shape: torch.Size([8, 5, 32, 64])
ic| x.shape: torch.Size([8, 32, 64, 5])


KeyboardInterrupt: 

In [None]:
lat = torch.tensor(raw_data.coords["lat"].values)
lon = torch.tensor(raw_data.coords["lon"].values)
lsm = torch.tensor(raw_data.lsm.values)
oro = torch.tensor(raw_data.orography.values)
raw_data

In [None]:
# TODO
def print_time(t, paper=False):
    day_in_years = t / 24  # 365 or 366
    hours_of_day = t % 24
    day_of_years = t // 24
    (torch.sin(2 * torch.pi * hours_of_day),)  # sin temporal embedding
    (torch.sin(2 * torch.pi * day_of_years / day_in_years),)  # sin seasonal embedding
    print(f"{t%24}ème heure")
    print(f"{t//24}ème jour")
    print(f"Sinus hour: {math.sin(t%24)}")
    print(f"Sinus day: {math.sin(t//24)}")
    if paper:
        t_papier = t % 24
        print("\npapier")
        print(f"{t_papier}ème heure")
        print(f"{t_papier/24}ème jour")
        print(f"Sinus hour: {math.sin(t_papier%24 - math.pi / 2)}")
        print(f"Sinus day: {math.sin(t_papier/24 - math.pi / 2)}")


feb_28 = (31 + 28) * 24 + 3  # 28 feb 3h
mar_1 = (31 + 28 + 1) * 24 + 3  # 1 mars 3h
mar_1_bi = (31 + 29 + 1) * 24 + 3  # 1 mars 3h
print_time(feb_28)
print("1er mars pas bissextile")
print_time(mar_1)
print("1er mars bissextile")
print_time(mar_1_bi)