**Import datacube and dependencies - should not need to copy entire cell**

In [1]:
%matplotlib inline

# Import the path to the root directory to import utilites.
import sys
sys.path.append('../..')

import datacube
from utils.data_cube_utilities.data_access_api import DataAccessApi
from utils.data_cube_utilities.dc_display_map import display_map
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import datetime as dt

# Create an instance of the datacube and API
api = DataAccessApi()#config="/home/localuser/.datacube.conf"
dc = api.dc

In [9]:
# Landsat 8 data is not ascending in time.
l8 = xr.open_dataset('l8_wofs.nc').wofs.load().sortby('time')
s1 = xr.open_dataset('s1_water_v2.nc').wofs.load()
s2 = xr.open_dataset('wofs_s2ab.nc').wofs.load()

In [14]:
# Create montly composites. 
# Reindex to ensure there is data for each month, even if NaN.
l8_monthly = l8.groupby('time.month').mean(dim='time')
    #l8.resample(time='2m').mean(dim='time')
# del l8
s1_monthly = s1.groupby('time.month').mean(dim='time')
# del s1
s2_monthly = s2.groupby('time.month').mean(dim='time')
# del s2

In [None]:
# print("\nLANDSAT 8")
# for month_val in l8_monthly.month:
#     l8_monthly.sel(month=month_val).plot()
#     plt.show()
# print("\nSENTINEL 1")
# for month_val in s1_monthly.month:
#     s1_monthly.sel(month=month_val).plot()
#     plt.show()
# print("\nSENTINEL 2")
# for month_val in s2_monthly.month:
#     s2_monthly.sel(month=month_val).plot()
#     plt.show()

In [None]:
print("Landsat 8:", l8_monthly)
print("Sentinel 1:", s1_monthly)
print("Sentinel 2:", s2_monthly)
print()
print("L8 fraction not nan:", l8.count(dim='time')/len(l8.time))
print("L8 monthly fraction not nan:", l8_monthly.count(dim='month')/len(l8_monthly.month))
print("S1 fraction not nan:", s1.count(dim='time')/len(s1.time))
print("S1 monthly fraction not nan:", s1_monthly.count(dim='month')/len(s1_monthly.month))
print("S2 fraction not nan:", s2.count(dim='time')/len(s2.time))
print("S2 monthly fraction not nan:", s2_monthly.count(dim='month')/len(s2_monthly.month))

In [15]:
def xr_calc_corr(da1, da2, dim):
    """
    Finds the minimum, mean, median, or maximum time difference between True values
    in a boolean xarray.DataArray.
    
    Parameters
    ----------
    da1, da2: list of 2 xarray.DataArray
        The xarray.DataArrays to calculate correlation for
    dim: list of str
        The dimensions to calculate correlation for. The remaining dimensions
        will be flattened and have their correlation calculated. So `dims` are
        the dimensions to keep in the final result.

    Returns
    -------
    out_arr: xarray.DataArray of np.timedelta64
        The time differences.
    """
    def calc_corr_np_arr(arr, axis):
        """
        Calculate the correlation of a NumPy array over a set of axes.
        
        Parameters
        ----------
        arr: NumPy array
        axis: list of int
            A list of the axes to calculate correlation over.
        """
#         def calc_corr_np_arr_1d(arr_1d, axis):
#             print("arr_1d.shape:", arr_1d.shape)
#             print("axis:", axis)
#             return numpy.corrcoef(list1, list2)[0, 1]
#             desired_mask = arr_1d == 1
#             arr_desired = arr_1d[desired_mask]
#             times_desired = data_arr.time.values[desired_mask]
#             # For each True element, calculate the time difference to the next True element.
#             time_diffs = np.diff(times_desired)
#             # Handle the case of there being no "True" instances.
#             if len(time_diffs) == 0: # Must return single value, not an array.
#                 return np.diff(data_arr.time.values[[0,-1]])[0]
#             # Calculate the desired statistic for the time differences for this lat/lon point.
#             if aggregation_method == 'min':      return np.min(time_diffs)
#             elif aggregation_method == 'mean':   return np.mean(time_diffs)
#             elif aggregation_method == 'median': return np.median(time_diffs)
#             elif aggregation_method == 'max':    return np.max(time_diffs)
        
#         print("arr:", arr.shape)
#         result = np.apply_along_axis(calc_corr_np_arr_1d, axis=axis, arr=arr)
#         result = np.apply_over_axes(calc_corr_np_arr_1d, da1, axes=axis)
#         result = np.corrcoef(da1, da2)[0,1]    
#         print("result.shape:", result.shape)
        return result
    
#     # The dimensions to flatten and calculate correlation for.
#     corr_dims = list(set(da1.dims) - (set(dims)))
#     print(corr_dims)
#     print(da1.reduce(calc_corr_np_arr, dim=corr_dims))
# #     merged = xr.merge([da1.rename("da1"), da2.rename("da2")])
# #     print("merged:", merged)
#     merged.reduce(calc_corr_np_arr, dim=corr_dims)
# #     print("merged:", merged)
#           #{"da1": da1.rename("da1"), "da2": da2.rename("da2")}))
# #     return 
    
#     np.corrcoef(l8_monthly)[0,1]

# xr_calc_corr(l8_monthly, s1_monthly, dims=['month'])

# def xr_calc_corr(*args, **kwargs):
#     """
#     Calculates correlation of two xarray.DataArrays
    
#     Parameters
#     ---------
#     da1, da2: xarray.DataArray
#         The xarray DataArrays 
#     *args: list
#         Must be a list of two xarray.DataArrays to calcualte correlation for.
#     **kwargs: dict
#         Unused.
#     """
#     args = [arg.flatten() for arg in args]
#     print(np.corrcoef(*args)[0,1])
#     return np.corrcoef(*args)[0,1]

# NO - have to resolve axis index from dimension name manually
# x,y = [1, 2, 3], [0, 1, 0.5]
# xr.apply_ufunc(xr_calc_corr, l8_monthly, s1_monthly, join='left')

time_axis_ind = l8_monthly.dims.index('month') # Axis to find correlation for.
import itertools
# Create a heat map of correlation across time for every pair of products.
for pair in itertools.combinations([l8_monthly, s1_monthly, s2_monthly], 2):
    # 1. Combine the data into a new NumPy array.
    # Shape should be [month, y, x, pair_index].
    combined_arr = np.stack([pair[0].values, pair[1].values], axis=-1)
#     print(combined_arr.shape)
    # 2. Create a new NumPy array to hold the correlation coefficients.
    # Shape should be [y, x].
    corr_arr = np.empty((combined_arr.shape[1], combined_arr.shape[2]))
#     print(corr_arr.shape)
#     print(len(pair[0].y), len(pair[0].x))
    # 3. Calculate correlation for
    for x_ind in range(len(pair[0].x)):
        if x_ind % 10 == 0:#(len(pair[0].x)/10) == 0:
            print("{} percent complete".format(x_ind/len(pair[0].x)))
        for y_ind in range(len(pair[0].y)):
            point_vecs = combined_arr[:,y_ind,x_ind,:]
            # Filter out any months which have NaN values for either
            # dataset member of this pair.
            months_inds_with_nan = np.any(np.isnan(point_vecs), axis=1)
            point_vecs_no_nan = point_vecs[~months_inds_with_nan,:]
            if point_vecs_no_nan.shape[0] < 2:
#                 print("Too many NaNs, returning NaN")
                corr_arr[y_ind, x_ind] = np.nan
            else:
#                 print("Enough values")
                corr_arr[y_ind, x_ind] = \
                    np.corrcoef(point_vecs_no_nan[:,0], point_vecs_no_nan[:,1])[0,1] # Shape is (month, pair_index)
    print(corr_arr.mean())
            
#     from xarray.ufuncs import fabs as xr_abs
#     diff = pair[1] - pair[0]
#     print(diff)
#     print()
#     diff = xr_abs(diff)
#     print(diff)


0.0 percent complete


  c /= stddev[:, None]
  c /= stddev[None, :]


0.011961722488038277 percent complete
0.023923444976076555 percent complete
0.03588516746411483 percent complete
0.04784688995215311 percent complete
0.05980861244019139 percent complete
0.07177033492822966 percent complete
0.08373205741626795 percent complete
0.09569377990430622 percent complete
0.1076555023923445 percent complete
0.11961722488038277 percent complete
0.13157894736842105 percent complete
0.14354066985645933 percent complete
0.15550239234449761 percent complete
0.1674641148325359 percent complete


KeyboardInterrupt: 

In [None]:
print(np.nanmean(corr_arr))
plt.imshow(corr_arr)
plt.colorbar()
plt.show()

In [None]:
# print(np.corrcoef([1,np.nan,3], [np.nan, 2, 3]))
# np.corrcoef([1, 2],[1, 2])