In [None]:
import sys
import os
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML, display # NOTE: will require ffmpeg installation
import numpy as np
import xarray as xr
import pandas as pd
import geopandas as gpd
import rioxarray
from shapely.geometry import mapping
import seaborn as sns
from scipy.stats import pearsonr
# from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes (for inset plots)
# from mpl_toolkits.axes_grid1.inset_locator import mark_inset (for inset plots)
from xarrayutils.utils import linear_trend, xr_linregress
from eofs.xarray import Eof

In [None]:
from eofs.standard import Eof

In [None]:
# Define project repo path
inDirName = '/Users/smurugan9/research/aislens/aislens_emulation/'

# Data file paths
regriddedFluxSSH_filepath = 'data/interim/RegriddedFluxSSH.nc' # Data containing regridded flux and SSH for 150 years
iceShelvesShape_filepath = 'data/interim/iceShelves.geojson' # File contains all defined ice shelves
figures_folderpath = 'reports/figures/' # Folder for output figures

In [None]:
data = xr.open_dataset(inDirName + regriddedFluxSSH_filepath)

In [None]:
# Read geoJSON region feature file as GeoDataFrame
iceshelvesmask = gpd.read_file(inDirName + iceShelvesShape_filepath)
# Convert to south polar stereographic projection
icems = iceshelvesmask.to_crs({'init': 'epsg:3031'});
crs = ccrs.SouthPolarStereo();
# Specify projection for data file
data.rio.write_crs("epsg:3031",inplace=True);

In [None]:
## Following section tests the code for 1 ice shelf / catchment:
# Iceshelf chosen: 
#  34 : Amery
# 103 : Ronne        #TODO: Split polygons for Ronne1, Ronne2 etc.
# 104 : Western Ross #TODO: Split polygons for W-Ross1, W-Ross2 etc.
# 105 : Eastern Ross #TODO: Split polygons for E-Ross1, E-Ross2 etc.
# 114 : Thwaites

basin = 34
basinName = icems.name.values[basin]

In [None]:
# Mask data to chosen basin
ds = data.rio.clip(icems.loc[[basin],'geometry'].apply(mapping),icems.crs,drop=False)

In [None]:
dsn = ds.dropna('y',how='all')
dsn = dsn.dropna('x',how='all')
dsn = dsn.dropna('time',how='all')

In [None]:
flxn = dsn.timeMonthly_avg_landIceFreshwaterFlux
hn = dsn.timeMonthly_avg_ssh

In [None]:
# Time mean of melt flux across spatial domain
flxn_tmean = flxn.mean('time')
flxn_mean = flxn.mean()
# Time series of integrated melt flux across ice shelf basin
flxn_ts = flxn.sum(['x','y']) - flxn_mean

In [None]:
eps = flxn - flxn_tmean

In [None]:
eps_spatialmean = eps.mean(['x','y'])

In [None]:
reconstr = eps+flxn_tmean
flxn.mean('time').plot()

In [None]:
reconstr.mean('time').plot()

In [None]:
# (1) Phase randomization of raw data
# Execute to test and validate that spatial variability information is retained

# Plot phase randomized data
plt.figure(figsize=(25, 8), dpi=80)

spinuptime = 0 # Ignore first few years of data in the phase randomization
n_realizations = 50 # Number of random Fourier realizations

new_fl = np.empty((n_realizations,data.timeMonthly_avg_landIceFreshwaterFlux[spinuptime:].size))

#new_fl = xr.DataArray(0, dims=["realization", "time"])


# Time limits for plotting
t1 = 60
tf = 800

for i in range(n_realizations):    
    fl = data.timeMonthly_avg_landIceFreshwaterFlux[spinuptime:]
    fl_fourier = np.fft.rfft(fl)
    random_phases = np.exp(np.random.uniform(0,2*np.pi,int(len(fl)/2+1))*1.0j)
    fl_fourier_new = fl_fourier*random_phases
    new_fl[i,:] = np.fft.irfft(fl_fourier_new)
    plt.plot(new_fl[i,t1:tf],'b', linewidth=0.15)

plt.plot(new_fl[8,t1:tf],'b', linewidth=1, label='Randomized Output')
plt.plot(new_fl[28,t1:tf],'b', linewidth=1)
plt.plot(new_fl[35,t1:tf],'b', linewidth=1)
plt.plot(new_fl[45,t1:tf],'b', linewidth=1)
plt.plot(fl[t1:tf],'k', linewidth=3, label='MPAS Output')
plt.title('Deseasonalized & Detrended Flux (Years: {:.1f} - {:.1f}): {}'.format((spinuptime+t1)/12,(spinuptime+tf)/12,basinName))
plt.ylabel('landIceFreshwaterFlux')
plt.legend()

In [None]:
flxn_ts = fl
#(orig_ts+eps).plot()
flxn.mean('time').plot(cmap='viridis')

In [None]:
flxn_reconstr = eps_spatialmean+flxn_tmean

In [None]:
flxn_reconstr.mean('time').plot(cmap='viridis',vmin=0)

In [None]:
flxn_k1_ts = eps_spatialmean[spinuptime:].copy(data=new_fl[8])
flxn_k1 = flxn_k1_ts + flxn_tmean

flxn_k2_ts = eps_spatialmean[spinuptime:].copy(data=new_fl[28])
flxn_k2 = flxn_k2_ts + flxn_tmean

flxn_k3_ts = eps_spatialmean[spinuptime:].copy(data=new_fl[35])
flxn_k3 = flxn_k3_ts + flxn_tmean

flxn_k4_ts = eps_spatialmean[spinuptime:].copy(data=new_fl[45])
flxn_k4 = flxn_k4_ts + flxn_tmean

flxn.plot()
flxn_k1.plot()
flxn_k2.plot()
flxn_k3.plot()
flxn_k4.plot()

In [None]:
plt.figure(figsize=(25, 8), dpi=80)
plt.plot(flxn_ts, 'k', linewidth=1.5)
plt.plot(flxn_k1_ts, linewidth=0.5)
plt.plot(flxn_k2_ts, linewidth=0.5)
plt.plot(flxn_k3_ts, linewidth=0.5)
plt.plot(flxn_k4_ts, linewidth=0.5)

In [None]:
##==================================================
##==============ANIMATION===========================
##==================================================
# Flux trend in time: Contourf animation

# Get a handle on the figure and the axes
fig, ax = plt.subplots(figsize=(15,8), subplot_kw={'projection': ccrs.SouthPolarStereo()})

# vmin=-0.000005
# vmax=0.0008

# Plot the initial frame.
# vmin = np.min(flux), vmax = np.max(flux) obtained manually. These should be modified to skip ocean flux values
cax = flxn[1,:,:].plot(add_colorbar=True,
                       cmap='coolwarm',vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})

# Next we need to create a function that updates the values for the colormesh, as well as the title.
def animate(frame):
    cax.set_array(flxn[frame,:,:].values.flatten())
    ax.set_title("time = " + str(flxn.coords['time'].values[frame])[:7])

# Finally, we use the animation module to create the animation.
ani = animation.FuncAnimation(
    fig,             # figure
    animate,         # name of the function above
    frames=500,       # Could also be iterable or list
    interval=100     # ms between frames
)

In [None]:
# View animation in browser / save to file
HTML(ani.to_jshtml())
# ani.save(inDirName+figures_folderpath+'AIS_flux.mp4')

In [None]:
# View animation in browser / save to file
HTML(ani2.to_jshtml())
# ani.save(inDirName+figures_folderpath+'AIS_flux.mp4')

In [None]:
##==================================================
##==============ANIMATION===========================
##==================================================
# Flux trend in time: Contourf animation

# Get a handle on the figure and the axes
fig, ax = plt.subplots(figsize=(15,8), subplot_kw={'projection': ccrs.SouthPolarStereo()})

# vmin=-0.000005
# vmax=0.0008

# Plot the initial frame.
# vmin = np.min(flux), vmax = np.max(flux) obtained manually. These should be modified to skip ocean flux values
cax = flxn_k3[1,:,:].plot(add_colorbar=True, 
                       cmap='coolwarm', vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})

# Next we need to create a function that updates the values for the colormesh, as well as the title.
def animate(frame):
    cax.set_array(flxn_k3[frame,:,:].values.flatten())
    ax.set_title("time = " + str(flxn_k3.coords['time'].values[frame])[:7])

# Finally, we use the animation module to create the animation.
ani2 = animation.FuncAnimation(
    fig,             # figure
    animate,         # name of the function above
    frames=500,       # Could also be iterable or list
    interval=100     # ms between frames
)

In [None]:
# create a figure and axes
fig, (ax1,ax2,ax3,ax4) = plt.subplots(4,1, figsize=(15,8), subplot_kw={'projection': ccrs.SouthPolarStereo()})

# set up the subplots as needed

cax = flxn[1,:,:].plot(ax = ax1, add_colorbar=True,
                       cmap='coolwarm', vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})

cax2 = flxn_k2[1,:,:].plot(ax = ax2, add_colorbar=True,
                       cmap='coolwarm', vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})

cax3 = flxn_k3[1,:,:].plot(ax = ax3, add_colorbar=True,
                       cmap='coolwarm', vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})

cax4 = flxn_k4[1,:,:].plot(ax = ax4, add_colorbar=True,
                       cmap='coolwarm', vmax=5.5e-5, vmin=0,
                       cbar_kwargs={'extend':'neither'})


# Next we need to create a function that updates the values for the colormesh, as well as the title.
def animate(frame):
    cax.set_array(flxn[frame,:,:].values.flatten())
    ax1.set_title("time = " + str(flxn.coords['time'].values[frame])[:7])
    #ax1.set_title('Original Data')
    cax2.set_array(flxn_k2[frame,:,:].values.flatten())
    #ax2.set_title("time = " + str(flxn_k2.coords['time'].values[frame])[:7])
    ax2.set_title('Phase Randomized Data, k=2')
    cax3.set_array(flxn_k3[frame,:,:].values.flatten())
    #ax3.set_title("time = " + str(flxn_k3.coords['time'].values[frame])[:7])
    ax3.set_title('Phase Randomized Data, k=3')
    cax4.set_array(flxn_k4[frame,:,:].values.flatten())
    #ax4.set_title("time = " + str(flxn_k4.coords['time'].values[frame])[:7])
    ax4.set_title('Phase Randomized Data, k=4')

# Finally, we use the animation module to create the animation.
ani3 = animation.FuncAnimation(
    fig,             # figure
    animate,         # name of the function above
    frames=500,       # Could also be iterable or list
    interval=100     # ms between frames
)

In [None]:
HTML(ani3.to_jshtml())

In [None]:
ani3.save(inDirName+figures_folderpath+'phase_randomized.gif')

In [None]:
ani3.save(inDirName+figures_folderpath+'phase_randomized.mp4')

In [None]:
pcs

In [None]:
pcs.sum('mode').plot()

In [None]:
flxn_stack = flxn.stack(z=("x", "y"))

In [None]:
def get_time_space(df, time_dim, lumped_space_dims):
    return df.set_index([time_dim]+lumped_space_dims).unstack(lumped_space_dims)

In [None]:
df = flxn.to_dataframe().reset_index() 

In [None]:
df_data = get_time_space(df, time_dim = "time", lumped_space_dims = ["x","y"])

In [None]:
data_solver = Eof(data.timeMonthly_avg_landIceFreshwaterFlux.to_numpy())

In [None]:
data.timeMonthly_avg_landIceFreshwaterFlux.mean('time').plot()

In [None]:
ais_pcs = data_solver.pcs()

In [None]:
ais_eofs = data_solver.eofs()

In [None]:
ais_eigenvalues = data_solver.eigenvalues()

In [None]:
plt.figure(figsize=(25,8))
plt.plot(ais_pcs)

In [None]:
plt.figure(figsize=(25,8))
plt.plot(ais_pcs[:,4:6])

In [None]:
plt.figure(figsize=(10,10))
#plt.pcolormesh(ais_eofs[1000],vmin = 0, vmax = 0.05)
plt.pcolormesh(ais_eofs[30])
plt.colorbar()

In [None]:
ais_eofs.shape

In [None]:
# Catchment-specific analysis
flxn_np = flxn.to_numpy()

In [None]:
npsolver = Eof(flxn_np)

In [None]:
pcs = npsolver.pcs()

In [None]:
pcs = npsolver.pcs(npcs=5, pcscaling=1)
plt.plot(pcs[:,:])

In [None]:
eofs = npsolver.eofs()

In [None]:
plt.pcolormesh(eofs[600])
plt.colorbar()

In [None]:
eof1 = npsolver.eofs(neofs=1, eofscaling=1)

In [None]:
plt.pcolormesh(eofs[200])
plt.colorbar()

In [None]:
data = 

In [None]:
solver = Eof(flxn_stack)

In [None]:
pcs = solver.pcs()

In [None]:
pcs = solver.pcs(npcs=10, pcscaling=0)

In [None]:
eigenvalues = solver.eigenvalues()

In [None]:
eofs = solver.eofs()

In [None]:
plt.figure(figsize=(25,8))
plt.plot(eigenvalues[:100], '*')

In [None]:
flx_rgrs.intercept.plot()

In [None]:
flx_rgrs = xr_linregress(hn, flxn.mean('time'), dim='time')

In [None]:
flx_prd = flx_rgrs.intercept + flx_rgrs.slope*hn.mean('time')
flx_ddrft = flxn - flx_prd

In [None]:
####################
#### plot tests
####################

# plt.figure(figsize=(15,5))
#flxn.sum(['x','y']).plot()
#flx_dedraft.sum(['x','y']).plot()

#plt.subplot(projection=ccrs.SouthPolarStereo())
#flx_rgrs.r_value.plot()

In [None]:
flpos = flxn.where(flxn>0) # Values where flux is positive, i.e., into the ocean
flneg = flxn.where(flxn<0) # Values where flux is negative, i.e., out of the ocean

fig, axs = plt.subplots(2, figsize=(15,10))
flneg.plot(ax = axs[0])
flpos.plot(ax = axs[1])

# What do the negative flux values, while minimal, signify?

In [None]:
# Spatial scaling factor, done after dedrafting
# average over time for each pixel
# remove this (scale this in comparison to total average over catchment)

flx_scl_fctr = flx_ddrft.mean('time')   # pixel-by-pixel time mean
flx_tsm = flx_scl_fctr.mean(['y','x']) # spatial mean over time, i.e., "spatial mean flux", SCALAR throughout
#flx_scl = flx_tsm - flx_scl_fctr
#flxn_scl = flxn_scl.transpose('time','y','x')

In [None]:
hn[100].plot()

In [None]:
# scl = flx_scl_fctr/flx_scl
scl = flx_scl_fctr/flx_tsm
scl = scl.transpose('time','y','x') # check dimension ordering before running this

In [None]:
flx_scl = flx_tsm - flx_scl_fctr
#flx_scl = flx_scl.transpose('time','y','x')

In [None]:
flx_scl.plot()

In [None]:
#fig, axs = plt.subplots(2,3)
flxn.plot()
flx_ddrft.plot()
#flx_scl.plot()

In [None]:
# Remove climatologies to isolate anomalies / deseasonalize 
flx_monthly = flx_ddrft.groupby("time.month") # flxn or flx_scl? How is spatial scaling incorporated?
flx_clm = flx_monthly.mean("time") # Climatologies
flx_anm = flx_monthly - flx_clm # Deseasonalized anomalies

# Integrate over entire AIS / basin for time series
flx_clm_ts = flx_clm.sum(['y','x']) # Seasonality / Climatology?
flx_anm_ts = flx_anm.sum(['y','x'], skipna=True)

In [None]:
diff = flx_ddrft.sum(['y','x']) - flx_anm_ts

In [None]:
flx_anm_ts.plot()
diff.plot()

In [None]:
eps = flxn.mean('time')

In [None]:
flx_rebuild_ts = flx_anm_ts

In [None]:
flx_rebuild = flx_rebuild_ts*flx_scl_fctr
flx_rebuild.sum(['y','x']).plot()

In [None]:
# Plot phase randomized data
plt.figure(figsize=(25, 8), dpi=80)

spinuptime = 0 # Ignore first few years of data in the phase randomization
n_realizations = 50 # Number of random Fourier realizations

new_fl = np.empty((n_realizations,flxn.sum(['x','y'])[spinuptime:].size))

# Time limits for plotting
t1 = 60
tf = 800

for i in range(n_realizations):    
    fl = flxn.sum(['x','y'])[spinuptime:]
    fl_fourier = np.fft.rfft(fl)
    random_phases = np.exp(np.random.uniform(0,2*np.pi,int(len(fl)/2+1))*1.0j)
    fl_fourier_new = fl_fourier*random_phases
    new_fl[i,:] = np.fft.irfft(fl_fourier_new)
    plt.plot(new_fl[i,t1:tf],'b', linewidth=0.15)

plt.plot(new_fl[45,t1:tf],'b', linewidth=1, label='Randomized Output')
plt.plot(new_fl[10,t1:tf],'b', linewidth=1)
plt.plot(new_fl[40,t1:tf],'b', linewidth=1)
plt.plot(fl[t1:tf],'k', linewidth=3, label='MPAS Output')
plt.title('Deseasonalized & Detrended Flux (Years: {:.1f} - {:.1f}): {}'.format((spinuptime+t1)/12,(spinuptime+tf)/12,basinName))
plt.ylabel('landIceFreshwaterFlux')
plt.legend()

In [None]:
orig_ts = fl
(orig_ts+eps).plot()
flxn.plot()
#eps.plot()

In [None]:
flx_recon = orig_ts+eps

In [None]:
flxn.plot()

In [None]:
orig_ts = fl
#orig_fl = (orig_ts+flx_scl_fctr)/788 # + flx_clm.unstack()

# Add seasonality back
# orig_fl = 0.000028+(orig_ts+diff)*flx_scl_fctr/(flx_tsm*788)/100 # factor of 100?
orig_fl = (orig_ts+diff)*flx_scl_fctr/(flx_tsm*788) # factor of 100?
plt.figure(figsize=(25, 5), dpi=80)
orig_fl.sum(['y','x']).plot()
# 788 is the number of actual pixels with moving data points in time
# stack dataarray to get 788

plt.figure(figsize=(25,5))
orig_fl.plot(label='reconstructed')
#flx_ddrft.plot(label='dedrafted')
flx_anm.plot(label='original')
plt.legend()

In [None]:
fig, axes = plt.subplots(nrows=2, figsize=(10,8))
flxn.mean('time').plot(ax=axes[0])
orig_fl.mean('time').plot(ax=axes[1])

In [None]:
np.min(flxn)

In [None]:
new_ts = xr.DataArray(data=new_fl[25],dims="time",coords=orig_ts.coords)
fl_new = (new_ts+diff)*flx_scl_fctr/flx_tsm/788

In [None]:
plt.figure(figsize=(25, 8), dpi=80)
plt.plot(orig_fl.sum(['y','x']), 'k', linewidth=2, label='MPAS Output')
plt.plot(fl_new.sum(['y','x']), 'b', linewidth=0.5, label='Randomized Output')
plt.legend()

In [None]:
plt.figure(figsize=(18,5))
orig_fl.plot(label='reconstructed')
#flx_ddrft.plot(label='dedrafted')
flx_anm.plot(label='anomalies')
plt.legend()

In [None]:
plt.figure(figsize=(25,5))
fl_new.plot(label='new')
orig_fl.plot(label='original - reconstructed')
#flx_ddrft.plot(label='dedrafted')
flx_anm.plot(label='actual')
plt.legend()

In [None]:
fl_new.plot()