#### Load necessary packages

In [1]:
import os
import sys
sys.path.insert(0,'../../mocsy')

In [2]:
import numpy as np
import xarray as xr
import pickle 

import mocsy
from mocsy import mvars
from mocsy import mrhoinsitu
from mocsy import mrho

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

#### Check how to transform units from mol m-3 to mol kg-1

In [4]:
mrhoinsitu(np.array([35, 32]), np.array([18, 25]), np.array([5, 5]))

array([1025.2950932 , 1021.09931306])

In [5]:
mrho(35, 18, 5 / 10)

1025.2950931958476

-> rhoinsitu calls rho after transforming from dbar to bar, but it accepts arrays and is thus preferable.

#### Read in CMIP6 data

In [6]:
def concat_hist_ssp(hist_array, ssp_array):
    return np.concatenate([hist_array[-22*12:, :], ssp_array[:8*12, :]], axis=0)

def expand_co2(co2_array):
    return np.repeat(np.repeat(co2_array, 15, axis=1)[:, :, np.newaxis], 360, axis=2)

In [7]:
base = "/home/friedrich/Downloads/cmip6/"
co2 = expand_co2(concat_hist_ssp(xr.open_dataset(base
                          + "mole-fraction-of-carbon-dioxide-in-air_input4MIPs_GHGConcentrations"
                           "_CMIP_UoM-CMIP-1-2-0_gn-15x360deg_185001-201412.nc",
                           decode_times=False
                          ).mole_fraction_of_carbon_dioxide_in_air.values,
                      xr.open_dataset(base
                          + "mole-fraction-of-carbon-dioxide-in-air_input4MIPs_GHGConcentrations"
                          "_ScenarioMIP_UoM-MESSAGE-GLOBIOM-ssp245-1-2-1_gn-15x360deg_201501-210012.nc"
                          ).mole_fraction_of_carbon_dioxide_in_air.values))

In [8]:
var_list = ["talkos", "dissicos", "tos", "sos", "sios", "po4os",
            "mlotst", "zos", "chlos", "siconc", "uas", "vas", "pr", "clt"]

cmip6_data = {}
for var in var_list:
    cmip6_data[var] = []

In [9]:
domain_list = ["Omon", "Omon", "Omon", "Omon", "Omon", "Omon",
               "Omon", "Omon", "Omon", "SImon", "Amon", "Amon", "Amon", "Amon"]

In [10]:
# Add CMCC-ESM2 data
model = "CMCC-ESM2"
member_id = "r1i1p1f1"

for i, var in enumerate(var_list):
    file_hist = base + var + "/" + domain_list[i] + "/historical/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_historical_" + member_id + "_1x1reg_185001-201412.nc"
    file_ssp = base + var + "/" + domain_list[i] + "/ssp245/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_ssp245_" + member_id + "_1x1reg_201501-210012.nc"
    cmip6_data[var].append(concat_hist_ssp(xr.open_dataset(file_hist, decode_times=False)[var].values.squeeze(),
                                           xr.open_dataset(file_ssp)[var].values.squeeze()))

In [11]:
# Add MPI-ESM1-2-LR data"
model = "MPI-ESM1-2-LR"
member_id = "r1i1p1f1"

for i, var in enumerate(var_list):
    file_hist = base + var + "/" + domain_list[i] + "/historical/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_historical_" + member_id + "_1x1reg_185001-201412.nc"
    file_ssp = base + var + "/" + domain_list[i] + "/ssp245/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_ssp245_" + member_id + "_1x1reg_201501-210012.nc"
    cmip6_data[var].append(concat_hist_ssp(xr.open_dataset(file_hist, decode_times=False)[var].values.squeeze(),
                                           xr.open_dataset(file_ssp)[var].values.squeeze()))

In [12]:
# Add UKESM1-0-LL data 
model = "UKESM1-0-LL"
member_id = "r1i1p1f2"

for i, var in enumerate(var_list):
    file_hist = base + var + "/" + domain_list[i] + "/historical/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_historical_" + member_id + "_1x1reg_185001-201412.nc"
    file_ssp = base + var + "/" + domain_list[i] + "/ssp245/"\
        + model + "/" + member_id + "/gn/" + var + "_" + domain_list[i]\
        + "_" + model + "_ssp245_" + member_id + "_1x1reg_201501-210012.nc"
    cmip6_data[var].append(concat_hist_ssp(xr.open_dataset(file_hist, decode_times=False)[var].values.squeeze(),
                                           xr.open_dataset(file_ssp)[var].values.squeeze()))

In [13]:
print(len(cmip6_data["talkos"]))
print(cmip6_data["talkos"][0].shape)
print(cmip6_data["talkos"][1].shape)
print(cmip6_data["talkos"][2].shape)

3
(360, 180, 360)
(360, 180, 360)
(360, 180, 360)


#### Data strategy
since all data is on the same grid:
* load data, but only select range 1993-2022 (30y period), concat hist and ssp field
* draw random indices for time, lat, and lon, first for first model, then second, then third to fit memory
* check if any data for the lat-lon indix combination is NaN, if so draw new random set of indices, until 2e6 combinations.
* Load all variables: talkos, dissicos, tos, sos, sios, po4os, chlos, mlotst, zos, uas, vas, pr, clt, siconc

In [14]:
np.random.seed(0)

cmip6_random_samples = {}
for var in var_list:
    cmip6_random_samples[var] = []

cmip6_random_samples["co2"] = []
cmip6_random_samples["mon"] = []
cmip6_random_samples["lat"] = []
cmip6_random_samples["lon"] = []

lats = np.arange(-89.5, 90)
lons = np.arange(-179.5, 180)
mons = np.array([i % 12 for i in range(360)])

transform_list = ["talkos", "dissicos", "sios", "po4os"]

obs_per_model = 2000_000
neglected_ocean_obs = 0

for model_index in range(3):
    nobs = 0
    while nobs < obs_per_model:
        time_ind = np.random.randint(low=0, high=360)
        lat_ind = np.random.randint(low=0, high=180)
        lon_ind = np.random.randint(low=0, high=360)
        
        obs_is_nan = np.any(np.isnan(
            [cmip6_data[var][model_index][
                    time_ind, lat_ind, lon_ind] for var in var_list]))

        tos = cmip6_data["tos"][model_index][
                    time_ind, lat_ind, lon_ind]

        sos = cmip6_data["sos"][model_index][
                    time_ind, lat_ind, lon_ind]
        
        tos_out_of_range = tos > 35

        sos_out_of_range = not (10 <= sos <= 50)

        rho = mrhoinsitu(sos, tos, 5)[0]

        talkos_dissicos_out_of_range = not (1000e-6
                                         <= cmip6_data["dissicos"][model_index][
                    time_ind, lat_ind, lon_ind] / rho 
                                         <= cmip6_data["talkos"][model_index][
                    time_ind, lat_ind, lon_ind] / rho 
                                         <= 3000e-6)
        
        any_out_of_range = (tos_out_of_range or sos_out_of_range
                            or talkos_dissicos_out_of_range)

        if any_out_of_range and not obs_is_nan:
            neglected_ocean_obs += 1
        
        if not obs_is_nan and not any_out_of_range:
            for var in var_list:
                cmip6_random_samples[var].append(cmip6_data[var][model_index][
                    time_ind, lat_ind, lon_ind])
                
            cmip6_random_samples["co2"].append(co2[time_ind, lat_ind, lon_ind])
            
            cmip6_random_samples["mon"].append(mons[time_ind])
            cmip6_random_samples["lat"].append(lats[lat_ind])
            cmip6_random_samples["lon"].append(lons[lon_ind])

            for var in transform_list:
                cmip6_random_samples[var][-1] /= rho
            
            nobs += 1

In [15]:
print("The number of neglected samples: {} ({:.2f}%)".format(
    neglected_ocean_obs, 100 * neglected_ocean_obs / 6000000))

The number of neglected samples: 9900 (0.17%)


In [16]:
for var in var_list + ["co2", "mon", "lat", "lon"]:
    cmip6_random_samples[var] = np.array(cmip6_random_samples[var])

In [17]:
print("Some statistics:")
for var in var_list + ["co2", "mon", "lat", "lon"]:
    print("Mean for {} samples: {:.6e}, std: {:.6e}, min: {:.6e}, max: {:.6e}".format(
    var, np.mean(cmip6_random_samples[var]),
    np.std(cmip6_random_samples[var]),
    np.min(cmip6_random_samples[var]),
    np.max(cmip6_random_samples[var])))
    print("-----")

Some statistics:
Mean for talkos samples: 2.244086e-03, std: 1.116865e-04, min: 1.002897e-03, max: 2.998658e-03
-----
Mean for dissicos samples: 2.018905e-03, std: 9.913828e-05, min: 1.000056e-03, max: 2.840916e-03
-----
Mean for tos samples: 1.385280e+01, std: 1.138069e+01, min: -2.079249e+00, max: 3.499847e+01
-----
Mean for sos samples: 3.392466e+01, std: 1.777865e+00, min: 1.000003e+01, max: 4.427740e+01
-----
Mean for sios samples: 1.619789e-05, std: 2.632815e-05, min: 1.858408e-11, max: 1.125364e-04
-----
Mean for po4os samples: 5.207525e-07, std: 5.477110e-07, min: -6.572454e-10, max: 3.090078e-06
-----
Mean for mlotst samples: 5.860219e+01, std: 8.698860e+01, min: 9.611196e-01, max: 4.780963e+03
-----
Mean for zos samples: -5.760545e-02, std: 7.792690e-01, min: -2.078595e+00, max: 1.822147e+00
-----
Mean for chlos samples: 3.806836e-07, std: 8.655427e-07, min: 5.255121e-30, max: 3.077705e-05
-----
Mean for siconc samples: 1.371948e+01, std: 3.225763e+01, min: 0.000000e+00, max:

In [18]:
with open('../data/cmip6_random_samples.pkl', 'wb') as f:
    pickle.dump(cmip6_random_samples, f)

### Check how well the fCO2 NN model performs on the CMIP6 samples for AT, CT, T, S, SiT, PT

#### Reload NN model for fCO2

In [19]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, hidden_size_3, output_size):
        super().__init__()
        self.device = torch.device("cuda")
        self.linear1 = nn.Linear(input_size, hidden_size_1, device=self.device)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2, device=self.device)
        self.linear3 = nn.Linear(hidden_size_2, hidden_size_3, device=self.device)
        self.linear4 = nn.Linear(hidden_size_3, output_size, device=self.device)

    def forward(self, x):
        x = x.to(self.device)
        x = F.elu(self.linear1(x))
        x = F.elu(self.linear2(x))
        x = F.elu(self.linear3(x))
        x = F.elu(self.linear4(x))
        return x

    def save(self, file_name='model.pth'):
        model_folder_path = '../models'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)

        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

model = MLP(6, 160, 160, 160, 1)
model_state_dict = torch.load("../models/fCO2_model_160x3_elu_10000epo.pth")
model.load_state_dict(model_state_dict)

<All keys matched successfully>

#### Define mocsy fCO2

In [20]:
def extend(number, template):
    """
    helper function allowing calc_fco2 to work on floats and arrays.
    """
    template_type = "array" if isinstance(template, np.ndarray) else "float"
    if template_type == "array":
        return number * np.ones(template.shape)
    elif template_type == "float":
        return number

def calc_fCO2(alk, dic, tem, sal, sil, phos):
    """
    input units
    alk in mol / kg
    dic in mol / kg
    tem in °C
    sal in PSU
    sil in mol / kg
    phos in mol / kg
    """
    return mvars(alk=alk,
                     dic=dic,
                     temp=tem,
                     sal=sal,
                     sil=sil,
                     phos=phos,
                     patm=extend(1, alk),
                     depth=extend(5, alk),
                     lat=extend(np.nan, alk),
                     optcon='mol/kg',
                     optt='Tpot',
                     optp='db',
                     optk1k2='l',
                     optb='u74',
                     optkf='pf',
                     opts='Sprc')[2]

#### Assess error

In [21]:
mocsy_fco2 = calc_fCO2(cmip6_random_samples["talkos"], cmip6_random_samples["dissicos"],
                        cmip6_random_samples["tos"], cmip6_random_samples["sos"],
                        cmip6_random_samples["sios"], cmip6_random_samples["po4os"])

In [22]:
alk_range = [1000e-6, 3000e-6]
dic_range = [1000e-6, 3000e-6]
tem_range = [-2, 35]
sal_range = [10, 50]
sil_range = [0, 134e-6]
phos_range = [0, 4e-6]

sample_means = {"talkos":(alk_range[0] + alk_range[1]) / 2,
                 "dissicos":dic_range[0] + (dic_range[1] - dic_range[0]) / 4,
                 "tos":(tem_range[0] + tem_range[1]) / 2,
                 "sos":(sal_range[0] + sal_range[1]) / 2,
                 "sios":(sil_range[0] + sil_range[1]) / 2,
                 "po4os":(phos_range[0] + phos_range[1]) / 2}

sample_stds = {"talkos":(alk_range[1] - alk_range[0]) / np.sqrt(12),
                 "dissicos":(dic_range[1] - dic_range[0]) * np.sqrt(7 / 144),
                 "tos":(tem_range[1] - tem_range[0]) / np.sqrt(12),
                 "sos":(sal_range[1] - sal_range[0]) / np.sqrt(12),
                 "sios":(sil_range[1] - sil_range[0]) / np.sqrt(12),
                 "po4os":(phos_range[1] - phos_range[0]) / np.sqrt(12)}

cmip6_samples_normalized = np.concatenate([(cmip6_random_samples[key][:, np.newaxis]
                                             - sample_means[key]) / sample_stds[key]
                                 for key in ["talkos", "dissicos", "tos", "sos", "sios", "po4os"]],
                                          axis=1)

In [23]:
model.eval()
with torch.no_grad():
    nn_fco2 = model(torch.from_numpy(
        cmip6_samples_normalized.astype("float32"))).detach().cpu().numpy().squeeze()

In [24]:
def MSE(x, y): # define MSE for offline calculations on numpy arrays
    return np.sum((x-y)**2)/len(x)

In [25]:
print("MSE: ", MSE(mocsy_fco2, nn_fco2))
print("RMSE: ", np.sqrt(MSE(mocsy_fco2, nn_fco2)))
print("Maximum absolute deviation: ", np.max(np.abs(nn_fco2-mocsy_fco2)))
print("99.9th percentile of absolute deviation (1000 val's larger): ",
      np.percentile(np.abs(nn_fco2-mocsy_fco2), q=99.9))
print("99.99th percentile of absolute deviation (100 val's larger): ",
      np.percentile(np.abs(nn_fco2-mocsy_fco2), q=99.99))
print("99.999th percentile of absolute deviation (10 val's larger): ",
      np.percentile(np.abs(nn_fco2-mocsy_fco2), q=99.999))

MSE:  0.0008695582316074656
RMSE:  0.029488272781013566
Maximum absolute deviation:  0.2435986346279151
99.9th percentile of absolute deviation (1000 val's larger):  0.1089044869370121
99.99th percentile of absolute deviation (100 val's larger):  0.11951280805522761
99.999th percentile of absolute deviation (10 val's larger):  0.12620468803811347


### Check how well the pH NN model performs on the CMIP6 sample for AT, CT, T, S, SiT, PT

#### Reload NN model for pH

In [27]:
model_state_dict = torch.load("../models/pH_model_160x3_elu_10000epo.pth")
model.load_state_dict(model_state_dict)

<All keys matched successfully>

#### Define mocsy pH

In [28]:
def calc_pH(alk, dic, tem, sal, sil, phos):
    """
    input units
    alk in mol / kg
    dic in mol / kg
    tem in °C
    sal in PSU
    sil in mol / kg
    phos in mol / kg
    """
    return mvars(alk=alk,
                     dic=dic,
                     temp=tem,
                     sal=sal,
                     sil=sil,
                     phos=phos,
                     patm=extend(1, alk),
                     depth=extend(5, alk),
                     lat=extend(np.nan, alk),
                     optcon='mol/kg',
                     optt='Tpot',
                     optp='db',
                     optk1k2='l',
                     optb='u74',
                     optkf='pf',
                     opts='Sprc')[0]

#### Assess error

In [29]:
mocsy_ph = calc_pH(cmip6_random_samples["talkos"], cmip6_random_samples["dissicos"],
                        cmip6_random_samples["tos"], cmip6_random_samples["sos"],
                        cmip6_random_samples["sios"], cmip6_random_samples["po4os"])

In [30]:
model.eval()
with torch.no_grad():
    nn_ph = model(torch.from_numpy(
        cmip6_samples_normalized.astype("float32"))).detach().cpu().numpy().squeeze()

In [31]:
print("MSE: ", MSE(mocsy_ph, nn_ph))
print("RMSE: ", np.sqrt(MSE(mocsy_ph, nn_ph)))
print("Maximum absolute deviation: ", np.max(np.abs(nn_ph-mocsy_ph)))
print("99.9th percentile of absolute deviation (1000 val's larger): ",
      np.percentile(np.abs(nn_ph-mocsy_ph), q=99.9))
print("99.99th percentile of absolute deviation (100 val's larger): ",
      np.percentile(np.abs(nn_ph-mocsy_ph), q=99.99))
print("99.999th percentile of absolute deviation (10 val's larger): ",
      np.percentile(np.abs(nn_ph-mocsy_ph), q=99.999))

MSE:  3.9446634414756045e-09
RMSE:  6.280655572052653e-05
Maximum absolute deviation:  0.0005569656017598845
99.9th percentile of absolute deviation (1000 val's larger):  0.0002824040739381581
99.99th percentile of absolute deviation (100 val's larger):  0.00038139157592631534
99.999th percentile of absolute deviation (10 val's larger):  0.0005012017243931727
