# <font color='darkblue'>Setup</font> 

## <font color='orange'>Packages</font> 

In [1]:
# Packages -----------------------------------------------#

# Data Analysis
import xarray as xr
import numpy as np
import pandas as pd
import metpy.calc as mpcalc
import matplotlib.dates as dates
from tqdm import tqdm

# Plotting
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.ndimage import gaussian_filter
from matplotlib.lines import Line2D
import datetime as dt

# make sure the figures plot inline rather than at the end
%matplotlib inline

## <font color='orange'>Parameters</font> 

In [2]:
bounds = [35,120,-20,30] # for all binning
binwidth = 2 # number of degrees for bins

# these are for grouping into IOD peak season and non peak season
IODseason_begin = 9
IODseason_end = 11

# these are for grouping into the IOD years since the effects are 
# not confined to a single year. You chose this to have an even 
# number of months around the IOD peak..but now you left it in
# line with the SLA plots that start in 06 and end in 05
IODyear_begin = '-06-01' # month-day of IOD year
IODyear_end = '-05-31' # month-day of year AFTER IOD year

# # define months to start and end on when sampling
# begin = '-05-01' #iod year
# end = '-04-30'  # post iod year

## <font color='orange'>Functions</font> 

In [3]:
def add_land(ax,bounds= [35,120,-20,30]):
    from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
    ax.add_feature(cfeature.LAND,color='gray',zorder = 0)
    ax.background_patch.set_facecolor('k')
    ax.coastlines(resolution='110m',zorder = 0)
    g = ax.gridlines(draw_labels=True,alpha=0)
    g.xlabels_top = False
    g.ylabels_right = False
    g.xlabel_style = {'size': 15}
    g.ylabel_style = {'size': 15}
    g.xformatter = LONGITUDE_FORMATTER
    g.yformatter = LATITUDE_FORMATTER
    ax.axes.axis('tight')
    ax.set_extent(bounds, crs=ccrs.PlateCarree())
    return None 

def find_coast(arr):
    import itertools
    
    # create empty lists to add row and col info to
    rowind = []
    colind = []

    # create an array of indices
    rows = np.arange(arr.shape[0])
    cols = np.arange(arr.shape[1])

    # find if the sum of a block around a point is a nan (meaning one of the values at least must have been a nan)
    for row,col in itertools.product(rows,cols):
        rowcond = (np.isnan(np.sum(arr[max(0,row-1):min(arr.shape[0],row+2),col])))
        colcond = (np.isnan(np.sum(arr[row,max(0,col-1):min(arr.shape[1],col+2)])))

        if  (~np.isnan(arr[row,col])) & (rowcond | colcond):
            rowind.append(rows[row].tolist())
            colind.append(cols[col].tolist())
    
    return np.array(rowind), np.array(colind)

def order_coast(loninds,latinds,sta_zero):
    # find based on radius
    zipped_lists = zip(loninds, latinds)
    sorted_pairs = sorted(zipped_lists, reverse=True)

    # sort by lon
    tuples = zip(*sorted_pairs)
    lon_list,lat_list  = [ list(tuple) for tuple in  tuples]

    pos = []
    curr_sta = []
    rem_sta = sorted_pairs
    for i in range(len(lon_list)):
#         print(i)
        if i == 0:
            curr_sta.append(sta_zero)
            rem_sta.remove(sta_zero) 
            prev_sta = sta_zero
        else:
            prev_sta = curr_sta[i-1]

        start_len = len(curr_sta)
        dir1 = False
        for j,(lo, la) in enumerate(rem_sta):
            
            next_sta = (rem_sta[j])
            
            diff = tuple(map(lambda l, k: l - k, curr_sta[i], next_sta))
            
            # check uplr first
            if(next_sta != prev_sta) & (all(np.abs(diff) == [0,1])) | (all(np.abs(diff) == [1,0])):
                curr_sta.append(next_sta)
                rem_sta.remove(next_sta) 
                dir1 = True
                break
        if dir1 == False:
            for j,(lo, la) in enumerate(rem_sta):

                next_sta = (rem_sta[j])

                diff = tuple(map(lambda l, k: l - k, curr_sta[i], next_sta))
                # then check diagonals
                if (next_sta != prev_sta) & (all(np.abs(diff) == [1,1])):
                    curr_sta.append(next_sta)
                    rem_sta.remove(next_sta) 
                    break
                
        if len(curr_sta) == start_len:
            print('No Next Station Found. Returning Previous Stations Only.')
            print(curr_sta[i],rem_sta)
            break
        
    sta_lonind, sta_latind  = map(np.array, zip(*curr_sta))
    
    return sta_lonind, sta_latind

# binning for one variable ------------------------------------------------------------#
def latlonbin(invar,lat,lon,bounds,binwidth):
    import numpy as np
    
    # create a pandas dataframe
    df = pd.DataFrame(dict(
            invar = np.array(invar),
            lat= np.array(lat),
            lon= np.array(lon),
        ))

    # create 1 degree bins
    latedges = np.arange(bounds[2]-(binwidth/2),bounds[3]+(binwidth/2),binwidth)
    lat_inds = list(range(len(latedges)-1))

    lonedges = np.arange(bounds[0]-(binwidth/2),bounds[1]+(binwidth/2),binwidth)
    lon_inds = list(range(len(lonedges)-1))

    latbins = latedges[1:]-(binwidth/2)
    lonbins = lonedges[1:]-(binwidth/2)

    df['latedges'] = pd.cut(lat, latedges)
    df['lonedges'] = pd.cut(lon, lonedges)
    df['latbins_ind'] = pd.cut(lat, latedges,labels = lat_inds)
    df['lonbins_ind'] = pd.cut(lon, lonedges,labels = lon_inds)
    df['lat_lon_indx']=df.groupby(['latbins_ind', 'lonbins_ind']).ngroup()
    grouped = df.groupby(['latbins_ind', 'lonbins_ind'])

    invar_BINNED = np.zeros((len(latbins),len(lonbins)), dtype=np.ndarray)
    invar_BINNED[:] = np.nan

    invar_binned_ave = np.zeros((len(latbins),len(lonbins)), dtype=np.ndarray)
    invar_binned_ave[:] = np.nan
    
    invar_bincounts = np.zeros((len(latbins),len(lonbins)), dtype=np.ndarray)
    invar_bincounts[:] = np.nan


    #extract the data for each group
    for name, group in grouped:
        i = np.array(group.latbins_ind)
        j = np.array(group.lonbins_ind)

        invar_BINNED[i[0],j[0]] = group.invar

        invar_binned_ave[i[0],j[0]] = np.nanmean(group.invar)   
        
        invar_bincounts[i[0],j[0]] = len(group.invar[np.isfinite(group.invar)]) 

    return np.array(invar_binned_ave,dtype = float),np.array(invar_bincounts,dtype = float),latbins,lonbins


## <font color='orange'>Read Data</font> 

In [4]:
ds_CD = xr.open_dataset('../data/CD/CD.nc')
# ds_CD

FileNotFoundError: [Errno 2] No such file or directory: b'/projects/GEOCLIM/LRGROUP/jennap/Modulation_of_Coastal_Hypoxia_by_the_IOD/data/CD/CD.nc'

## <font color='orange'>1 Degree TCD/SLA/CHL</font> 

### <font color='lightblue'>Unordered</font> 

In [None]:
# ------------------------------------------------------------#
# Only take data after 1994 to make sure your bins are close to 
# to the coast but also have enough data to make the pdfs
# ------------------------------------------------------------#
start_time = '1994-01-01'
end_time = '2020-02-07'
time_slice = slice(start_time, end_time)

TCD = ds_CD.TCD_AWG.sel(time=time_slice)
OCD = ds_CD.OCD_AWG.sel(time=time_slice)
lat = ds_CD.lat.sel(time=time_slice)
lon = ds_CD.lon.sel(time=time_slice)

TCD_binned_ave,TCD_bncts,latbins,lonbins = latlonbin(TCD,lat,lon,bounds,1)

# ------------------------------------------------------------#
# block out a few places that you don't want to include
# ------------------------------------------------------------#

mask = TCD_binned_ave
xx,yy = np.meshgrid(lonbins,latbins)

# set Gulfs to Nans
mask = np.where(~((yy>22) & (xx<60)),mask,np.nan)
mask = np.where(~((yy>10) & (xx<51)),mask,np.nan)

# AS inlets
mask = np.where(~((yy>21.5) & (yy<22) & (xx>72) & (xx<73)),mask,np.nan)
mask = np.where(~((yy>22.25) & (yy<22) & (xx>69) & (xx<70.5)),mask,np.nan)
mask = np.where(~((yy>23.25) & (yy<22) & (xx>69) & (xx<70)),mask,np.nan)

# Sumatra area

mask = np.where(~((yy>-2) & (yy<1) & (xx>97) & (xx<101)),mask,np.nan)

# equator
mask = np.where(~(yy<0),mask,np.nan)

# ------------------------------------------------------------#
# locate points along the BoB and AS
# ------------------------------------------------------------#

sta_latinds_unord, sta_loninds_unord = find_coast(np.array(mask))

# ------------------------------------------------------------#
# Remove Manually some other points
# ------------------------------------------------------------#

sta_latinds_unord = sta_latinds_unord[9:]
sta_loninds_unord = sta_loninds_unord[9:]

# AS
lons_from_sta = np.array(lonbins[sta_loninds_unord])
lats_from_sta = np.array(latbins[sta_latinds_unord])

ind = ~((lons_from_sta <52) & (lats_from_sta < 15))
sta_latinds_unord = sta_latinds_unord[ind]
sta_loninds_unord = sta_loninds_unord[ind]

#BoB
lons_from_sta = np.array(lonbins[sta_loninds_unord])
lats_from_sta = np.array(latbins[sta_latinds_unord])

ind = ~((lons_from_sta >99))
sta_latinds_unord = sta_latinds_unord[ind]
sta_loninds_unord = sta_loninds_unord[ind]

# ------------------------------------------------------------#
# put in the right order
# ------------------------------------------------------------#

#initial station
sta_zero = (sta_loninds_unord[0],sta_latinds_unord[0]) # 98.625, 10.375

#calculate
sta_loninds, sta_latinds =order_coast(sta_loninds_unord,sta_latinds_unord,sta_zero)

sta_latinds = sta_latinds[:-3]
sta_loninds = sta_loninds[:-3]

# ------------------------------------------------------------#
# identify initial stations to demarcate 
# stations for EQ, BOB, and AS
# ------------------------------------------------------------#

# EQ = 0
JA = 9
sBoB = 18
# mBoB = 65
SL = 38
mAS = 48
# eAS = 90
loc_list = [JA,sBoB,SL,mAS]

# ------------------------------------------------------------#
# plot
# ------------------------------------------------------------#
cbounds = [40,100,-5,30]
    
cmin = -0.2
cmax = 0.2
levels = np.linspace(cmin, cmax, 10)

# Start figure
fig = plt.figure(figsize=(16, 8))
ax = plt.axes(projection=ccrs.PlateCarree())
add_land(ax,cbounds)
xx,yy = np.meshgrid(lonbins,latbins)
# xx = xx.flatten()
# yy = yy.flatten()
# plt.pcolormesh(xx,yy,TCD_binned_ave)
# p2 = plt.scatter(xx,yy,c = TCD_binned_ave.flatten(),marker='s',
#                       s = 500,cmap=plt.cm.Spectral,vmin=20,vmax=160,transform=ccrs.PlateCarree())
p =plt.scatter(lonbins[sta_loninds],latbins[sta_latinds],
            s = 1000,c=np.arange(len(sta_latinds)),marker = '.',cmap =plt.cm.PiYG)

plt.scatter(lonbins[sta_loninds[loc_list]],latbins[sta_latinds[loc_list]],
            s = 560,c='darkblue',marker = '.')
plt.colorbar(p,label = 'Station No.')

plt.title('Stations')

plt.savefig('../figures/stations-TCD-along-coast-1-degree-boxes.png', dpi=300, bbox_inches='tight')

# convert to xarray dataset
ds=xr.Dataset(coords={'longitude': daily_sla_dtrnd.longitude,
                    'latitude': daily_sla_dtrnd.latitude,
                    'time': daily_sla_dtrnd.time})

ds_CD['TCD_sta_loninds'] = xr.DataArray(sta_loninds,dims = ['time'],coords =[ds_CD.time])

ds.to_netcdf('../data/CD/coastlines.nc',mode='w',format = "NETCDF4")
ds