# Setup

In [None]:
# Packages -----------------------------------------------#

# Data Analysis
import xarray as xr
import numpy as np
import pandas as pd
import metpy.calc as mpcalc
import matplotlib.dates as dates
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# Timing Processes and Progress
import time
from tqdm import tqdm

# make sure the figures plot inline rather than at the end
%matplotlib inline

# get data
path = '/home/jennap/projects/LRGROUP/shared_data/chl-globcolor-100km-case-1-and-case-2-waters/concatenated-monthly/'
chlinfn = 'all_L3m_AV_CHL1_100km_global_monthly_merged_1997_2020.nc'

ds = xr.open_dataset(path + chlinfn)
ds
# ds.chl1_mean.attrs["units"]
# # print(ds.keys())

# Subset ------------------------------------------------#
# Create slice variables to subset domain before finding means
lat_slice = slice(-20, 30) # bounds inclusive
lon_slice = slice(35, 120) # bounds inclusive
time_slice = slice('1993-01-01','2012-12-31')

# Get data, selecting lat/lon slice
daily_chl = ds['chl1_mean'].sel(lat=lat_slice,lon=lon_slice)
lat = daily_chl.lat.values
lon = daily_chl.lon.values

del lat_slice, lon_slice

# 1993-2012 anomaly

In [None]:
mchl = daily_chl.sel(latitude=lat_slice,longitude=lon_slice, time = time_slice).mean(axis=0,skipna=True)
mon_chl = mon_chl - np.nanmean(mon_chl,0)

# Detrend

In [None]:
# yearly
mon_chl_year_clim = mon_chl.groupby('time.year').mean('time')


In [None]:
%%time

# stack lat and lon into a single dimension called allpoints
stacked = mon_chl.stack(allpoints=['lat','lon'])

# set places where there are nans to zero since polyfit can't deal with them
stacked_nonan = stacked.fillna(0)

# convert date to a number to polyfit can handle it
datenum = dates.date2num(stacked_nonan.time)
mon_chl_slope, mon_chl_intercept = np.polyfit(datenum, stacked_nonan, 1)

#reshape the data
mon_chl_slope = np.reshape(mon_chl_slope, mon_chl.shape[1:3])
mon_chl_intercept = np.reshape(mon_chl_intercept, mon_chl.shape[1:3])

# define a function to compute a linear trend of a timeseries
def linear_detrend(y):
    x = dates.date2num(y.time)
    m, b = np.polyfit(x, y, 1)
    # we need to return a dataarray or else xarray's groupby won't be happy
    return xr.DataArray(y - (m*x + b))

# apply the function over allpoints to calculate the trend at each point
mon_chl_dtrnd = stacked_nonan.groupby('allpoints').apply(linear_detrend)
# unstack back to lat lon coordinates
mon_chl_dtrnd = mon_chl_dtrnd.unstack('allpoints')

# fill all points we set originally to zero back to nan
mon_chl_dtrnd = mon_chl_dtrnd.where(~np.isnan(mon_chl))

# delete trended data to save on memory
del mon_chl,ds, stacked, stacked_nonan

In [None]:
mon_chl_year_clim_dtrnd = mon_chl_dtrnd.groupby('time.year').mean('time')

var = mon_chl_year_clim-mon_chl_year_clim_dtrnd

years = list(range(1997,2021))

p = var.plot.pcolormesh(x="lon", y="lat", col="year", col_wrap=5,
                                        cmap="coolwarm",
                                        vmax=0.1,vmin=-0.1, # set colorbar lims
                                        extend = 'neither', # make a box colorbar rather than pointed
                                        figsize = (14, 12),
                                        cbar_kwargs={"label": "Chlorophyll-a Anomaly (mg/m^3)"},
                                        subplot_kws={'facecolor': 'gray'}
                                       )

for ii, ax in enumerate(p.axes.flat):
    if ii < len(years):
        ax.set_title(years[ii])
        ax.axes.axis('tight')

p.set_xlabels('Longitude')
p.set_ylabels('Latitude')

# Downsample to Monthly and Seasonal Temporal Resolution

In [None]:
%%time
# monthly
mon_chl_dtrnd = daily_chl_dtrnd.resample(time='1MS').mean(dim="time")
# seasonal
seas_chl_dtrnd = mon_chl_dtrnd.resample(time='QS-DEC').mean(dim="time")

# Find Climatologies 
Resources: [link](http://xarray.pydata.org/en/stable/examples/monthly-means.html)

In [None]:
%%time
# -------------------------------------------
# weighted seasonal
# -------------------------------------------

# get months
month_length = mon_chl_dtrnd.time.dt.days_in_month

# calculate the weights by grouping by 'time.season'.
weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()

# calculate the weighted average
chl_seas_clim = (mon_chl_dtrnd * weights).groupby('time.season').sum(dim='time')   

# set the places that are now zero from the weights to nans
chl_seas_clim = chl_seas_clim.where(chl_seas_clim != 0,np.nan) # for some reason .where sets the locations not in the condition to nan by default

# -------------------------------------------
# monthly
# -------------------------------------------

chl_mon_clim = mon_chl_dtrnd.groupby('time.month').mean('time') 

# Find Anomalies

In [None]:
%%time
# monthly avg data - monthly climatology
mon_chl_mon_anom = mon_chl_dtrnd.groupby('time.month') - chl_mon_clim

# seasonal avg data - seasonal climatology
seas_chl_seas_anom = mon_chl_dtrnd.groupby('time.season') - chl_seas_clim

In [None]:
# convert to xarray dataset
ds=xr.Dataset(coords={'lon': mon_chl_dtrnd.lon,
                    'lat': mon_chl_dtrnd.lat,
                    'time': mon_chl_dtrnd.time})

# add variables to dataset

ds["mon_chl"]=xr.DataArray(mon_chl_dtrnd,dims = ['time','lat', 'lon'],
                     coords =[mon_chl_dtrnd.time,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])
ds["seas_chl"]=xr.DataArray(seas_chl,dims = ['season_time','lat', 'lon'],
                     coords =[seas_chl.time,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])

# clim
ds["chl_mon_clim"]=xr.DataArray(chl_mon_clim,dims = ['month','lat', 'lon'],
                     coords =[chl_mon_clim.month,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])
ds["chl_seas_clim"]=xr.DataArray(chl_seas_clim,dims = ['season','lat', 'lon'],
                     coords =[chl_seas_clim.season,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])


# anom
ds["mon_chl_mon_anom"]=xr.DataArray(mon_chl_mon_anom,dims = ['time','lat', 'lon'],
                     coords =[mon_chl_mon_anom.time,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])
ds["seas_chl_seas_anom"]=xr.DataArray(seas_chl_seas_anom,dims = ['season_time','lat', 'lon'],
                     coords =[seas_chl_seas_anom.time,mon_chl_dtrnd.lat,mon_chl_dtrnd.lon])


In [None]:
import os

outfn = chlinfn[:-3] + '_processed.nc'

# delete if already present
if os.path.isfile(outfn):
    os.remove(outfn)

ds.to_netcdf(outfn,mode='w',format = "NETCDF4")

ds