In [None]:
#MLD adjustment applied on profiles dataset coming from product_generation_bootstrap.py
#date : February 2022
#author : Etienne Pauthenet (etienne.pauthenet@gmail.com)

import xarray as xr
import gsw
import numpy as np
import time
from numba import float64, guvectorize
import pandas as pd
from pathlib import Path

In [None]:
def add_mld(ds):
    sig_diff = ds.SIG_adj - ds.SIG_adj.sel(PRES_INTERPOLATED = 10)-0.03
    MLD = sig_diff['PRES_INTERPOLATED'].where(sig_diff>0).min(dim='PRES_INTERPOLATED')
    ds = ds.assign(variables={"mld": (('time','lat','lon'), MLD.data)})
    return ds

In [None]:
osnet_rep = ''
def get_data(path, yy, mm):
    ds = xr.open_dataset(f"{osnet_rep}/produit_{yy}{mm}.nc")
    SA_predicted = gsw.SA_from_SP(ds['PSAL_predicted'], ds['PRES_INTERPOLATED'], ds['lon'], ds['lat'])
    CT_predicted = gsw.CT_from_t(SA_predicted,ds['TEMP_predicted'],ds['PRES_INTERPOLATED'])
    SIG_predicted = gsw.sigma0(SA_predicted, CT_predicted)

    ds = ds.assign(variables={"SIG_predicted": (('PRES_INTERPOLATED', 'time', 'lat', 'lon'), SIG_predicted.data)})

    #New mask for better surface reconstruction
    b = 2
    b2 = 1
    H = 0.5664     #For OSnet Gulf Stream
    mask2 = np.where(ds['MLD_mask'].data<H, ds['MLD_mask'], 1)
    ds = ds.assign(variables={"MLD_mask2": (('PRES_INTERPOLATED', 'time', 'lat', 'lon'), mask2)})
    mask3 = np.where((ds['MLD_mask']>H) & (ds['MLD_mask']<b2), b-ds['MLD_mask'].data, ds['MLD_mask2'].data)
    ds = ds.assign(variables={"MLD_mask3": (('PRES_INTERPOLATED', 'time', 'lat', 'lon'), mask3)})

    return ds


In [None]:
@guvectorize(
    "(float64[:], float64[:], float64[:], float64[:], float64[:], float64[:])",
    "(n), (n), (n), (n) -> (n), (n)"
)
def MLD_adj_1d(temp_in, psal_in, depth, mask, temp, psal):
    temp[:] = np.copy(temp_in)
    psal[:] = np.copy(psal_in)
    for d in range(len(depth)-2, -1, -1):
        # apply mask on TEMP and PSAL
        temp[d] = (temp_in[d]*mask[d] - temp_in[d+1]*mask[d]) + temp[d+1]
        psal[d] = (psal_in[d]*mask[d] - psal_in[d+1]*mask[d]) + psal[d+1]

def MLD_adj(ds,mask):
    temp_out, psal_out = xr.apply_ufunc(MLD_adj_1d,
                                    ds['TEMP_predicted'], ds['PSAL_predicted'], ds['PRES_INTERPOLATED'], mask,
                                    input_core_dims=(['PRES_INTERPOLATED'],['PRES_INTERPOLATED'],['PRES_INTERPOLATED'],['PRES_INTERPOLATED']),
                                    output_core_dims=(['PRES_INTERPOLATED'],['PRES_INTERPOLATED']),
                                    output_dtypes=[np.float64, np.float64])
    # get sig adjusted
    sa_out = gsw.SA_from_SP(psal_out, ds['PRES_INTERPOLATED'], ds['lon'], ds['lat'])
    ct_out = gsw.CT_from_t(sa_out,temp_out,ds['PRES_INTERPOLATED'])
    sig_out = gsw.sigma0(sa_out, ct_out)
    
    ds_out = ds.assign(variables={"TEMP_adj": (('time', 'lat', 'lon','PRES_INTERPOLATED'), temp_out.data),
                                  "PSAL_adj": (('time', 'lat', 'lon','PRES_INTERPOLATED'), psal_out.data),
                                  "SIG_adj": (('time', 'lat', 'lon','PRES_INTERPOLATED'), sig_out.data)})
    return ds_out 

In [None]:
year_start = 1993
year_end = 2020
for yy in range(year_start, year_end+1):
    time_year = time.time()
    for mm in ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]:
        time_start = time.time()
        print(f'Starting adjustment for month: {mm}-{yy}')
        ds = get_data(osnet_rep, yy, mm)
        ds = convective_adjustment(ds,mask = ds['MLD_mask3'])
        print(f"size output file: {np.around(ds.nbytes / 1073741824,2)} Go, saved in {path_product}produit_{yy}{mm}_adj.nc")
        ds.to_netcdf(f"{osnet_rep}/produit_{yy}{mm}_adj.nc")
        print(f"adjustment of month {yy}-{mm} finished in {np.around(time.time() - time_start,2)} secondes")
    print(f'Year: {yy} done in {time.time() - time_year}')
print('Computation finished')

In [None]:
%%time
#Clean and format the final product
file = glob.glob(f"{osnet_rep}/*_adj.nc")
file =np.sort(file)

mon = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]
i = 1
for i in np.arange(len(file)):
    ds = xr.open_mfdataset(file[i])
    mm = ds.time.dt.month[0].data
    yy = ds.time.dt.year[0].data
    ds = xr.open_mfdataset(file[i])
    ds = add_mld(ds)

    #drop all other variables
    ds = ds.rename({"lat":"latitude","lon":"longitude","PRES_INTERPOLATED":"depth","TEMP_adj":"temp","PSAL_adj":"psal",
                      "TEMP_predicted_std":"temp_std","PSAL_predicted_std":"psal_std"})
    ds = ds[['mld', 'temp', 'psal','temp_std','psal_std']]
    ds = ds.drop('mask')
    ds['temp'] = ds.temp.astype(np.float32)
    ds['psal'] = ds.psal.astype(np.float32)
    ds['temp_std'] = ds.temp_std.astype(np.float32)
    ds['psal_std'] = ds.psal_std.astype(np.float32)
    ds['mld'] = ds.mld.astype(np.float32)

    ds.attrs['title'] = 'daily mean fields from Ocean Stratification network (OSnet)'
    ds.attrs['easting'] = 'longitude'
    ds.attrs['northing'] = 'latitude'

    ds.time.attrs['standard_name'] = 'time'

    ds.depth.attrs['standard_name'] = 'depth'
    ds.depth.attrs['units'] = 'meters'

    ds.latitude.attrs['standard_name'] = 'latitude'
    ds.latitude.attrs['units'] = 'degrees_north'

    ds.longitude.attrs['standard_name'] = 'longitude'
    ds.longitude.attrs['units'] = 'degrees_east'

    ds.mld.attrs['long_name'] = 'Density ocean mixed layer thickness'
    ds.mld.attrs['standard_name'] = 'ocean_mixed_layer_thickness_defined_by_sigma_03'
    ds.mld.attrs['units'] = 'meters'

    ds.temp.attrs['long_name'] = 'In situ temperature'
    ds.temp.attrs['standard_name'] = 'sea_water_in_situ_temperature'
    ds.temp.attrs['units'] = 'degrees_Celsius'

    ds.psal.attrs['long_name'] = 'Practical Salinity'
    ds.psal.attrs['standard_name'] = 'sea_water_salinity'
    ds.psal.attrs['units'] = 'Practical_Salinity_Unit'

    ds.temp_std.attrs['long_name'] = 'In situ temperature confidence interval'
    ds.temp_std.attrs['standard_name'] = 'sea_water_in_situ_temperature_confidence_interval'
    ds.temp_std.attrs['units'] = 'degrees_Celsius'

    ds.psal_std.attrs['long_name'] = 'Practical Salinity confidence interval'
    ds.psal_std.attrs['standard_name'] = 'sea_water_salinity_confidence_interval'
    ds.psal_std.attrs['units'] = 'Practical_Salinity_Unit'

    mm = ds.time.dt.month[0].data - 1
    yy = ds.time.dt.year[0].data
    #    print(f"size output file for {mm}-{yy}: {np.around(ds_temp_psal.nbytes / 1073741824,2)} Go")
    path_out = '/home/datawork-lops-bluecloud/osnet/product_out/OSnet/clean'
    ds.to_netcdf(f'{osnet_rep}/OSnet_GS_{yy}{mon[mm]}.nc')