# Tutorial on collocating a dataset  with lagged data

#1) You will need to point the code to your input file.

#2) Ensure that your input file has the following Variables and Names:
    #Latitude = 'lat'
    #Longitude = 'lon'
    #Time = 'time'
    #Year = 'year'
    #Month = 'month'
    #Day = 'day'
    #Date = 'date'


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
import warnings
import timeit
import glob
# filter some warning messages
warnings.filterwarnings("ignore") 
#from geopy.distance import geodesic 

####################you will need to change some paths here!#####################
#list of input files
#filename_bird='f:/data/project_data/NASA_biophysical/collocated_data/zoo_selgroups_HadSST_relabundance_5aug2019_plumchrusV_4regions_final.csv'

#adir = 'F:/data/project_data/NASA_biophysical/jeff_collocations/'
adir = './../../data/'

filename_bird=adir + 'allStationMetadata.csv'

#output files
filename_bird_out= adir + 'zoo_selgroups_HadSST_relabundance_5aug2019_plumchrusV_4regions_final_satsst'
#################################################################################

#import intake
import dask
import dask.array as dsa
import gcsfs
import fsspec

#subroutines to read data
#some of the data is on pangeo gcp, some on AWS
import sys
sys.path.append('./../subroutines/')  #where your
from get_data_pangeo import get_data,get_sst
#from get_data_local import get_sst


## Reading CSV datasets

In [None]:
#read in csv file in to panda dataframe & into xarray
df_bird = pd.read_csv(filename_bird)
df_bird

In [None]:
# calculate time, it needs a datetime64[ns] format
df_bird.insert(0,'Year',df_bird['year'])
df_bird.insert(1,'Month',df_bird['month'])
df_bird.insert(2,'Day',df_bird['day'])
df_bird=df_bird.drop(columns={'day','month','year'})
df_bird['time'] = df_bird['time'].apply(lambda x: x.zfill(8))
df_bird.insert(3,'Hour',df_bird['time'].apply(lambda x: x[:2]))
df_bird.insert(4,'Min',df_bird['time'].apply(lambda x: x[3:5]))
df_bird.insert(0,'time64',pd.to_datetime(df_bird[list(df_bird)[0:4]]))
df_bird=df_bird.drop(columns={'Day','Month','Year','Hour','Min','time','date'})

# transform to x array
ds_bird = df_bird.to_xarray()

In [None]:
df_bird

In [None]:
#just check lat/lon & see looks okay
minlat,maxlat=ds_bird.lat.min(),ds_bird.lat.max()
minlon,maxlon=ds_bird.lon.min(),ds_bird.lon.max()
plt.scatter(ds_bird.lon,ds_bird.lat)
print(minlat,maxlat,minlon,maxlon)

In [None]:
#from dask_gateway import Gateway
#from dask.distributed import Client
#gateway = Gateway()
#cluster = gateway.new_cluster()
#cluster.adapt(minimum=1, maximum=200)
#cluster.scale(50)
#client = Client(cluster)
#cluster

In [None]:
#Resolution of the Satellite Data is 0.2 degrees.
#If you want to smooth the data, change smooth_lat and smooth lon to the number of 
#grid cells you want to include.  
#example: smooth_lat = 3 will compute a rolling average over 3 grid cells which equals 0.6 deg 
#smooth_lat=1
#smooth_lon=1
#ds = ds.rolling(lat=smooth_lat,center=True,keep_attrs=True).mean(keep_attrs=True)
#ds = ds.rolling(lon=smooth_lon,center=True,keep_attrs=True).mean(keep_attrs=True)
#ds

# Collocate all data with bird data

In [None]:
ilen_bird = len(ds_bird.lat)
ilen_bird

In [None]:
#get MUR SST
#file_location = 's3://mur-sst/zarr'
#ds = xr.open_zarr(fsspec.get_mapper(file_location, anon=True),consolidated=True)
#ds_sst = ds.drop({'analysis_error','mask','sea_ice_fraction'})
#tem = ds_sst.analysed_sst.attrs
#tem['var_name']='mur_sst'
#ds_sst.analysed_sst.attrs=tem
#data={'sst':ds_sst}

In [None]:
#just include last 30 days of data
data,clim = get_data()
for name in data:
    ds_data=data[name]
    print('data',name)
    if name=='aviso':
        continue
    if name=='wnd':
        continue
    if name=='color':
        continue
    if name=='topo':
        continue
    if name=='topo':
        temlat,temlon = ds_bird.lat,ds_bird.lon
        tem2=ds_data.z.interp(lat=temlat,lon=temlon,method='nearest') 
        ds['ETOPO_depth']=xr.DataArray(tem2.data, coords={'index': ds_bird.index}, dims=('index'))
        ds['ETOPO_depth'].attrs = ds_data.attrs
    else:
        for var in ds_data:
            var_tem=var
            ds_bird[var_tem]=xr.DataArray(np.empty((ilen_bird,31), dtype=str(ds_data[var].dtype)), coords={'index': ds_bird.index,'dtime':np.arange(-30,1)}, dims=('index','dtime'))
            ds_bird[var_tem].attrs=ds_data[var].attrs
        for i in range(len(ds_bird.lat)):
            t1,t2 = ds_bird.time64[i]-np.timedelta64(30,'D')+np.timedelta64(9,'h'), ds_bird.time64[i]+np.timedelta64(9,'h')
            lat1,lat2=ds_bird.lat[i]-.25,ds_bird.lat[i]+.25
            lon1,lon2=ds_bird.lon[i]-.25,ds_bird.lon[i]+.25
            if name=='color':   #lat pos to neg
                lat2,lat1=ds_bird.lat[i]-.25,ds_bird.lat[i]+.25
            tem = ds_data.sel(time=slice(t1,t2),lat=slice(lat1,lat2),lon=slice(lon1,lon2)).load()
            tem = tem.interp(lat=ds_bird.lat[i],lon=ds_bird.lon[i])
            for var in ds_data:
                var_tem=var
                ds_bird[var_tem][i,:]=tem[var].data
            if int(i/1)*1==i:
                print(i,len(ds_bird.lat))
    #output data
    ds_bird.to_netcdf(filename_bird_out+name+'.nc')
    print('output:',filename_bird_out+name+'.nc')

In [None]:
#put it all together and create a csv file
filename = glob.glob('./../data/'+filename_bird_out+'*.nc')
fnames

print(filename[0])
ds = xr.open_dataset(filename[0])
for iname in range(1,len(filename)):
    print(filename[iname])
    ds2 = xr.open_dataset(filename[iname])
    for var in ds2:
        if not var in ds:
            ds[var]=ds2[var]
                
ds.to_netcdf(filename_bird_out+'all'+'.nc')
df_bird = ds.to_dataframe()
df_bird.to_csv(filename_bird_out+'all'+'.csv')


# Code below here doesn't run

In [None]:
#rolling means at set days
#this has errors that I don't understand
#run without dask cluster it just dies quietly with no errors
#run with a dask cluster it dies with get_item errors and kills cluster
#also rolling keep_attrs doesn't seem to preserve attributes so I had to write subroutine
data,clim = get_data()
for name in data:
    ds_data=data[name]
    ds_mon = ds_data.rolling(time=30, center=False).mean(dim='time',keep_attrs=True)
    ds_mon = set_attr(ds_data,ds_mon)
    ilat,ilon = len(ds_mon.lat.data),len(ds_mon.lon.data)
    ds_15 = ds_data.rolling(time=15, center=False).mean(dim='time',keep_attrs=True)
    ds_15 = set_attr(ds_data,ds_15)
    ds_week = ds_data.rolling(time=7, center=False).mean(dim='time',keep_attrs=True)
    ds_week = set_attr(ds_data,ds_week)
    ds_2dy = ds_data.rolling(time=2, center=False).mean(dim='time',keep_attrs=True)
    ds_2dy = set_attr(ds_data,ds_2dy)
    for var in ds_data:
        ds_data[var+'_1mon']=ds_mon[var]
        ds_data[var+'_15dy']=ds_15[var]
        ds_data[var+'_1week']=ds_week[var]
        ds_data[var+'_2dy']=ds_2dy[var]  
    if name=='topo':
        continue
    for var in ds_data:
        var_tem=var
        ds_bird[var_tem]=xr.DataArray(np.empty(ilen_bird, dtype=str(ds_data[var].dtype)), coords={'index': ds_bird.index}, dims=('index'))
        ds_bird[var_tem].attrs=ds_data[var].attrs
    print('var',var_tem)
    for i in range(len(ds_bird.lat)):
        #tem = ds_data.sel(time=ds_bird.time64[i])
        #tem = ds_data.sel(time=slice(t1,t2),lat=slice(lat1,lat2),lon=slice(lon1,lon2)).load()
        ilat,ilon = len(ds_data.lat.data),len(ds_data.lon.data)
        tem = ds_data.sel(time=ds_bird.time64[i],method='nearest')
        tem = tem.chunk(chunks={'lat':ilat,'lon':ilon})
        tem = tem.interp(lat=ds_bird.lat[i],lon=ds_bird.lon[i])
        tem = tem.load()
        for var in ds_data:
            var_tem=var
            ds_bird[var_tem][i]=tem[var].data
        if int(i/10)*10==i:
            print(i,len(ds_bird.lat))

    #output data
   #df_bird = ds_bird.to_dataframe()
    #df_bird.to_csv(filename_bird_out)
    ds_bird.to_netcdf(filename_bird_out+name+'.nc')

In [None]:
#test rolling to check
print(da.data)
da = xr.DataArray(np.linspace(0, 11, num=12),coords=[pd.date_range( "15/12/1999", periods=12, freq=pd.DateOffset(months=1), )],dims="time",)
dar = da.rolling(time=3,center=False).mean()  #before and up too
print(dar.data)