### Load the libraries (if error occurs, install them via `pip install xarray` or `conda install xarray` for example for `xarray`)

In [None]:
import xarray as xr
import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import generic_filter
import pvlib
import gsw
import scipy
import datetime

### Specify the float number you want to post-process
- note the file list is not in sequence. this is intentional to prevent the assumption that the file order follows the time sequence (which is not necessary always the case for some reason).

In [None]:
wmoid = 6901474 # float number
vernum = 0 # Version of the notebook (generally, no need to change) 

list_file = glob.glob(f"{wmoid}/S*{wmoid}_[0-9][0-9][0-9].nc")
print('Total number of profies:',len(list_file))
list_file

### Read all profiles

In [None]:
# list for storing all profiles
juld = [] # date and time
lon = [] # longitude
lat = [] # latitude
pres = [] # pressure level (depth)
vars = ['temp','psal','down','nitr','chla','bbp7','doxy','ph_i'] # shortened names
vars_original = ['TEMP','PSAL', # original names
                 'DOWNWELLING_PAR_ADJUSTED','NITRATE_ADJUSTED',
                 'CHLA_ADJUSTED','BBP700_ADJUSTED',
                 'DOXY_ADJUSTED','PH_IN_SITU_TOTAL_ADJUSTED']
raw = {f"{var}": [] for var in vars} # raw data
qc = {f"{var}": [] for var in vars} # qc flags
qc_valid = {f"{var}": [] for var in vars} # good data
pres_qc_valid = {f"{var}": [] for var in vars} # corresponding pres for good data
npq5_qc_valid = [] # corresponding qc=5 for good chla data (used for identifying whether npq correction is needed)

for i in range(len(list_file)): # loop over profiles
    ds = xr.open_dataset(list_file[i]) # open the netCDF file for each profile
    juld.append(ds['JULD'][0].values)
    lon.append(ds['LONGITUDE'][0].values)
    lat.append(ds['LATITUDE'][0].values)
    pres.append(ds['PRES'][0,:].values)
    for j in range(len(vars)): # loop over variables
        if vars_original[j] in ds.data_vars:
            raw[vars[j]].append(ds[vars_original[j]][0,:].values) # store raw profiles
            qc[vars[j]].append(ds[vars_original[j]+'_QC'][0,:].values.astype(str)) # store qc flags
            if vars_original[j] == 'CHLA_ADJUSTED': # include the qc flag of 5 (NPQ)
                qc_valid[vars[j]].append(raw[vars[j]][-1][np.isin(qc[vars[j]][-1],['1','2','5','8'])]) # store qc masks
                pres_qc_valid[vars[j]].append(pres[-1][np.isin(qc[vars[j]][-1],['1','2','5','8'])])
                npq5_qc_valid.append(np.any(np.isin(qc[vars[j]][-1],['5']))) # set to True if the profile contains QC of 5 (NPQ corrected)
            else:
                qc_valid[vars[j]].append(raw[vars[j]][-1][np.isin(qc[vars[j]][-1],['1','2','8'])]) # store qc masks
                pres_qc_valid[vars[j]].append(pres[-1][np.isin(qc[vars[j]][-1],['1','2','8'])]) # store qc masks

### Plot the raw data to understand the data coverage

In [None]:
def plot_raw(vari_in,pres_in,juld_in,vari_name_in):
    for i in range(len(vari_in)): # loop over profiles
        if np.any(np.isfinite(vari_in[i])): # ignore the profile with all NaNs
            plt.scatter(np.full(len(pres_in[i]),juld_in[i]),pres_in[i],c=vari_in[i],s=0.1,
                        vmin=np.nanmin(np.concatenate(vari_in)),
                        vmax=np.nanmax(np.concatenate(vari_in))
                       )
    plt.gca().invert_yaxis()
    cbar = plt.colorbar()
    cbar.set_label(vari_name_in)
    vari_ptile = [np.nanpercentile(np.concatenate(vari_in),25),
                  np.nanpercentile(np.concatenate(vari_in),50),
                  np.nanpercentile(np.concatenate(vari_in),75)
                 ]
    cbar.ax.hlines(vari_ptile,xmin=0,xmax=1,color='r') # draw percentiles
    plt.gcf().autofmt_xdate() # automatically format date
    plt.xlim(np.min(juld),np.max(juld)) # align the date range across all variables

plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(raw[vars[j]]) and np.any(np.isfinite(np.concatenate(raw[vars[j]]))): # if finite values exit
        plot_raw(raw[vars[j]],pres,juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Plotting the good data

In [None]:
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(qc_valid[vars[j]]) and np.any(np.isfinite(np.concatenate(qc_valid[vars[j]]))): # if finite values exit
        plot_raw(qc_valid[vars[j]],pres_qc_valid[vars[j]],juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Smoothing

In [None]:
# NaN を無視する中央値フィルタ関数
def nanmedian_filter(values):
    valid_values = values[~np.isnan(values)]  # NaN を除去
    return np.median(valid_values) if len(valid_values) > 0 else np.nan  # 有効値があれば中央値、なければ NaN

# define lists
smooth = {f"{var}": [] for var in vars} # smoothed data
pres_res = {f"{var}": [] for var in vars} # vertical resolution
pres_mid = {f"{var}": [] for var in vars} # midpoint depth at which vertical resolutions are defined

# Compute and assign smoothed values
for j in range(len(vars)): # loop over variables
    for i in range(len(qc_valid[vars[j]])): # loop over profiles
        pres_res[vars[j]].append(np.diff(pres_qc_valid[vars[j]][i])) # 深度の解像度を計算
        pres_mid[vars[j]].append((pres_qc_valid[vars[j]][i][:-1] + pres_qc_valid[vars[j]][i][1:]) / 2) # 深度の解像度を計算
        pres_res_med = np.median(pres_res[vars[j]][-1])
        # 窓サイズの決定
        if pres_res_med >= 3:
            nsmooth = 5
        elif pres_res_med <= 1:
            nsmooth = 11
        else:
            nsmooth = 7
        # 中央値フィルタを適用
        smooth[vars[j]].append(generic_filter(
            qc_valid[vars[j]][i],nanmedian_filter,size=nsmooth,mode='nearest')
                              )

# Plot
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(smooth[vars[j]]) and np.any(np.isfinite(np.concatenate(smooth[vars[j]]))): # if finite values exit
        plot_raw(smooth[vars[j]],pres_qc_valid[vars[j]],juld,vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Plot the vertical resolution information

In [None]:
def plot_res(pres_res_in,pres_mid_in,vari_name_in):
    for i in range(len(pres_res_in)):
        if np.any(np.isfinite(pres_res_in[i])): # ignore the profile with all NaNs
            plt.scatter(pres_res_in[i],pres_mid_in[i],color='k',s=0.1,alpha=0.1)
    plt.title(vari_name_in)
    plt.xlim(0,np.nanmax(np.concatenate(pres_res_in)))
    plt.ylim(0,np.nanmax(np.concatenate(pres_mid_in)))
    plt.gca().invert_yaxis()
    plt.xlabel('Resolution (dbar)')
    plt.ylabel('Depth (dbar)')

# Plot
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(pres_res[vars[j]]) and np.any(np.isfinite(np.concatenate(pres_res[vars[j]]))): # if finite values exit
        plot_res(pres_res[vars[j]],pres_mid[vars[j]],vars_original[j])
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Interpolation (Please specify the resolution and the depth for interpolation)
- date is sorted for 2d interpolation
- Set negatives to zeros for all variables other than temperature.

In [None]:
int_res = 10 # resolution used for interpolation (dbar, which can be approximated as meter)
int_dep0 = 1 # the shallowest depth
int_dep1 = 1000 # the deepest depth

pres_int = np.arange(int_dep0,int_dep1,int_res) # depth grid for interpolation
data_int = {f"{var}": [] for var in vars} # interpolated data

def interpolate_argo(pres_in,pres_out,data_in,date,vari_name_in):
    data2d = np.full((len(date),len(pres_out)), np.nan) # create 2d array filled with NaNs
    ind_sorted = np.argsort(date)
    for i in range(len(pres_in)):
        if np.any(np.isfinite(pres_in[ind_sorted[i]])): # ignore the profile with all NaNs
            data_sorted = data_in[np.argsort(date)[i]]
            f = scipy.interpolate.interp1d(x=pres_in[ind_sorted[i]],y=data_in[np.argsort(date)[i]],
                                           kind='linear',
                                           bounds_error=False,  # 範囲外は補間せずにfill_valueを適用
                                           fill_value=np.nan    # 範囲外のデータはNaNに設定
                                          )
            data2d[i,:] = f(pres_out)
    return np.float32(data2d) # single precision is sufficient

# function to plot the interpolated data (2d array)
def plot_int(vari_in,pres_in,juld_in,vari_name_in):
    X, Y = np.meshgrid(np.sort(juld_in),pres_in, indexing='ij')  # (time, depth)
    #X, Y = np.meshgrid(np.sort(juld_in),pres_in)  # shape must match int_2d
    if np.any(np.isfinite(vari_in)): # check at least one value exits
        plt.scatter(X,Y,c=vari_in,s=0.1,vmin=np.nanmin(vari_in),vmax=np.nanmax(vari_in))
    else: # all values are NaNs, which seems weird
        print('all interpolated values are NaN? CHECK')
    cbar = plt.colorbar()
    cbar.set_label(vari_name_in)
    vari_ptile = [np.nanpercentile(vari_in,25),
                  np.nanpercentile(vari_in,50),
                  np.nanpercentile(vari_in,75)
                 ]
    cbar.ax.hlines(vari_ptile,xmin=0,xmax=1,color='r') # draw percentiles
    plt.gcf().autofmt_xdate() # automatically format date
    plt.xlim(np.min(juld),np.max(juld)) # align the date range across all variables
    plt.ylim(0,np.max(pres_in)) # align the depth range across all variables
    plt.gca().invert_yaxis()

# Interpolate and plot
plt.figure(figsize=(12,10))
for j in range(len(vars)): # loop over variables
    plt.subplot(3,3,j+1)
    if len(smooth[vars[j]]) and np.any(np.isfinite(np.concatenate(smooth[vars[j]]))): # if finite values exit
        data_int[vars[j]] = interpolate_argo(pres_qc_valid[vars[j]],pres_int,smooth[vars[j]],juld,vars_original[j])
        if vars[j] != 'temp': # if not temperature
            data_int[vars[j]][data_int[vars[j]]<0] = 0 # set negatives to zeros
        plot_int(data_int[vars[j]],pres_int,juld,vars_original[j])        
    else:
        plt.text(0.1,0.5,'NO DATA for \n'+vars_original[j])
        plt.axis('off')
plt.tight_layout()

### Saving
Save the post-processed 2D array as a netCDF file

In [None]:
# Initialize dictionary to hold DataArrays
data_vars = {}

# Loop to generate DataArrays
for i in range(len(vars)):
    if len(data_int[vars[i]]): # continue if data exist
        data_array = xr.DataArray(
            data_int[vars[i]],
            coords={
                'time': ('time', np.sort(juld)), #, {'units': 'days since 1950-01-01'}),
                'depth': ('depth', pres_int, {'units': 'dbar'})  # or 'meters' if it's depth below sea surface
            },            
            dims=['time', 'depth'],
            attrs=ds[vars_original[i]].attrs # copy the input file attributes
        )
        data_vars[vars_original[i]+'_AR'] = data_array # adding the array to the dataset (AR: Analysis-Ready)
    else: # skip if data are empty
        print(vars_original[i],'is empty so not adding to the file')

# Create Dataset from all variables
ds_int = xr.Dataset(data_vars)

print(ds_int)

## EXTRA (variable-specific post-processing)
- Chlorophyll-a: NPQ and dark corrections
- 

### **Chlorophyll-a only** Apply NPQ correction and dark correction for relevant profiles

In [None]:
def calc_solar_elevation(latitude, longitude, utc):
    """緯度, 経度, UTC時間 から太陽光角度を計算"""
    loc = pvlib.location.Location(latitude, longitude)
    solar_position = loc.get_solarposition(utc)
    solar_zenith = solar_position['zenith'].values[0]  # 太陽天頂角
    solar_elevation = 90 - solar_zenith  # 太陽高度角
    return solar_elevation

# Define empty lists
mld = [] # mixed layer depth
chla_npq = [] # NPQ corrected
chla_dark = [] # dark corrected
count_npq = 0 # number of NPQ correction applied here

# Compute and assign NPQ corrected values
if len(qc_valid['chla']): # if chla data exist:
    for i in range(len(qc_valid['chla'])): # loop over profiles        
        if calc_solar_elevation(lat[i], lon[i], juld[i]) < 0 and not npq5_qc_valid[i]: # if sun is above horizon AND qc != 5 (NPQ correction necessary)
            # 絶対塩分（Absolute Salinity, SA）の計算
            SA = gsw.SA_from_SP(smooth['psal'][i], pres_qc_valid['psal'][i], lon[i], lat[i])        
            # 実効温度（Conservative Temperature, CT）の計算
            CT = gsw.CT_from_t(SA, smooth['temp'][i], pres_qc_valid['temp'][i])
            # ポテンシャル密度（Potential Density, σθ）の計算（基準圧力 0 dbar）
            sigma0 = gsw.sigma0(SA, CT)  # σθ = 密度 - 1000 (kg/m³)
            # Obtain sigma0 at 10 dbar based on linear interpolation 10dbarでの密度を取得
            sigma0_10 = np.interp(x=pres_qc_valid[i],xp=10,fp=sigma0)
            for j in range(len(sigma0)): # loop over samples
                if sigma0[j] > sigma0_10 + 0.03:
                    mld.append(pres_qc_valid['temp'][i][j])
                    idx90 = np.argmin(np.abs(pres_qc_valid['chla'][i] - 0.9*mld[-1])) # depth index closest to mld*0.9
                    chla_npq.append(smooth['chla'][i]) # first assign the smoothed data
                    chla_npq[-1][:idx90+1] = np.nanmax(chla_npq[-1][:idx90+1]) # set the upper 90% mld to have uniform chla
                    count_npq += 1 # count the number of corrected profiles
                    break # stop at the first occurrence
            else: # if the threshold is never met
                mld.append(np.nan) # assign nan
                chla_npq.append(smooth['chla'][i]) # no NPQ correction hence assign the smoothed data
        else:
            chla_npq.append(smooth['chla'][i]) # assign the smoothed data
    else:
        print ('No good data for CHLA, so not calculating NPQ or dark correction')

print('Total number of all profiles:',len(npq5_qc_valid))
print('Total number of profiles with QC = 5:',sum(npq5_qc_valid))
print('Total number of profiles corrected here:',count_npq)

### Saving to netCDF

In [None]:
ds_int.attrs['title'] = 'Analysis-ready BGC-Argo dataset'
ds_int.attrs['institution'] = 'JAMSTEC Application Laboratory (APL)'
ds_int.attrs['notes'] = 'Reference: Hayashida and Fujishima (2025): Jupyter Notebook for generating analysis-ready BGC-Argo datasets, Journal of Open Source Software'
ds_int.attrs['history'] = 'Created on ' 
                        + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
                        + ' using Version '+str(vernum)
# Save to NetCDF
ds_int.to_netcdf('AR'+str(wmoid)+'.nc')