In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import glob

from mpl_toolkits.axes_grid1 import make_axes_locatable
import time

import preseason.tools as sf
import preseason.onset_demise as od
import preseason.plotting as pp
from scipy import stats
import dask.array as da

import cartopy
import cartopy.io.shapereader as shapereader
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

import rioxarray
import geopandas as gpd

In [None]:
from importlib import reload

reload(od)

In [12]:
def _detrend(data, slope, intercept):
    time_coord = np.arange(len(data))

    # Calculate the trend using broadcasting
    trend = slope * time_coord + intercept
    detrend = data - trend 
    return detrend

def detrend(data, slope, intercept):
    
        result = xr.apply_ufunc(
        _detrend,
        data,
        slope,
        intercept,
        input_core_dims=[['time'],[],[]],
        output_core_dims=[['time']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float],
        dask_gufunc_kwargs={'output_sizes': {'year': len(data)}}
    )
        return result

In [4]:
def create_composites(ds, dates, period):
    composites = {}
    composite_data = []
    for date in dates:
        start_date = date - pd.Timedelta(period,'w')
        end_date = date - pd.Timedelta(1,"d")
        
        subset = ds.sel(time=slice(start_date, end_date))
        
        composite = subset.mean(dim='time')
        composite_data.append(composite)
    
    composites = xr.concat(composite_data, dim='time')
    composites = composites.mean(dim='time')
    print(len(composite_data))
    return composites

In [11]:

def _linear_regression(y):
    x = np.arange(len(y))
    mask = ~np.isnan(y)
    if np.sum(mask) > 1:  # Ensure we have at least two non-NaN values
        slope, intercept, r_value, p_value, std_err = stats.linregress(x[mask], y[mask])
        return np.array([slope, intercept, r_value, p_value])
    else:
        return np.array([np.nan, np.nan, np.nan, np.nan])

# Apply the linear regression function to the data
def linear_regression(sst_data: xr.DataArray):
    result = xr.apply_ufunc(
        _linear_regression,
        sst_data,
        input_core_dims=[['time']],
        output_core_dims=[['params']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float],
        dask_gufunc_kwargs={'output_sizes': {'params': 4}}
    )

    # Add parameter names
    result = result.assign_coords(params=['slope','intercept', 'r_value', 'p_value'])

    return result
    # Compute the result
    #result = result.compute()

In [4]:
### Allows us to use dask to speed up some calculations ###
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=4, memory_limit='8GB', threads_per_worker=4)
client = Client(cluster)

In [None]:
client

In [3]:
### Selecting out Peru from the global data.
center_lat = -10
center_lon = 285

lon_radius = 50

PERU_center = {'lat': slice(-5, -10), 'lon': slice(280, 285)}




In [64]:
sst_data_dir = '/data/deluge/reanalysis/REANALYSIS/ERA5/2D/4xdaily/sst/' 
sst_files = glob.glob(sst_data_dir+'sst.[12]*')
ds_sst = xr.open_mfdataset(sst_files,parallel=True, chunks={'time' : -1})

sst_data = ds_sst['sst']

sst_data = sst_data.resample(time="D").mean(dim='time').chunk({'time' : -1, 'latitude':75, 'longitude':75})

In [None]:
sst_data


In [None]:
trends = linear_regression(sst_data).compute()



In [13]:
sst_trend  = xr.Dataset(

    data_vars=dict(
        slope=(["latitude", "longitude"], trends.sel(params='slope').data),
        intercept=(["latitude", "longitude",], trends.sel(params='intercept').data),
    ),
    coords=dict(
        latitude=(trends['latitude'].data),
        longitude=(trends['longitude'].data),
    ),
)

In [14]:
detrended_sst = detrend(sst_data, sst_trend['slope'], sst_trend['intercept'])

In [None]:
detrended_sst

In [11]:
### Onset Demise for Precipitation ###

precip_data_dir = '/data/deluge/reanalysis/REANALYSIS/ERA5/2D/daily/precip/'

precip_files = glob.glob(precip_data_dir+'precip.[12]*')


ds_p = xr.open_mfdataset(precip_files, parallel=True, chunks={'time': -1})

precip_data = ds_p['precip'].sel(latitude = PERU_center['lat'], longitude = PERU_center['lon'])

precip_anom = precip_data - precip_data.mean(dim='time')

precip_anom = precip_anom.chunk(chunks={'time':-1})

In [13]:
p_annual_cycle = sf.calc_annual_cycle(precip_data).compute()

analysis_start = od.B17_analysis_start(p_annual_cycle)

analysis_start = analysis_start.persist()

In [None]:
onset = od.onset_B17(precip_anom, analysis_start).compute()

demise = od.demise_B17(precip_anom, analysis_start).compute()

In [19]:
#onset.to_netcdf('onset_era5_peru.nc')

In [20]:
#demise.to_netcdf('demise_era5_peru.nc')

In [None]:
detrended_sst

In [16]:
onset = xr.open_dataarray('onset_era5_peru.nc')
demise = xr.open_dataarray('demise_era5_peru.nc')

In [17]:
center_loc_onset = onset.sel(latitude=center_lat, longitude=center_lon)
center_loc_demise = demise.sel(latitude=center_lat, longitude=center_lon)

In [None]:
_linear_regression(center_loc_onset.values)[0]*10

In [37]:
test = linear_regression(onset)

In [71]:
detrend_onset = detrend(onset, test.sel(params='slope').values, test.sel(params='intercept').values).sel(latitude=center_lat, longitude=center_lon)

In [72]:
early_onset = detrend_onset.where(detrend_onset < detrend_onset.quantile(0.10))
late_onset = detrend_onset.where(detrend_onset > detrend_onset.quantile(0.90))

In [18]:
early_onset = center_loc_onset.where(center_loc_onset < center_loc_onset.quantile(0.10))
late_onset = center_loc_onset.where(center_loc_onset > center_loc_onset.quantile(0.90))

In [19]:
early_demise = center_loc_demise.where(center_loc_demise < center_loc_demise.quantile(0.10))
late_demise = center_loc_demise.where(center_loc_demise > center_loc_demise.quantile(0.90))

In [73]:
early_onset_dates = sf.calcDates(early_onset).dropna(dim='year')
late_onset_dates = sf.calcDates(late_onset).dropna(dim='year')

In [21]:
early_demise_dates = sf.calcDates(early_demise).dropna(dim='year')
late_demise_dates = sf.calcDates(late_demise).dropna(dim='year')

In [None]:
composites_early_onset = create_composites(detrended_sst, early_onset_dates, period=4)
composites_late_onset = create_composites(detrended_sst, late_onset_dates, period=4)

In [None]:
composites_early_onset.plot()

In [None]:
(trends.sel(params='slope')*10).plot()

In [30]:
composites_early_demise = create_composites(sst_data, early_demise_dates, period=1)
composites_late_demise = create_composites(sst_data, late_demise_dates, period=1)


In [31]:
### ENSO Indexes Analysis ###

enso_data = pd.read_csv('~/data/enso_index/oni.data', delim_whitespace=True, index_col=0, skiprows=[0,74,75,76,77,78,79,80,81,82,83,84], names=['Year','Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun','Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])

In [None]:
### Turns 2D Year x Month representation to 1D time series ###

melted = enso_data.melt(ignore_index=False, var_name='Month', value_name='ONI')

melted['time'] = pd.to_datetime(melted['Year'].astype(str) + '-' + melted['Month'].astype(str) + '-01')

# Set the 'date' column as the index
melted.set_index('time', inplace=True)


# Drop the original 'year' and 'month' columns if you no longer need them
melted = melted.drop(columns=['Year', 'Month'])
melted.sort_index(inplace=True)
# Display the first few rows of the resulting DataFrame

# Get the data that overlaps with precipitation
ONI_index = melted['1950-01-01':'2020-01-01'].to_xarray()
print(melted.head())



In [79]:
JAS_oni = ONI_index.sel(time=ONI_index.time.dt.month.isin([7, 8, 9]))
JAS_oni= JAS_oni.groupby('time.year').mean()
JAS_oni = JAS_oni['ONI'].values

In [46]:

peru_onset = center_loc_onset[0:70]
peru_demise = center_loc_demise[0:70] 

In [106]:
season_length = sf.calcSeasonLength(peru_onset, peru_demise).values

In [None]:
plt.scatter(JAS_oni[JAS_oni<-.5], peru_onset[JAS_oni<-.5], color='blue')
plt.scatter(JAS_oni[JAS_oni>.5], peru_onset[JAS_oni>.5], color='orange')

In [None]:
plt.scatter(JAS_oni[JAS_oni<-.5], peru_demise[JAS_oni<-.5], color='blue')
plt.scatter(JAS_oni[JAS_oni>.5], peru_demise[JAS_oni>.5], color='orange')

In [None]:
plt.scatter(series.index[series > 0], series[series > 0], color='blue', label='> 0')

# Plot points <= 0 in red
plt.scatter(series.index[series <= 0], series[series <= 0], color='red', label='<= 0')



In [None]:
create_composites(detrended_sst, early_onset_dates, period=12).values

In [None]:
create_composites(test, early_demise_dates, period=12).values

In [None]:
create_composites(test, late_onset_dates, period=12).values

In [None]:
create_composites(test, late_demise_dates, period=12).values

In [None]:
center_loc_demise.plot()


In [None]:
center_loc_onset.plot()

In [None]:
plt.plot(season_length)


In [28]:
comp_diff = composites_early_onset - composites_late_onset

In [None]:
pp.plot_spatial_data(comp_diff, vmax=25, vmin=-25, cmap='RdBu_r', title='Early Onset - Late Onset 1-week Composite SST')

In [None]:
center_loc_onset.plot.hist()

In [None]:
pp.plot_spatial_data(comp_diff_demise, vmax=25, vmin=-25, cmap='RdBu_r', title='Early Demise - Late Demise 1-week Composite SST')

In [5]:
precip_data_dir = '/data/deluge/reanalysis/REANALYSIS/ERA5/2D/daily/precip/'

precip_files = glob.glob(precip_data_dir+'precip.[12]*')

In [None]:
precip_files[0]

In [None]:
peru.geometry[3]

In [101]:
test2 = gpd.GeoSeries(peru.geometry[3])

In [None]:
from shapely.geometry import mapping

#test = xr.open_dataset(precip_files[0])['precip']
ds.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)
ds.rio.write_crs("epsg:4326", inplace=True)
#peru = gpd.read_file(fname, crs="epsg:4326")#


clipped = ds.rio.clip(test2, ds.rio.crs, drop=True)

In [None]:
clipped.isel(time=10).plot()

In [18]:
shapefile = gpd.read_file(fname)

In [None]:
shapefile.geometry[0]

In [29]:
shapefile = shapefile[shapefile['HYBAS_ID'].astype(str).str.startswith('6')]

In [None]:
rioxarray.

In [54]:

data = cartopy.io.shapereader.natural_earth(
    resolution='10m', category='cultural', 
    name='admin_1_states_provinces',
)
reader = cartopy.io.shapereader.Reader(data)

states = [x for x in reader.records() if x.attributes["admin"] == "Peru"]
states_geom = cfeature.ShapelyFeature([x.geometry for x in states], ccrs.PlateCarree())

In [None]:
projection=ccrs.PlateCarree()
fig, ax = plt.subplots(1, 1, figsize=(16, 9), dpi=600,  subplot_kw={'projection': projection})

Peru.plot(ax=ax, column='HYBAS_ID', facecolor='blue', edgecolor='black', alpha=.5)
plt.xlim([-90,-25])
plt.ylim([-50,5])

states_provinces = cfeature.NaturalEarthFeature(
    category='cultural',
    name='admin_1_states_provinces_lines',
    scale='10m')
### Adding coastlines ###
ax.coastlines(edgecolor='black', linewidth=2)
ax.add_feature(cartopy.feature.BORDERS, edgecolor='black', linewidth=2)
ax.add_feature(cfeature.STATES, edgecolor='black', linewidth=2)


In [36]:
Peru = shapefile.cx[-80:-60,-20:-10]

In [None]:
def plot_spatial_data(dataarray, projection=ccrs.PlateCarree(), cmap ='twilight', vmax = 365, vmin = 1, title='Spatial Data Plot', var='data_to_plot'):
    """
    Plots a spatial figure of a variable from an xarray DataArray.

    :param dataarray: xarray DataArray containing the geospatial data to be plotted.
    :param projection: Cartopy CRS projection. Defaults to PlateCarree.
    :param title: Title of the plot.
    """
    
    states_provinces = cfeature.NaturalEarthFeature(
    category='cultural',
    name='admin_1_states_provinces_lines',
    scale='10m')
    map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)
    #cmap = mpl.cm.RdBu_r

    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    fig, ax = plt.subplots(1, 1, figsize=(16, 9), dpi=600,  subplot_kw={'projection': projection})
    p = dataarray.plot(ax=ax,transform=ccrs.PlateCarree(), add_colorbar=False, cmap=cmap,alpha = 0.8, norm=norm)


    ### Setting 1st plot parameters ###
    ax.coastlines(edgecolor='black', linewidth=2)
    ax.add_feature(cartopy.feature.BORDERS, edgecolor='black', linewidth=2)
    ax.add_feature(cfeature.STATES, edgecolor='black', linewidth=2)
    ax.add_feature(cfeature.LAKES, alpha=0.5, edgecolor='blue')
    ax.add_feature(cfeature.RIVERS, color='blue')
    #ax1.set_xticks(np.arange(-180,181, 40))
    #ax1.set_yticks(np.arange(-90,91,15))
    
    #ax1.add_artist(at)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1, axes_class=plt.Axes)
    plt.colorbar(p, cax=cax, label=var)
    
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False, linewidth=2, color='black', alpha=0.5, linestyle='--')
    gl.xlocator = mticker.FixedLocator([280-360, 285-360, 290-360])
    gl.ylocator = mticker.FixedLocator([-5, -10, -15])
    gl.left_labels = True
    gl.right_labels = False
    gl.top_labels = False
    gl.bottom_labels = True

    # Add a title
    ax.set_title(title, loc='center')

    # Show the plot
    plt.show()