# Southern Ocean Codes
## Environment Settings

In [None]:
# filter some warning messages
import warnings
warnings.filterwarnings("ignore") 

In [None]:
from dask.distributed import Client
#from dask_kubernetes import KubeCluster
from dask_gateway import Gateway

#cluster = KubeCluster() 
gateway = Gateway()
cluster = gateway.new_cluster()
cluster.adapt(minimum = 1, maximum = 10)

client = Client(cluster, timeout="50s") 
cluster

In [None]:
!pip install --upgrade git+https://github.com/JiaweiZhuang/xESMF.git

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import gcsfs # Pythonic file-system for Google Cloud Storage
import xesmf as xe
import math
import copy
from scipy.interpolate import griddata
# import seawater as sw
from tqdm.autonotebook import tqdm

# from mpl_toolkits.basemap import Basemap, cm, maskoceans

# import os
# os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0'

## Data Access and processing
### 1. Fetch data and calculation
If this process have been done and data have been saved before, this part doesn't need to be run again.
Turn to the second part retrieve data from saved file.
#### a) get data from gcs

In [None]:
df = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')

In [None]:
def func_load_ds_uri(uri):
    """
    Load data for given uri
    """
    gcs = gcsfs.GCSFileSystem(token='anon') # GCSFS will attempt to use your default gcloud credentials
    ds = xr.open_zarr(gcs.get_mapper(uri), consolidated=True)
    return ds

In [None]:
GCM_name = 'GFDL-CM4'

In [None]:
df_plt = df[(df.table_id == 'Omon') & 
            (df.variable_id == 'thetao') & 
            (df.activity_id == 'CMIP') & 
            (df.experiment_id == 'piControl') & 
            (df.source_id==GCM_name)]
df_plt = df_plt[ df_plt['grid_label'] == 'gr']
#run_counts = df_plt.groupby(['source_id', 'experiment_id'])['zstore'].count()
#run_counts
uri_thetao = df_plt[(df_plt.source_id == GCM_name)].zstore.values[0]
uri_thetao

In [None]:
df_plt = df[(df.table_id == 'Omon') & (df.variable_id == 'so') & (df.activity_id == 'CMIP') & (df.experiment_id == 'piControl')& (df.source_id==GCM_name)]
df_plt = df_plt[ df_plt['grid_label'] == 'gr']
uri_so = df_plt[(df_plt.source_id == GCM_name)].zstore.values[0]
uri_so

In [None]:
ds_thetao = func_load_ds_uri(uri_thetao)
ds_so = func_load_ds_uri(uri_so)

#### b) select data

In [None]:
GCM = GCM_name
year_start=1
year_end=500

conv_index_depth = 500

In [None]:
ds_thetao = ds_thetao.isel(time=slice(year_start-1, year_end*12))
ds_so = ds_so.isel(time=slice(year_start-1, year_end*12))

In [None]:
Depths=ds_thetao.lev.values

#### c) calcalate MLD etc...
    define functions

In [None]:
def func_MLD(dset_thetao, dset_so, month_no, hemis_no):
    import seawater as sw
    
    yrs_no = np.int(len(ds_thetao.time)/12)
    Depths=ds_thetao.lev.values

    if hemis_no==0:
        depth_MLD_tr=2000 # Weddell Sea - [30E-60W]
    elif hemis_no==1:
        depth_MLD_tr=1000 # Labrador Sea  - [60W-30W]
    elif hemis_no==2:
        depth_MLD_tr=2000 # Norwegian Sea
    elif hemis_no==3:
        depth_MLD_tr=1000 # Labrador Sea - Extended [60W-40W]  
    elif hemis_no==4:
        depth_MLD_tr=2000 # Ross Sea  - [160E-230E]           
    else:
        depth_MLD_tr=2000
    deep_conv_area=[]
    
    Lon_orig=ds_thetao.lon.values
    Lat_orig=ds_thetao.lat.values 
    if np.ndim(Lon_orig)==1: # If the GCM grid is not curvlinear
        Lon_orig,Lat_orig=np.meshgrid(Lon_orig, Lat_orig)    
        
    lat_n=Lat_orig.shape[0] # Number of Lat elements in the data
    lon_n=Lon_orig.shape[1] # Number of Lon elements in the data
    earth_R = 6378e3 # Earth Radius - Unit is kilometer (km)
    GridCell_Areas = np.zeros ((lat_n, lon_n )) # A = 2*pi*R^2 |sin(lat1)-sin(lat2)| |lon1-lon2|/360 = (pi/180)R^2 |lon1-lon2| |sin(lat1)-sin(lat2)| 
    for ii in range(1,lat_n-1):
        for jj in range(1,lon_n-1):
            GridCell_Areas [ii,jj] = math.fabs( (earth_R**2) * (math.pi/180) * np.absolute( (Lon_orig[ii,jj-1]+Lon_orig[ii,jj])/2  -  (Lon_orig[ii,jj]+Lon_orig[ii,jj+1])/2 )  * np.absolute( math.sin(math.radians( ( Lat_orig[ii-1,jj]+Lat_orig[ii,jj])/2 )) - math.sin(math.radians( Lat_orig[ii,jj]+Lat_orig[ii+1,jj])/2  )) )                  
    for ii in range(1,lat_n-1):
        for jj in range(2,lon_n-2):
            if GridCell_Areas [ii,jj] > GridCell_Areas [ii,jj-1]*3:
                GridCell_Areas [ii,jj]=GridCell_Areas [ii,jj-1]
            if GridCell_Areas [ii,jj] > GridCell_Areas [ii,jj+1]*3:
                GridCell_Areas [ii,jj]=GridCell_Areas [ii,jj+1]
    GridCell_Areas[0,:]=GridCell_Areas[1,:]; GridCell_Areas[-1,:]=GridCell_Areas[-2,:]
    GridCell_Areas[:,0]=GridCell_Areas[:,1]; GridCell_Areas[:,-1]=GridCell_Areas[:,-2]
    areacello=GridCell_Areas      
    
    lat_n_regrid=90
    lon_n_regrid=180
    Lat_regrid_1D, Lon_regrid_1D, Lat_bound_regrid, Lon_bound_regrid = func_latlon_regrid_eq(lat_n_regrid, lon_n_regrid, -90, 90, 0, 360)
    lon, lat = np.meshgrid(Lon_regrid_1D, Lat_regrid_1D)
    areacello = func_regrid(areacello, Lat_orig, Lon_orig, lat, lon)
    
    # ds_so.so.isel(lev=slice(0,10))   # slice(start, end, step)
    
    data_plot=np.full([yrs_no,len(lon),len(lon[0])], np.nan)    
    
    for t in tqdm(range(yrs_no)):
        #print('MLD calc - Year: ', t+1)
        data_thetao_extracted = ds_thetao.thetao.isel(time= 12*t+month_no-1 ).values
        data_so_extracted = ds_so.so.isel(time= 12*t+month_no-1 ).values
        data_dens=sw.dens0(data_so_extracted, data_thetao_extracted)
        depth10m_shalow=0
        depth10m_deep=0
        depth_array=np.asarray(Depths)
        
        for k in range(len(Depths)):
            if Depths[k]<=10:
                depth10m_shalow=k
        for k in range(len(Depths)):        
            if Depths[k]>=10:
                depth10m_deep+=k
                break
                
        interpol_x = [Depths[depth10m_shalow], Depths[depth10m_deep]]
        data_i=data_dens
        data_i = func_regrid(data_dens, Lat_orig, Lon_orig, lat, lon)
        data_i[data_i>100000]=np.nan
        
        if (int(hemis_no)==int(0)):# Weddell Sea
            [ii,jj] = np.where(lat<=-50)###indeces####
        elif (int(hemis_no)==int(1)):# Labrador Sea  - [60W-30W]
            [ii,jj] = np.where(lat>=50)###indeces####
        elif (int(hemis_no)==int(2)):# Norwegian Sea
            [ii,jj] = np.where(lat>=58)###indeces####
        elif (int(hemis_no)==int(3)):# Labrador Sea  - [60W-40W]  
            [ii,jj] = np.where(lat>=50)###indeces####   
        elif (int(hemis_no)==int(4)):# Ross Sea  - [160E-230E]   
            [ii,jj] = np.where(lat<=-50)###indeces####             
        else:
            print(hemis_no)
            print('invalid input for hemisphere option')
            break                

        area=0
        for k in range(len(ii)):
            if not(str(data_i[0,ii[k],jj[k]])=='nan'):
                dummy=100
                interpol_dens = [data_i[depth10m_shalow,ii[k],jj[k]], data_i[depth10m_deep,ii[k],jj[k]]]
                p_10m_dens = np.interp(10, interpol_x, interpol_dens)
                for d in range(len(data_i)):
                    if not(str(data_i[0,ii[k],jj[k]])=='nan'):
                        p_dens = data_i[d,ii[k],jj[k]]
                        if abs(p_dens-p_10m_dens-0.03)<dummy:
                            dummy=abs(p_dens-p_10m_dens-0.03)
                            MLD=d
                if MLD==0:
                    MLD+=1
                    p_dens_interpol = [data_i[MLD-1,ii[k],jj[k]]-p_10m_dens,data_i[MLD,ii[k],jj[k]]-p_10m_dens,data_i[MLD+1,ii[k],jj[k]]-p_10m_dens]
                    depth_levels = [depth_array[MLD-1],depth_array[MLD],depth_array[MLD+1]]
                ##elif MLD==49:
                elif MLD==len(data_i)-1: # If MLD is the last layer                   
                    MLD-=1
                    p_dens_interpol = [data_i[MLD-1,ii[k],jj[k]]-p_10m_dens,data_i[MLD,ii[k],jj[k]]-p_10m_dens,data_i[MLD+1,ii[k],jj[k]]-p_10m_dens]
                    depth_levels = [depth_array[MLD-1],depth_array[MLD],depth_array[MLD+1]]
                else:
                    p_dens_interpol = [data_i[MLD-1,ii[k],jj[k]]-p_10m_dens,data_i[MLD,ii[k],jj[k]]-p_10m_dens,data_i[MLD+1,ii[k],jj[k]]-p_10m_dens]
                    depth_levels = [depth_array[MLD-1],depth_array[MLD],depth_array[MLD+1]]
                interpol_z=np.interp(0.03,p_dens_interpol,depth_levels)
                if interpol_z>=depth_MLD_tr:
                #y1+=float(interpol_z)
                    area+=areacello[ii[k],jj[k]]
                    data_plot[t,ii[k],jj[k]]=float(interpol_z)     
        deep_conv_area.append(area)
    deep_conv_area=np.asarray(deep_conv_area)   
    
    average_MLD=np.nanmean(data_plot,axis=0)
    if hemis_no==0: # SH, Weddell Sea
        indeces = np.where(np.logical_or((lon<=30) & (average_MLD>depth_MLD_tr), (lon>=300) &(average_MLD>depth_MLD_tr)))
    elif hemis_no==1: # NH, Labrador Sea  - [60W-30W]
        indeces = np.where(np.logical_and((lon>=30) & (average_MLD>depth_MLD_tr), (lon<=330) &(average_MLD>depth_MLD_tr)))
    elif hemis_no==2: # NH, Norwegian Sea
        indeces = np.where(np.logical_or((lon<=30) & (average_MLD>depth_MLD_tr), (lon>=345) &(average_MLD>depth_MLD_tr)))
    elif hemis_no==3: # NH, Labrador Sea  - [60W-40W]
        indeces = np.where(np.logical_and((lon>=30) & (average_MLD>depth_MLD_tr), (lon<=320) &(average_MLD>depth_MLD_tr)))
    elif hemis_no==4: # NH, Ross Sea  - [160E-230E] 
        indeces = np.where(np.logical_and((lon>=160) & (average_MLD>depth_MLD_tr), (lon<=230) &(average_MLD>depth_MLD_tr)))        
    else: ### This should never be the case though ###
        indeces = np.where(np.logical_and((lon>=30) & (average_MLD>depth_MLD_tr), (lon<=330) &(average_MLD>depth_MLD_tr)))        
        
    return deep_conv_area, data_plot, lon, lat, indeces     
    

In [None]:
def func_time_depth_plot(ds_thetao, indeces, conv_index_depth):

    Depths=ds_thetao.lev.values
    yrs_no = np.int(len(ds_thetao.time)/12)
    Lon_orig=ds_thetao.lon.values
    Lat_orig=ds_thetao.lat.values    
    
    [ii,jj]=indeces
    region=[]
    
    lat_n_regrid=90
    lon_n_regrid=180
    Lat_regrid_1D, Lon_regrid_1D, Lat_bound_regrid, Lon_bound_regrid = func_latlon_regrid_eq(lat_n_regrid, lon_n_regrid, -90, 90, 0, 360)
    lon, lat = np.meshgrid(Lon_regrid_1D, Lat_regrid_1D)     
    
    for t in tqdm(range(yrs_no)):
        data = ds_thetao.thetao.isel(time= slice(12*t,12*t+11) ).values
        data=np.asarray(data)
        #data[data>100000]=np.nan
        data=np.nanmean(data,axis=0)
        data=np.squeeze(data)
        data_i = func_regrid(data, Lat_orig, Lon_orig, lat, lon)
        data_i=data_i[:,ii,jj]
        
        #print('time_depth_plot calc - Year: ', t+1)
        region.append(np.nanmean(data_i,axis=1))

    depth_index_start=0
    depth_index_end=0
    if conv_index_depth==0:
        depth_index_start=0
        depth_index_end=1
    else:
        for i in range(len(Depths[:])):
            if Depths[i]<=conv_index_depth:
                depth_index_start=i
        for i in range(len(Depths[:])):
            if Depths[i]<=conv_index_depth:
                depth_index_end=i
    if depth_index_end==0:
        depth_index_end+=1        
        
        
    print(depth_index_start)
    region=np.asarray(region)
    convection_index=region[:,depth_index_start]
    return region,convection_index

In [None]:
#added by Grace, adapting Behzad's code to get average salinity over convection area
def func_time_depth_plot_so(ds, indeces, conv_index_depth):

    Depths=ds.lev.values
    yrs_no = np.int(len(ds.time)/12)
    Lon_orig=ds.lon.values
    Lat_orig=ds.lat.values    
    
    [ii,jj]=indeces
    region=[]
    
    lat_n_regrid=90
    lon_n_regrid=180
    Lat_regrid_1D, Lon_regrid_1D, Lat_bound_regrid, Lon_bound_regrid = func_latlon_regrid_eq(lat_n_regrid, lon_n_regrid, -90, 90, 0, 360)
    lon, lat = np.meshgrid(Lon_regrid_1D, Lat_regrid_1D)     
    
    for t in tqdm(range(yrs_no)):
        data = ds.so.isel(time= slice(12*t,12*t+11) ).values
        data=np.asarray(data)
        #data[data>100000]=np.nan
        data=np.nanmean(data,axis=0)
        data=np.squeeze(data)
        data_i = func_regrid(data, Lat_orig, Lon_orig, lat, lon)
        data_i=data_i[:,ii,jj]
        
        #print('time_depth_plot calc - Year: ', t+1)
        region.append(np.nanmean(data_i,axis=1))

    depth_index_start=0
    depth_index_end=0
    if conv_index_depth==0:
        depth_index_start=0
        depth_index_end=1
    else:
        for i in range(len(Depths[:])):
            if Depths[i]<=conv_index_depth:
                depth_index_start=i
        for i in range(len(Depths[:])):
            if Depths[i]<=conv_index_depth:
                depth_index_end=i
    if depth_index_end==0:
        depth_index_end+=1        
        
        
    print(depth_index_start)
    region=np.asarray(region)
    convection_index=region[:,depth_index_start]
    return region,convection_index

### 2. restore data from saved file.

In [None]:
import os
import shelve

dir_pwd = os.getcwd() # Gets the current directory (and in which the code is placed)
filename_out = (dir_pwd + '/AllResults_'+GCM+'_500yr.out') # Directory to save processed data
my_shelf = shelve.open(filename_out)
for key in my_shelf:
    globals()[key]=my_shelf[key]
my_shelf.close()

filename_out = (dir_pwd + '/AllResults_'+GCM+'_500yr_ROSS.out') # Directory to save processed data
my_shelf = shelve.open(filename_out)
for key in my_shelf:
    globals()[key]=my_shelf[key]
my_shelf.close()

filename_out = (dir_pwd + '/AllResults_'+GCM+'_500yr_WS.out') # Directory to save processed data
my_shelf = shelve.open(filename_out)
for key in my_shelf:
    globals()[key]=my_shelf[key]
my_shelf.close()

#### 3. Calculation

## Plotting