In [63]:
!pip install --upgrade plotly

Collecting plotly
  Downloading plotly-6.4.0-py3-none-any.whl.metadata (8.5 kB)
Downloading plotly-6.4.0-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m14.5 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: plotly
  Attempting uninstall: plotly
    Found existing installation: plotly 6.3.0
    Uninstalling plotly-6.3.0:
      Successfully uninstalled plotly-6.3.0
Successfully installed plotly-6.4.0


In [64]:
import pickle as pkl
import random
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

import plotly
import chart_studio.plotly as py
import plotly.graph_objs as go
import seaborn as sns

from numpy.random import gamma, binomial
import numpy as np
import scipy
import pandas as pd 
import xarray as xr 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from convCNP.validation.utils import get_dists
import os
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
import xarray as xr

from convCNP.models.elev_models import TmaxBiasConvCNPElev, GammaBiasConvCNPElev
from convCNP.models.cnn import CNN, ResConvBlock
from convCNP.training.training_elev import train_elev
from convCNP.training.loss_functions import gll, gamma_ll
from convCNP.training.utils import get_value_tmax
from convCNP.validation.utils import get_dists, generate_context_mask, get_output, load_model

In [65]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [66]:
# Data and mode parameters
VARIABLE = 'tmax'   # 'tmax' or 'precip'
DATA_YEAR_START = None  # Year to start loading data from, None to include all data
                        # Used for quicker testing

# Model parameters
N_CHANNELS = 128    # default in paper is 128
N_BLOCKS = 6        # default in paper is 6
KERNEL_SIZE = 5     # default in paper is 5
LENGTH_SCALE = 0.1  # default in paper is 0.1
IN_CHANNELS = 25    # default in paper is 25

# Training parameters
N_EPOCHS = 100      # default in paper is 100
BATCH_SIZE = 16     # default in paper is 16    # TODO: wire this in
LR = 5e-4           # default in paper is 5e-4
PATIENCE = 10       # default in paper is 10    # TODO: wire this in

# Cross-validation parameters
N_FOLDS = 5         # default in paper is 5

# Other parameters
SEED = 42

In [67]:
OUTPUT_DIR = 'trained_models/'

# Auto-set data paths based on variable
if VARIABLE == 'tmax':
    GRID_INPUTS, GRID_TARGETS = f'./datasets/ERA5_Land/max_temperature/*.nc', f'./datasets/EOBS/max_temperature/*.nc'
elif VARIABLE == 'precip':
    GRID_INPUTS, GRID_TARGETS = f'./datasets/ERA5_Land/precipitation/*.nc', f'./datasets/EOBS/precipitation/*.nc'
ELEV_INPUTS = ''  # Empty for now

In [68]:
# Load data
full_input_ds = xr.open_mfdataset(GRID_INPUTS, combine='by_coords')
full_target_ds = xr.open_mfdataset(GRID_TARGETS, combine='by_coords')

print("Input data:\n", full_input_ds)
print("\nTarget data:\n", full_target_ds)

Input data:
 <xarray.Dataset> Size: 166MB
Dimensions:    (time: 23376, latitude: 29, longitude: 61)
Coordinates:
  * time       (time) datetime64[ns] 187kB 1960-01-01 1960-01-02 ... 2023-12-31
  * latitude   (latitude) float64 232B 48.2 48.1 48.0 47.9 ... 45.6 45.5 45.4
  * longitude  (longitude) float64 488B 5.0 5.1 5.2 5.3 ... 10.7 10.8 10.9 11.0
Data variables:
    t2m_max    (time, latitude, longitude) float32 165MB dask.array<chunksize=(366, 29, 61), meta=np.ndarray>

Target data:
 <xarray.Dataset> Size: 122MB
Dimensions:    (time: 19540, latitude: 30, longitude: 52)
Coordinates:
  * time       (time) datetime64[ns] 156kB 1971-01-01 1971-01-02 ... 2024-06-30
  * latitude   (latitude) float64 240B 45.35 45.45 45.55 ... 48.05 48.15 48.25
  * longitude  (longitude) float64 416B 5.65 5.75 5.85 ... 10.55 10.65 10.75
Data variables:
    t_max      (time, latitude, longitude) float32 122MB dask.array<chunksize=(1, 30, 52), meta=np.ndarray>
Attributes:
    CDI:            Climate Data Int

In [69]:
# Select data from YEAR_START onwards
if DATA_YEAR_START is not None:
    input_ds = full_input_ds.sel(time=slice(str(DATA_YEAR_START), None))
    target_ds = full_target_ds.sel(time=slice(str(DATA_YEAR_START), None))
else:
    input_ds = full_input_ds
    target_ds = full_target_ds

print("Input data xarray:\n", input_ds)
print("\nTarget data xarray:\n", target_ds)

# create a torch tensor with dimensions (time, lat, long)

Input data xarray:
 <xarray.Dataset> Size: 166MB
Dimensions:    (time: 23376, latitude: 29, longitude: 61)
Coordinates:
  * time       (time) datetime64[ns] 187kB 1960-01-01 1960-01-02 ... 2023-12-31
  * latitude   (latitude) float64 232B 48.2 48.1 48.0 47.9 ... 45.6 45.5 45.4
  * longitude  (longitude) float64 488B 5.0 5.1 5.2 5.3 ... 10.7 10.8 10.9 11.0
Data variables:
    t2m_max    (time, latitude, longitude) float32 165MB dask.array<chunksize=(366, 29, 61), meta=np.ndarray>

Target data xarray:
 <xarray.Dataset> Size: 122MB
Dimensions:    (time: 19540, latitude: 30, longitude: 52)
Coordinates:
  * time       (time) datetime64[ns] 156kB 1971-01-01 1971-01-02 ... 2024-06-30
  * latitude   (latitude) float64 240B 45.35 45.45 45.55 ... 48.05 48.15 48.25
  * longitude  (longitude) float64 416B 5.65 5.75 5.85 ... 10.55 10.65 10.75
Data variables:
    t_max      (time, latitude, longitude) float32 122MB dask.array<chunksize=(1, 30, 52), meta=np.ndarray>
Attributes:
    CDI:            Cl

In [70]:
def calculate_dists(era5_inputs, eobs_targets):
    """
    Get the distances between the grid points and true points.
    """
    era5_long_grid, era5_lat_grid = np.meshgrid(era5_inputs['longitude'], era5_inputs['latitude'])
    era5_lat_grid = torch.from_numpy(era5_lat_grid).to(device)
    era5_long_grid = torch.from_numpy(era5_long_grid).to(device)

    eobs_coords = eobs_targets.stack(coords=['latitude', 'longitude'])['coords'].values

    return get_dists(eobs_coords, era5_lat_grid, era5_long_grid)

dists = calculate_dists(input_ds, target_ds).to(device)

In [71]:
def set_seed(seed):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

In [73]:
# Get the coordinates where there is actual data
target_coords = target_ds['t_max'].isel(time=0).to_dataframe().dropna().reset_index()
input_coords = input_ds['t2m_max'].isel(time=0).to_dataframe().dropna().reset_index()

fig = go.Figure()

fig.add_trace(go.Scattergeo(
    lon = target_coords['longitude'],
    lat = target_coords['latitude'],
    mode = 'markers',
    marker_color = 'olivedrab',
    marker = dict(
        size = 3
    ),
    name = 'EOBS (target)'
    ))

fig.add_trace(go.Scattergeo(
    lon = input_coords['longitude'],
    lat = input_coords['latitude'],
    mode = 'markers',
    marker_color = 'sandybrown',
    marker = dict(
        size = 3
    ),
    name = 'ERA5 (input)'
    ))

fig.update_layout(
    title = 'EOBS target grid and ERA5 input grid (only points with data)',
    geo = dict(
        scope='europe',
        # lataxis = dict(range=[41, 52]),
        # lonaxis = dict(range=[2, 15]),
        fitbounds='locations',
        # constrainwidth='width',
    ),
    width=1000,
    height=700
    )
fig.show()