In [3]:
from glob import glob 
import xarray as xr

In [None]:
# read in the en4 data so we can calculate the mixed layer depths

files = glob('../../../data/en4/profiles/EN.4.2.2.f.profiles.c14*.nc')

vars = ['DEPH_CORRECTED', 'JULD', 'LATITUDE', 'LONGITUDE', 'PSAL_CORRECTED', 'PSAL_CORRECTED_QC', 'POTM_CORRECTED', 'POTM_CORRECTED_QC', 'WMO_INST_TYPE']

# load all the EN4 profiles

ds = xr.open_mfdataset(files, combine='nested', concat_dim='N_PROF')

In [None]:
# create a coordinate called N_PROF which is an array starting at 0

import numpy as np

N_PROF = np.arange(ds.JULD.size)

ds['N_PROF'] = (('N_PROF'), N_PROF)

In [None]:
# cut out all the data north of 40S

idx = ds['LATITUDE'].values < -40

ds = ds.isel(N_PROF=idx)

In [None]:
# load the data into memory

en4 = ds[vars].load()

In [None]:
# make the main dimention the date

N_PROF = np.arange(en4.JULD.size)

en4['N_PROF'] = (('N_PROF'), N_PROF)

en4 = en4.sortby('JULD')

en4 = en4.assign_coords(JULD=('N_PROF', en4.JULD.data))

en4 = en4.swap_dims({'N_PROF':'JULD'})

In [None]:
# check how many months there are

month_counts = en4['JULD'].resample(JULD='M').count(dim='JULD')

In [None]:
# cut out all the data north of 90S (there is some weird data)

idx = en4['LATITUDE'].values > -90

en4 = en4.isel(JULD=idx)

In [None]:
# make longitude and latitude xarray coordinates

en4=en4.assign_coords(LONGITUDE=('JULD', en4.LONGITUDE.data))
en4=en4.assign_coords(LATITUDE=('JULD',  en4.LATITUDE.data))

In [None]:
# show data density

fig, ax = plt.subplots(figsize=(12,3))

ax.bar(np.arange(month_counts['JULD'].size), month_counts, width=1, edgecolor='w')

ax.xaxis.set_ticks(np.linspace(0, month_counts['JULD'].size, 12))

ax.xaxis.set_ticks(np.arange(0, month_counts['JULD'].size, 12))

ax.xaxis.set_ticklabels(np.arange(2004, 2024, 1))

In [None]:
# calculate density from the tempeature and salinity of the profiles

import gsw

density = gsw.rho(en4['PSAL_CORRECTED'], en4['POTM_CORRECTED'], en4['DEPH_CORRECTED']/1000)

en4['DENSITY'] = (('JULD', 'N_LEVELS'), density.values)

In [None]:
# vertically grid the density data

from scipy.interpolate import griddata
from tqdm.notebook import tqdm

z = np.arange(0, 1005, 5)

density_gridded = np.ndarray([en4.JULD.size, z.size])

for i in tqdm(range(en4.JULD.size)):

    ds = en4.isel(JULD=i)

    density_gridded[i] = griddata(ds.DEPH_CORRECTED, ds.DENSITY, z)


In [None]:
# calculate the mixed layer depth

from functions.calc_mld import calc_mld

mld = calc_mld(density_gridded, z, den_lim=0.03, ref_dpt=10)

en4['MLD'] = (('JULD'), mld)

In [None]:
# save the mld

en4.to_netcdf('../../../data/en4/en4_profiles_with_mixed_layer_depth.nc')

In [None]:
# estel font gridding function

GS=3

# grid 3d for 3d variables (2d + time)
def grid_lat_3df(dsgpd_ln,gs=GS):
    lat_min = (-90)
    lat_max = (-40)
    lat = np.arange(lat_min,lat_max+gs,gs)
    lat_labels = np.arange(0,(1/gs)*(lat_max-lat_min),1)
    # lat_labels = range(0,lat_max-lat_min,gs)
    
    return dsgpd_ln.groupby_bins('LATITUDE',lat,
                       labels=lat_labels,
                       restore_coord_dims=True).median(skipna=True) #,dim='profile_num')
    
def grid_lon_3d_f(dsgpd_t,gs=GS):
    # define lon min and max resp
    lon_min = -180
    lon_max = 180
    lon = np.arange(lon_min,lon_max+gs,gs)
    lon_labels = np.arange(0,(1/gs)*(lon_max-lon_min),1)
    # lon_labels = range(0,lon_max-lon_min,gs)

    return dsgpd_t.groupby_bins('LONGITUDE',lon,
                       labels=lon_labels,
                       restore_coord_dims=True).apply(grid_lat_3df)
    
    
def grid_var_3dflt(dsvar,clim='month',gs=GS):
    """for gridding spatially in 2D and time (3D)."""
    if clim == 'season':
        var = dsvar.groupby_bins(group='time.month',bins=range(0,15,3),labels=range(0,4)).apply(grid_lon_3d_f)
    else:
        var = dsvar.groupby('JULD.'+clim).apply(grid_lon_3d_f)
    return var

In [None]:
# plot each monthly mld map with the location of the data embedded

gs=3

lat_min = (-90)
lat_max = (-40)
lat_grid = np.arange(lat_min,lat_max+gs,gs)[:-1]+1.5

lon_min = -180
lon_max = 180
lon_grid = np.arange(lon_min,lon_max+gs,gs)[:-1]+1.5

from scipy.interpolate import griddata

for year in range(2004,2024):

    # Define the start and end dates for the year you want to select
    start_date = str(year) + '-01-01'  # Replace YYYY with the year you're interested in
    end_date = str(year) + '-12-31'
    
    ds = en4.sel(JULD=slice(start_date, end_date))

    mld_month = grid_var_3dflt(ds['MLD'])

    # Assuming 'en4' is your Dataset and 'JULD' is the datetime variable
    grouped_by_month = ds.groupby(ds.JULD.dt.month)    
    
    for month, group in grouped_by_month:

        mld_grid = mld_month.sel(month=month)
        
        mld = group.MLD
        lon = group.LONGITUDE[mld>0]
        lat = group.LATITUDE[mld>0]  
        mld = mld[mld>0]
    
        fig = plt.figure(figsize=[3.5, 4.5])
        ax = fig.add_subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
        ax = southern_ocean_map(ax)

        ax.pcolormesh(lon_grid, lat_grid, mld_grid.T, cmap=cmo.dense, vmin=0, vmax=200, transform=ccrs.PlateCarree())
        ax.scatter(lon, lat, s=0.1, c='k', transform=ccrs.PlateCarree())

        
        if month < 10:
            ax.set_title(str(year) + '-0' + str(month), fontsize=12, pad=7.5)
            plt.savefig('/Users/xduplm/Google Drive/My Drive/projects/2023_duplessis_storms_fluxes/figs/en4_month_maps/' + 'map_en4_' + str(year) + '_0' + str(month) + '.png', dpi=300)
            plt.close()
        
        else:
            ax.set_title(str(year) + '-' + str(month), fontsize=12, pad=7.5)
            plt.savefig('/Users/xduplm/Google Drive/My Drive/projects/2023_duplessis_storms_fluxes/figs/en4_month_maps/' + 'map_en4_' + str(year) + '_' + str(month) + '.png', dpi=300)
            plt.close()

In [None]:
# create a new monthly dataset of MLDs

for year in tqdm(range(2004,2024)):

    # Define the start and end dates for the year you want to select
    start_date = str(year) + '-01-01'  # Replace YYYY with the year you're interested in
    end_date = str(year) + '-12-31'
    
    ds = en4.sel(JULD=slice(start_date, end_date))

    mld_month = grid_var_3dflt(ds['MLD'])

    if year==2004:

        ds_mld = np.array(mld_month)

    else:

        ds_mld = np.append(ds_mld, mld_month, axis=0)

In [None]:
dates = pd.date_range('2004-01-01', '2024-01-01', freq='M')

# Create the Dataset
mld_month = xr.Dataset({
    'MLD': xr.DataArray(
        data=ds_mld,
        dims=['time', 'lon', 'lat'],
        coords={'time': dates, 'lon': lon_grid, 'lat': lat_grid},
        attrs={'long_name': 'Mixed Layer Depth', 'units': 'm'}
    )
})

mld_month.to_netcdf('/Users/xduplm/Google Drive/My Drive/data/duplessis-storms-paper/en4_monthly_mixed_layer_depth_median.nc')