In [1]:
import pytest
import pandas as pd
import numpy as np
import xarray as xr
import geopandas as gpd
from geopandas import testing as gpdt
from unittest import TestCase
from shapely.geometry import Polygon
import xesmf as xe
import importlib

from xagg_new.core import (process_weights,create_raster_polygons,get_pixel_overlaps,aggregate)
from xagg_new.aux import (normalize,fix_ds,get_bnds,subset_find)
from xagg_new.wrappers import *
from xagg_new.classes import (weightmap,aggregated)


In [2]:
ds = xr.open_dataset('~/Downloads/era5-sst.nc')
gdf = gpd.read_file('/Users/kevinschwarzwald/Downloads/xagg-sst/state1deg.shp')

In [3]:
ds2 = fix_ds(ds)

In [4]:
wm = pixel_overlaps(ds2, gdf)

creating polygons for each pixel...
lat/lon bounds not found in dataset; they will be created.
calculating overlaps between pixels and output polygons...
success!


In [5]:
# Aggregate data in [ds] onto polygons
aggregated = aggregate(ds2,wm)

aggregating sst...
all variables aggregated to polygons!


In [6]:
ds_out = aggregated.to_dataset()

In [16]:
var = 'sst'
ds2[var].

In [13]:
np.isnan(ds_out.sst).any([])

In [None]:
ds_out.sst.isel(time=0)

In [None]:
ds.isel(time=0).sst.plot()

In [None]:
pix_agg = create_raster_polygons(ds)

In [None]:
gdf_in = gdf

In [None]:
# Add an index for each polygon as a column to make indexing easier
if 'poly_idx' not in gdf_in.columns:
    gdf_in['poly_idx'] = gdf_in.index.values

# Match up CRSes
pix_agg['gdf_pixels'] = pix_agg['gdf_pixels'].to_crs(gdf_in.crs)

# Get GeoDataFrame of the overlaps between every pixel and the polygons
# (using the EASE grid https://nsidc.org/data/ease)
if np.all(gdf_in.total_bounds[[1,3]]>0):
    # If min/max lat are both in NH, use North grid
    #epsg_set = {'init':'EPSG:6931'} (change to below bc of depreciation of {'init':...} format in geopandas)
    epsg_set = 'EPSG:6931'
elif np.all(gdf_in.total_bounds[[1,3]]<0):
    # If min/max lat are both in SH, use South grid
    #epsg_set = {'init':'EPSG:6932'}
    epsg_set = 'EPSG:6932'
else:
    # Otherwise, use the global/temperate grid
    #epsg_set = {'init':'EPSG:6933'}
    epsg_set = 'EPSG:6933'

In [None]:
overlaps = gpd.overlay(gdf_in.to_crs(epsg_set),
                           pix_agg['gdf_pixels'].to_crs(epsg_set),
                           how='intersection')
    
# Now, group by poly_idx (each polygon in the shapefile)
ov_groups = overlaps.groupby('poly_idx')


In [None]:
overlap_info = ov_groups.agg(rel_area=pd.NamedAgg(column='geometry',aggfunc=lambda ds: [normalize(ds.area)]),
                                  pix_idxs=pd.NamedAgg(column='pix_idx',aggfunc=lambda ds: [idx for idx in ds]),
                                  lat=pd.NamedAgg(column='lat',aggfunc=lambda ds: [x for x in ds]),
                                  lon=pd.NamedAgg(column='lon',aggfunc=lambda ds: [x for x in ds]))
   

In [None]:
overlap_info['coords'] = overlap_info.apply(lambda row: list(zip(row['lat'],row['lon'])),axis=1)
overlap_info = overlap_info.drop(columns=['lat','lon'])

# Reset index to make poly_idx a column for merging with gdf_in
overlap_info = overlap_info.reset_index()


In [None]:
gdf_in = pd.merge(gdf_in,overlap_info,'outer')

In [None]:
wm_out = weightmap(agg=gdf_in.drop('geometry',axis=1),
               source_grid=pix_agg['source_grid'],
               geometry=gdf_in.geometry)
wm = wm_out

In [None]:
wm = weightmap

In [None]:
wm

In [17]:
ds = fix_ds(ds)

# Stack 
ds = ds.stack(loc=('lat','lon'))

# Adjust grid of [ds] if necessary to match 
ds = subset_find(ds,wm.source_grid)

In [31]:
np.atleast_1d(ds[var].dims)

array(['time', 'loc'], dtype='<U4')

In [33]:
poly_idx = 0
ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs).any([k for k in np.atleast_1d(ds[var].dims) if k != 'loc'])

In [None]:
if type(wm.weights) == pd.core.series.Series:
    weights = np.array([float(k) for k in wm.weights])
else:
    if wm.weights != 'nowghts':
        warnings.warn('wm.weights is: \n '+print(wm.weights)+
                        ', \n which is not a supported weight vector (in a pandas series) '+
                        'or "nowghts" as a string. Assuming no weights are included...')
    weights = np.ones((len(wm.source_grid['lat'])))


In [None]:
var = 'sst'

In [None]:
wm.agg[var] = None

In [None]:
poly_idx = 3

In [None]:
# Get relative areas for the pixels overlapping with this Polygon
tmp_areas = np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)



In [None]:
tmp_areas[np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time')] = np.nan

In [None]:
tmp_areas

In [None]:
np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)

In [None]:
np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all()

In [None]:
np.isnan(wm.agg.iloc[10,:].pix_idxs).all()

In [None]:
ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs).isel(time=0)

In [None]:
poly_idx = 4
# Get relative areas for the pixels overlapping with this Polygon
tmp_areas = np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)

# Replace overlapping pixel areas with nans if the corresponding pixel
# is only composed of nans
tmp_areas[np.array(np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time').values)] = np.nan
# Calculate the normalized area+weight of each pixel (taking into account
# nans)
normed_areaweights = normalize(tmp_areas*weights[wm.agg.iloc[poly_idx,:].pix_idxs],drop_na=True)


In [None]:
np.nansum(normed_areaweights)

In [None]:
np.array(np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time').values)

In [None]:
tuple(np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time').values)

In [None]:
for poly_idx in wm.agg.poly_idx:
    # Get average value of variable over the polygon; weighted by 
    # how much of the pixel area is in the polygon, and by (optionally)
    # a separate gridded weight
    if not np.isnan(wm.agg.iloc[poly_idx,:].pix_idxs).all():
        # Check first if any nans are "complete" (meaning that a pixel 
        # either has values for each step, or nans for each step - if
        # there are random nans within a pixel, throw a warning)
        if not xr.Dataset.equals(np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).any('time'),
                                 np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time')):
            warnings.warn('Variable '+var+' has occasional nans in the dimension(s) '+
                          ', '.join([k for k in ds[var].dims if k != 'loc'])+
                          '. The code can currently only deal with pixels for which the '+
                          'entire pixel has nan values in all dimensions; the aggregation '+
                          'calculation is likely incorrect.')


        # Get relative areas for the pixels overlapping with this Polygon
        tmp_areas = np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)
        # Replace overlapping pixel areas with nans if the corresponding pixel
        # is only composed of nans
        tmp_areas[np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time')] = np.nan
        # Calculate the normalized area+weight of each pixel (taking into account
        # nans)
        normed_areaweights = normalize(tmp_areas*weights[wm.agg.iloc[poly_idx,:].pix_idxs],drop_na=True)

        # Take the weighted average of all the pixel values to calculate the 
        # aggregated value for the shapefile
        wm.agg.loc[poly_idx,var] = [[((ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)*
                                        normed_areaweights).
                                       sum('loc')).values]]
        
        print('polygon: '+str(poly_idx)+', val: '+str(wm.agg.loc[poly_idx,var][0][0]))
    else:
        #breakpoint()
        #wm.agg.loc[poly_idx,var] = [[np.array(np.nan)]]
        wm.agg.loc[poly_idx,var] = [[(ds[var].isel(loc=0)*np.nan).values]]

In [None]:
[[((ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)*
                                        normed_areaweights).
                                       sum('loc')).values]]

In [None]:
poly_idx = wm.agg.poly_idx.iloc[-1]

In [None]:
wm.agg.iloc[poly_idx,:].pix_idxs

In [None]:
ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs,time=0)

In [None]:
np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time')

In [None]:
tmp_areas = np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)

In [None]:
tmp_areas[np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time')] = np.nan

In [None]:
normed_areas = normalize(tmp_areas*weights[wm.agg.iloc[poly_idx,:].pix_idxs],drop_na=True)

In [None]:
xr.Dataset.equals(np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).any('time'),
                  np.isnan(ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)).all('time'))

In [None]:
[[((ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)*
                       normed_areas).
                      sum('loc')).values]]

In [None]:
def normalize(a,drop_na = False):
    if (drop_na) & (np.any(np.isnan(a))):
        a2 = a[~np.isnan(a)]
        a2 = a2/a2.sum()
        a[~np.isnan(a)] = a2

        return a
    
    elif (np.all(~np.isnan(a))) & (a.sum()>0):
        return a/a.sum()
    else:
        return a*np.nan

In [None]:
np.nansum(normalize(tmp_areas,drop_na=True))

In [None]:
a2 = a[~np.isnan(a)]

In [None]:
a3 = a

In [None]:
a3[~np.isnan(a)] = a2/a2.sum()

In [None]:
a2/a2.sum()

In [None]:
[[((ds[var].isel(loc=wm.agg.iloc[poly_idx,:].pix_idxs)*
                                                           normalize(np.squeeze(wm.agg.iloc[poly_idx,:].rel_area)*
                                                             weights[wm.agg.iloc[poly_idx,:].pix_idxs])).
                                                          sum('loc')).values]]

In [None]:
ds_out = aggregated.to_dataset()

In [None]:
gdf_out = aggregated.to_shp()

In [None]:
gdf.plot()

In [None]:
ds_out.sst.

In [None]:
ds_out.isel(time=0).sst