In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import h5py
import natsort


In [2]:
jax.config.update("jax_enable_x64", True)


In [3]:
#name1 = 'results_equinet_cnn_squares_1000_crps.h5'
#name2 = 'results_equinet_cnn_squares_2000_crps.h5'
#name3 = 'results_equinet_cnn_squares_4000_crps.h5'
#name4 = 'results_equinet_cnn_squares_8000_crps.h5'
#name5 = 'results_equinet_cnn_squares_16000_crps.h5'
name = 'results_equinet_cnn_squares.h5'

with h5py.File(f'{name}', 'r') as f:
    eta_re = f['eta'][()]
    eta_re_pred = f['eta_pred'][()]


## RRMSE

In [5]:
cond_samples_metric = eta_re_pred

In [7]:
MSE = np.zeros(cond_samples_metric.shape[0])
for i in range(cond_samples_metric.shape[0]):
    MSE[i] = np.linalg.norm(eta_re[i,:,:]-cond_samples_metric[i,0,:,:])/np.linalg.norm(eta_re[i,:,:])

In [8]:
print('Mean of validation relative l2 error:', np.mean(MSE))
print('Median of validation relative l2 error:', np.median(MSE))
print('Min of validation relative l2 error:', np.min(MSE))
print('Max of validation relative l2 error:', np.max(MSE))
print('Standard deviation of validation relative l2 errors:', np.std(MSE))

Mean of validation relative l2 error: 0.017200465539882553
Median of validation relative l2 error: 0.009055773319286108
Min of validation relative l2 error: 0.005048117424074934
Max of validation relative l2 error: 0.09298317870554379
Standard deviation of validation relative l2 errors: 0.01909906958707674


### For probabilistic

In [9]:
cond_samples_metric = np.swapaxes(np.swapaxes(eta_re_pred[:,:,:,:,0],1,2),2,3)

In [10]:
MSE = np.zeros((cond_samples_metric.shape[0],cond_samples_metric.shape[-1]))

for i in range(cond_samples_metric.shape[0]):
    for j in range(cond_samples_metric.shape[-1]):
        MSE[i,j] = np.linalg.norm(eta_re[i,:,:,0]-cond_samples_metric[i,:,:,j])/np.linalg.norm(eta_re[i,:,:,0])


In [11]:
print('Mean of validation relative l2 error:', np.mean(MSE))
print('Median of validation relative l2 error:', np.median(MSE))
print('Min of validation relative l2 error:', np.min(MSE))
print('Max of validation relative l2 error:', np.max(MSE))
print('Standard deviation of validation relative l2 errors:', np.std(MSE))

Mean of validation relative l2 error: 0.017441053943523816
Median of validation relative l2 error: 0.008988638466762119
Min of validation relative l2 error: 0.004793373108378484
Max of validation relative l2 error: 0.11489494693466881
Standard deviation of validation relative l2 errors: 0.019381169819276747


## CRPS

### For probabilistic

In [12]:
from swirl_dynamics.lib import metrics

In [13]:
cond_samples_metric = np.swapaxes(np.swapaxes(eta_re_pred[:,:,:,:,0],1,2),2,3)

In [14]:
crpss_1 = np.zeros((cond_samples_metric.shape[0],80,80))
for i in range(cond_samples_metric.shape[0]):
    crpss_1[i,:,:] = metrics.probabilistic_forecast.crps(cond_samples_metric[i,:,:,:], eta_re[i,:,:,0], ensemble_axis=-1)

In [15]:
# computing the \ell^1 metric
crpss_11 = np.zeros(cond_samples_metric.shape[0])
for i in range(cond_samples_metric.shape[0]):
    crpss_11[i] = np.sum(np.abs(crpss_1[i,:,:]))/(80*80)

In [16]:
print(np.mean(crpss_11)*10**4)

3.9164176288078565


## Energy Spectrum

In [17]:
from pysteps.utils.spectral import rapsd

Pysteps configuration file found at: /grad/bzhang388/anaconda3/envs/jaxflax/lib/python3.11/site-packages/pysteps/pystepsrc



In [18]:
cond_samples_metric = eta_re_pred[:,0,:,:,0]

In [19]:
rapsds_ref = np.zeros((eta_re.shape[0], 40))
rapsds = np.zeros((eta_re.shape[0],40))
rapsds_error_ = np.zeros(eta_re.shape[0])
for i in range(eta_re.shape[0]):
    rapsds_ref[i,:] = rapsd(eta_re[i,:,:,0],fft_method=np.fft)
    rapsds[i,:] = rapsd(cond_samples_metric[i,:,:],fft_method=np.fft)

In [20]:
rapsds_error_ = np.abs(np.log(rapsds/rapsds_ref))

In [21]:
print('radially averaged power spectrum', np.mean(rapsds_error_)*100)

radially averaged power spectrum 1.9212547201809513


### For probabilistic

In [22]:
cond_samples_metric = np.swapaxes(np.swapaxes(eta_re_pred[:,:,:,:,0],1,2),2,3)

In [23]:
rapsds = np.zeros((cond_samples_metric.shape[0], cond_samples_metric.shape[-1], 40))
rapsds_ref = np.zeros((cond_samples_metric.shape[0],40))
rapsds_mean = np.zeros((cond_samples_metric.shape[0],40))
for i in range(cond_samples_metric.shape[0]):
    rapsds_ref[i,:] = rapsd(eta_re[i,:,:,0],fft_method=np.fft)
    for j in range(cond_samples_metric.shape[-1]):
        rapsds[i,j,:] = rapsd(cond_samples_metric[i,:,:,j],fft_method=np.fft)

In [39]:
rapsds_error_ = np.array([np.mean(np.abs(np.log(rapsds[:,i,:]) - np.log(rapsds_ref)),axis=1) for i in range(50)])

In [40]:
print(np.mean(rapsds_error_)*100)

1.979019942466451


## Sinkhorn

In [26]:
import ott

from ott import problems
from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.problems.linear import linear_problem
from ott.solvers.linear import acceleration, sinkhorn
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

In [27]:
eta_re_pred.shape

(500, 50, 80, 80, 1)

In [34]:
eta_re.shape

(500, 80, 80, 1)

In [35]:
pred = eta_re_pred[:,0,:,:,:]

In [36]:
samples_pred = pred.reshape(500,6400)
samples_true = eta_re.reshape(500,6400)

In [37]:
solver = sinkhorn.Sinkhorn()
geom = pointcloud.PointCloud(samples_pred, samples_true, cost_fn=costs.Euclidean())
ot_prob_test = linear_problem.LinearProblem(geom)
ot = solver(ot_prob_test)

In [38]:
print('testing', ot.reg_ot_cost)

testing 3.860173790392494


In [33]:
ot.reg_ot_cost.dtype

dtype('float64')