In [None]:
import xarray as xr
import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import generic_filter

In [None]:
wmoid = 6901474
list_file = glob.glob(f"{wmoid}/S*{wmoid}_[0-9][0-9][0-9].nc")
list_file

### Read all profiles

In [None]:
# list for storing all profiles
juld = [] # date and time
pres = [] # pressure level (depth)
vars = ['temp','psal','down','nitr','chla','bbp7','doxy','ph_i'] # shortened names
vars_original = ['TEMP','PSAL', # original names
                 'DOWNWELLING_PAR_ADJUSTED','NITRATE_ADJUSTED',
                 'CHLA_ADJUSTED','BBP700_ADJUSTED',
                 'DOXY_ADJUSTED','PH_IN_SITU_TOTAL_ADJUSTED']
raw = {f"{var}": [] for var in vars} # raw data
qc = {f"{var}": [] for var in vars} # qc flags
qc_valid = {f"{var}": [] for var in vars} # good data
pres_qc_valid = {f"{var}": [] for var in vars} # corresponding pres for good data
smooth = {f"{var}": [] for var in vars} # smoothed data

#concat_raw = {f"{var}": [] for var in vars} # valid qc masks
#concat_qc_valid = {f"{var}": [] for var in vars} # valid qc masks


for i in range(len(list_file)): # loop over profiles
    ds = xr.open_dataset(list_file[i]) # open the netCDF file for each profile
    juld.append(ds['JULD'][0].values) # date and time
    pres.append(ds['PRES'][0,:].values) # pressure
    for j in range(len(vars)):
        if vars_original[j] in ds.data_vars:
            raw[vars[j]].append(ds[vars_original[j]][0,:].values) # store raw profiles
            qc[vars[j]].append(ds[vars_original[j]+'_QC'][0,:].values.astype(str)) # store qc flags
            if vars_original[j] == 'CHLA_ADJUSTED': # include the qc flag of 5 (NPQ)
                qc_valid[vars[j]].append(raw[vars[j]][-1][np.isin(qc[vars[j]][-1],['1','2','5','8'])]) # store qc masks
                pres_qc_valid[vars[j]].append(pres[-1][np.isin(qc[vars[j]][-1],['1','2','5','8'])])
            else:
                qc_valid[vars[j]].append(raw[vars[j]][-1][np.isin(qc[vars[j]][-1],['1','2','8'])]) # store qc masks
                pres_qc_valid[vars[j]].append(pres[-1][np.isin(qc[vars[j]][-1],['1','2','8'])]) # store qc masks
        
# concatenate all profiles to calculate statistics for plotting purpose
#for j in range(len(vars)): # loop over variables
#    if vars_original[j] in ds.data_vars:
#        concat_raw[vars[j]] = np.concatenate(raw[vars[j]])
#        concat_qc_valid[vars[j]] = np.concatenate(qc_valid[vars[j]])

In [None]:
lists = [chla,nitr,bbp7,down,doxy,ph_i]
num_bgc = sum(1 for i in lists if i) #number of bgc variables
num_bgc

### Plot the raw data to understand the data coverage

In [None]:
def plot_raw(vari_in,pres_in,juld_in,vari_name_in):
    for i in range(len(vari_in)):
        if np.any(np.isfinite(vari_in[i])): # ignore the profile with all NaNs
            plt.scatter(np.full(len(pres_in[i]),juld_in[i]),pres_in[i],c=vari_in[i],
                        vmin=np.nanmin(np.concatenate(vari_in)),
                        vmax=np.nanmax(np.concatenate(vari_in))
                       )
    plt.gca().invert_yaxis()
    cbar = plt.colorbar()
    cbar.set_label(vari_name_in)
    vari_ptile = [np.nanpercentile(np.concatenate(vari_in),25),
                  np.nanpercentile(np.concatenate(vari_in),50),
                  np.nanpercentile(np.concatenate(vari_in),75)
                 ]
    cbar.ax.hlines(vari_ptile,xmin=0,xmax=1,color='r') # draw percentiles
    plt.gcf().autofmt_xdate() # automatically format date
    plt.xlim(np.min(juld),np.max(juld)) # align the date range across all variables

plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(raw[vars[j]]) and np.any(np.isfinite(np.concatenate(raw[vars[j]]))): # if finite values exit
        plot_raw(raw[vars[j]],pres,juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Plot good data

In [None]:
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(qc_valid[vars[j]]) and np.any(np.isfinite(np.concatenate(qc_valid[vars[j]]))): # if finite values exit
        plot_raw(qc_valid[vars[j]],pres_qc_valid[vars[j]],juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Smoothing

In [None]:
# NaN を無視する中央値フィルタ関数
def nanmedian_filter(values):
    valid_values = values[~np.isnan(values)]  # NaN を除去
    return np.median(valid_values) if len(valid_values) > 0 else np.nan  # 有効値があれば中央値、なければ NaN

for j in range(len(vars)): # loop over variables
    for i in range(len(qc_valid[vars[j]])): # loop over profiles
        pres_res = np.median(np.diff(pres_qc_valid[vars[j]][i])) # 深度の解像度を計算
        # 窓サイズの決定
        if pres_res >= 3:
            nsmooth = 5
        elif pres_res <= 1:
            nsmooth = 11
        else:
            nsmooth = 7
        # 中央値フィルタを適用
        smooth[vars[j]].append(generic_filter(
            qc_valid[vars[j]][i],nanmedian_filter,size=nsmooth,mode='nearest')
                              )


In [None]:
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(smooth[vars[j]]) and np.any(np.isfinite(np.concatenate(smooth[vars[j]]))): # if finite values exit
        plot_raw(smooth[vars[j]],pres_qc_valid[vars[j]],juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()