# load data

In [None]:
pwd

In [None]:
%matplotlib inline
from netCDF4 import Dataset
import xarray as xr

import numpy as np
import matplotlib.pyplot as plt
import seaborn
seaborn.set_style('darkgrid')
seaborn.set_context('notebook')

from predictability_utils.utils import helpers, io
from predictability_utils.methods.lrlin_method import run_lrlin
from predictability_utils.methods.cca_method import run_cca

import torch
torch.manual_seed(42)
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
    print("CUDA not available")
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")

root_data = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'

z500 = xr.open_mfdataset(f'{root_data}geopotential_500/*.nc', combine='by_coords').z
_, nlat, nlon = z500.shape

def create_training_data(data, lead_time_h, return_valid_time=False):
    """Function to split input and output by lead time."""
    X = data.isel(time=slice(0, -lead_time_h))
    y = data.isel(time=slice(lead_time_h, None))
    valid_time = y.time
    if return_valid_time:
        return X.values, y.values, valid_time
    else:
        return X.values, y.values

z500

In [None]:
def load_test_data(path, var, years=slice('2017', '2018')):
    """
    Args:
        path: Path to nc files
        var: variable. Geopotential = 'z', Temperature = 't'
        years: slice for time window
    Returns:
        dataset: Concatenated dataset for 2017 and 2018
    """
    assert var in ['z', 't'], 'Test data only for Z500 and T850'
    ds = xr.open_mfdataset(f'{path}/*.nc', combine='by_coords')[var]
    try:
        ds = ds.sel(level=500 if var == 'z' else 850).drop('level')
    except ValueError:
        pass
    return ds.sel(time=years)

def compute_weighted_rmse(da_fc, da_true, mean_dims=xr.ALL_DIMS):
    """
    Compute the RMSE with latitude weighting from two xr.DataArrays.
    Args:
        da_fc (xr.DataArray): Forecast. Time coordinate must be validation time.
        da_true (xr.DataArray): Truth.
    Returns:
        rmse: Latitude weighted root mean squared error
    """
    error = da_fc - da_true
    weights_lat = np.cos(np.deg2rad(error.lat))
    weights_lat /= weights_lat.mean()
    rmse = np.sqrt(((error)**2 * weights_lat).mean(mean_dims))
    if type(rmse) is xr.Dataset:
        rmse = rmse.rename({v: v + '_rmse' for v in rmse})
    else: # DataArray
        rmse.name = error.name + '_rmse' if not error.name is None else 'rmse'
    return rmse

def evaluate_iterative_forecast(fc_iter, da_valid):
    rmses = []
    for lead_time in fc_iter.lead_time:
        fc = fc_iter.sel(lead_time=lead_time)
        fc['time'] = fc.time + np.timedelta64(int(lead_time), 'h')
        rmses.append(compute_weighted_rmse(fc, da_valid))
    return xr.concat(rmses, 'lead_time')
    # return xr.DataArray(rmses, dims=['lead_time'], coords={'lead_time': fc_iter.lead_time})


In [None]:
lead_time = 3 * 24 # 3 days

# Split into train and test data
t_train = z500['time'].sel(time=slice('1979', '2016')).data.size - lead_time
t_all = z500['time'].sel(time=slice('1979', '2018')).data.size - lead_time

z500_data = z500.sel(time=slice('1979', '2018')) #z500.sel(time=slice('1979', '2016'))
idx_source_train, idx_target_train = np.arange(t_train)[None,:], np.arange(t_train)[None,:]
idx_source_test, idx_target_test = np.arange(t_train, t_all)[None,:], np.arange(t_train, t_all)[None,:]

idcs = (idx_source_train, idx_target_train, idx_source_test, idx_target_test)

n_latents = 5
map_shape = (nlat, nlon)

# Compute normalization statistics
z500_mean = z500_data[idx_source_train.squeeze(),:,:].mean().values
z500_std = z500_data[idx_source_train.squeeze(),:,:].std('time').mean().values

# Normalize datasets
data_z500 = (z500_data - z500_mean) / z500_std

source_data, target_data = create_training_data(data_z500, lead_time_h=lead_time)

In [None]:
z500_data.shape

# simple low-rank linear prediction (pixel MSEs) 

- set up simple model $Y = W X$ with $W = U V$
- low-rank: if $Y \in \mathbb{R}^N, X \in \mathbb{R}^M$, then $W \in \mathbb{R}^{N \times M}$, but $U \in \mathbb{R}^{N \times k}, V \in \mathbb{R}^{k \times M}$ with $k << M,N$ !
- low-rank structure saves us parameters: $M N$ parameters in $W$, but only $N k + k M$ in $U$ and $V$, helps prevent overfitting on low samples size

In [None]:
corrs_map, params = run_lrlin(source_data, target_data, n_latents, idcs, if_plot=True, map_shape=map_shape,
                              n_epochs=5, lr=1e-1, batch_size=10)

In [None]:
T = source_data.shape[0]
X = source_data.reshape(T, -1)[idx_source_test,:].mean(axis=0)
Y = target_data.reshape(T, -1)[idx_target_test,:].mean(axis=0)
Ypred = X.dot(params['V']).dot(params['U']) * z500_std + z500_mean

z500_valid = load_test_data(f'{root_data}geopotential_500', 'z') #z500.sel(time=slice('2017', '2018'))

compute_weighted_rmse(Ypred.reshape(-1, *map_shape), z500_valid).values

In [None]:
z500_std

# debug