In [None]:
#import netCDF4
import xarray as xr
# import cartopy.crs as ccrs
import numpy as np
# import pandas as pd
# from pandas import read_csv
# import datetime as dt
import matplotlib.pyplot as plt
#from scipy import stats
import time

start_time = time.time()

# %% Load GRACE and ERA5 data
#################################################
# load grace data
gr = xr.open_dataset('F:/CSR_GRACE_GRACE-FO_RL0602_Mascons_all-corrections.nc')
gr["time"] = np.datetime64("2002-01-01") + gr.time.astype("timedelta64[ns]") * 8.64e+13
land_mask = xr.open_dataset('F:/CSR_GRACE_GRACE-FO_RL06_Mascons_v02_LandMask.nc')

# import ERA5 data
era5 = xr.open_dataset('F:/data.nc')
print("gr", gr)
print("era5", era5)
#print("land_mask", land_mask)

# regrid all data in 1 degree 
new_lon2 = np.linspace(0.5, 359.5, num=360)
new_lat2 = np.linspace(89.5, -89.5, num=180)
era5_1deg = era5.interp(latitude=new_lat2, longitude=new_lon2, method='nearest')
grace_1deg = gr.interp(lat=new_lat2, lon=new_lon2, method='nearest')
land_mask_1deg = land_mask.interp(lat=new_lat2, lon=new_lon2, method='nearest')
#change longitude from 0-360 degrees to -180 to 180
#for era5_1deg, grace_1deg and land_mask_1deg

lon_name = 'longitude'  # whatever name is in the data

## Adjust lon values to make sure they are within (-180, 180)
era5_1deg['_longitude_adjusted'] = xr.where(
    era5_1deg[lon_name] > 180,
    era5_1deg[lon_name] - 360,
    era5_1deg[lon_name])

# reassign the new coords as the main lon coords
# and sort DataArray using new coordinate values
era5_1deg = (
    era5_1deg
    .swap_dims({lon_name: '_longitude_adjusted'})
    .sel(**{'_longitude_adjusted': sorted(era5_1deg._longitude_adjusted)})
    .drop(lon_name))

era5_1deg = era5_1deg.rename({'_longitude_adjusted': lon_name})

#era5_1deg = era5_1deg.sel(expver=1).combine_first(era5_1deg.sel(expver=5))


# for grace_1deg - attention: "longitude" is "lon" here
lon_name = 'lon'  # whatever name is in the data
grace_1deg['_longitude_adjusted'] = xr.where(
    grace_1deg[lon_name] > 180,
    grace_1deg[lon_name] - 360,
    grace_1deg[lon_name])

# reassign the new coords  as the main lon coords
# and sort DataArray using new coordinate values
grace_1deg = (
    grace_1deg
    .swap_dims({lon_name: '_longitude_adjusted'})
    .sel(**{'_longitude_adjusted': sorted(grace_1deg._longitude_adjusted)})
    .drop(lon_name))

grace_1deg = grace_1deg.rename({'_longitude_adjusted': lon_name})

# now for land_mask_1deg
land_mask_1deg['_longitude_adjusted'] = xr.where(
    land_mask_1deg[lon_name] > 180,
    land_mask_1deg[lon_name] - 360,
    land_mask_1deg[lon_name])

# reassign the new coords  as the main lon coords
# and sort DataArray using new coordinate values
land_mask_1deg = (
    land_mask_1deg
    .swap_dims({lon_name: '_longitude_adjusted'})
    .sel(**{'_longitude_adjusted': sorted(land_mask_1deg._longitude_adjusted)})
    .drop(lon_name))

land_mask_1deg = land_mask_1deg.rename({'_longitude_adjusted': lon_name})


print("grace", grace_1deg)
print("era5", era5_1deg)

era5 = era5_1deg
grace = grace_1deg

# compute all in mm/month
era5['time_shift'] = era5['time'].shift(time=-1)
era5['daysinmonth'] = (era5['time_shift'] - era5['time']) / np.timedelta64(1, 'D')
#era5['daysinmonth_int'] = era5['daysinmonth'].astype(int)

era5['sro_mm'] = era5['sro'] * 1000 * era5['daysinmonth']

# multiply soil water content per layer with layer depth in m, sum up and convert to mm
era5['SM'] = (era5['swvl1'] * 0.07 + era5['swvl2'] * (0.28-0.07)
              + era5['swvl3'] * (1-0.28) + era5['swvl4'] * (2.89-1)) * 1000 

grace_MS = grace.resample(time="MS").interpolate()
era5_MS = era5.resample(time="MS").mean()
# slice time to 2003-2022
grace = grace_MS.sel(time=slice('2003-01-01T00:00:00', '2022-12-01T00:00:00'))
era5 = era5_MS.sel(time=slice('2003-01-01T00:00:00', '2022-12-01T00:00:00'))

# compute values of GRACE and ERA5 as differences to timeline 2003-2022, convert GRACE from cm to mm
grace_mean = grace['lwe_thickness'].mean(dim='time')
grace['lwe_thickness_scaled_ts_mm'] = (grace['lwe_thickness'] - grace_mean) * 10

era5_mean = era5.mean(dim='time')
era5 = era5 - era5_mean

# compute GWS
temp1 = grace['lwe_thickness_scaled_ts_mm']
temp2 = (era5['SM'] +
          era5['sd']+
          era5['sro_mm'] +
          era5['src'])

era5 = era5.rename({'longitude': 'lon','latitude': 'lat'})

temp3 = xr.merge([grace, era5])
temp3 = temp3['lwe_thickness_scaled_ts_mm'] - (temp3['SM']+temp3['sd']+temp3['sro_mm']+temp3['src'])
gws = temp3

# Slice to get rid of higher latitudes
gws = gws.sel(lat=slice(90, -60))

#mask array with land mask, so that ocean is Nan - nicer for plotting
gws_masked = gws.where(land_mask_1deg ==1)
print("gws_masked", gws_masked)