# Random sampling of xarray DataArrays

Testing a workflow for conducting random sampling on post-classification dataarrays

In [None]:
import pandas as pd
import xarray as xr
import geopandas as gpd
import numpy as np
import os
from datacube.utils.cog import write_cog
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import sys
sys.path.append('../Scripts')
from deafrica_plotting import map_shapefile
from deafrica_spatialtools import xr_rasterize


## Analysis Parameters

In [None]:
n_sample = 2000
n_class= 2
area_name='indianOcean'
results = 'results/NDWI_composite/'+area_name+ '/'
pred_tif = 'results/NDWI_composite/'+area_name+ '/'+ area_name+ '_NDWI_mosaic.tif'
mask_shp = 'data/AEZs/AEZs_ExcludeLargeWB_IndianOcean.shp'

### Open and mask mosaic

In [None]:
da = xr.open_rasterio(pred_tif).squeeze()
    
#load shapefile
gdf = gpd.read_file(mask_shp)
gdf = gdf.to_crs({'init': 'epsg:6933'})

In [None]:
#rasterize shapeile
mask = xr_rasterize(gdf=gdf,
                     da=da)

da = da.where(mask)
da = da.where(da!=0)
da = da.to_dataset(name='ndwi')

In [None]:
# da.ndwi.isel(x=range(10000,20000)).isel(y=range(10000,20000)).plot(figsize=(10,10))

### Check NDWI distribution and determine thresholds

In [None]:
# if not os.path.exists(f'ndwi_{area_name}.csv'):
if not os.path.exists(f'{results}ndwi_{area_name}.csv'):
    histy, histx, tmp = da.ndwi.plot.hist(bins=100, cumulative=True, density=True);
    np.savetxt(f'{results}ndwi_{area_name}.csv', np.vstack((histx[1:], histy)).transpose(),fmt='%.3f', delimiter=',')

In [None]:
# use wofs
x, y = np.loadtxt(f'{results}wofs_{area_name}.csv', delimiter=',', unpack=True)
ephem = 0.1
perm = 0.9
perc = np.interp([ephem, perm], x, y)
print('percentile for ephemeral and permanent water', perc)
histx, histy = np.loadtxt(f'{results}ndwi_{area_name}.csv', delimiter=',', unpack=True)
thresh = np.interp(perc, histy, histx)
print('Thresholds', thresh)

### Classify into dry, ephemeral and permanent

In [None]:
low, high = thresh[0], thresh[1]

label = np.zeros_like(da.ndwi.values, dtype=np.uint8)
label += (da.ndwi.values >= high).astype(np.uint8)*3
label += ((da.ndwi.values >= low) & (da.ndwi.values<high)).astype(np.uint8)*2
label +=(da.ndwi.values < low).astype(np.uint8)*1
da['label'] = ('y','x'), label
da['label'].attrs = da.ndwi.attrs

In [None]:
write_cog(da.label, f'{results}{area_name}_label.tif')

### Sample from array

In [None]:
class_sizes =[]
for class_id in np.arange(1, n_class+1):
    class_sizes.append((da.label==class_id).sum().values)

class_sizes = np.array(class_sizes)
print(class_sizes)
print(class_sizes/class_sizes.sum())

In [None]:
n_sample_class = np.ceil(n_sample*1./ n_class).astype(int)
print(n_sample_class)

In [None]:
label_picked = {}
for class_id in np.arange(1, n_class+1):
    if class_sizes[class_id-1]> 1e9:
        # slightly over sample
        n_sample_over = np.ceil(n_sample_class*len(da.x)*len(da.y)/class_sizes[class_id-1]).astype(int)
        random_x = np.random.choice(np.arange(len(da.x)), n_sample_over, replace=False)
        random_y = np.random.choice(np.arange(len(da.y)), n_sample_over, replace=False)
        match = dataset.label.values[random_y, random_x] == class_id
        random_y, random_x = random_y[match], random_x[match]
        if len(random_y) < n_sample_class:
            print("Not enough points are picked, try increase the number of random points")
            break
        else:
            pick = np.random.choice(np.arange(len(random_y)), n_sample_class, replace=False)
            y, x = random_y[pick], random_x[pick]
    else:
        index = np.argwhere(da.label.values.flatten() == class_id).squeeze()
        picked = np.random.choice(index, n_sample_class, replace=False)
        # convert back to x, y 
        y, x  = np.unravel_index(picked, da.label.values.shape)
    label_picked[class_id] = (y, x)
    np.savetxt(f'{results}{area_name}_class_{class_id}.csv', 
               np.vstack((da.y[y].values, 
                          da.x[x].values)).transpose(),fmt='%d', delimiter=',')


In [None]:
for class_id in np.arange(1, n_class+1):
    y, x = label_picked[class_id]
    df = pd.DataFrame({'y': da.y[y].values, 'x':da.x[x].values})
    df['class']=class_id
    if class_id ==1: 
        dfs = df
    else: 
        dfs = dfs.append(df, ignore_index=True)

In [None]:
gdf = gpd.GeoDataFrame(
        dfs,
        crs=da.label.crs,
        geometry=gpd.points_from_xy(dfs.x, dfs.y)).reset_index()

gdf = gdf.drop(['x', 'y'],axis=1)

In [None]:
gdf.plot(column='class', figsize=(15,8))

In [None]:
gdf.to_file(results+'indianOcean_samples.shp')

In [None]:
# #open
# da = xr.open_rasterio(pred_tif).squeeze()

# #reclassify
# da = xr.where(da >= 0, 3, da)
# da = xr.where((da >=-0.1) & (da < 0), 2, da)
# da = xr.where(da <-0.1, 1, da)

# #minimize data size by convertng to int8
# da = da.fillna(0).astype(np.int8).assign_coords({'x': da.x.astype(np.float32).values,
#                                                   'y':da.y.astype(np.float32).values})

In [None]:
# def random_sampling(da,
#                     n,
#                     sampling='stratified_random',
#                     manual_class_ratios=None,
#                     out_fname=None
#                    ):
    
#     """
#     Creates randomly sampled points for post-classification
#     accuracy assessment.
    
#     Params:
#     -------
#     da: xarray.DataArray
#         A classified 2-dimensional xarray.DataArray
#     n: int
#         Total number of points to sample. Ignored if providing
#         a dictionary of {class:numofpoints} to 'manual_class_ratios'
#     sampling: str
#         'stratified_random' = Create points that are randomly 
#         distributed within each class, where each class has a
#         number of points proportional to its relative area. 
#         'equal_stratified_random' = Create points that are randomly
#         distributed within each class, where each class has the
#         same number of points.
#         'random' = Create points that are randomly distributed
#         throughout the image.
#         'manual' = user definined, each class is allocated a 
#         specified number of points, supply a manual_class_ratio 
#         dictionary mapping number of points to each class
#     manual_class_ratios: dict
#         If setting sampling to 'manual', the provide a dictionary
#         of type {'class': numofpoints} mapping the number of points
#         to generate for each class.
#     out_fname: str
#         If providing a filepath name, e.g 'sample_points.shp', the
#         function will export a shapefile/geojson of the sampling
#         points to file.
    
#     Output
#     ------
#     GeoPandas.Dataframe
    
#     """
    
#     if sampling not in ['stratified_random', 'equal_stratified_random', 'random', 'manual']:
#         raise ValueError("Sampling strategy must be one of 'stratified_random', "+
#                              "'equal_stratified_random', 'random', or 'manual'") 
#     print('here')
#     #open the dataset as a pandas dataframe
# #     df = da.squeeze()
#     df = da.to_dataframe(name='class')#.astype('category')
#     print('made the dataframe')
#     #list to store points
#     samples = []
    
#     if sampling == 'stratified_random':
#         #determine class ratios in image
#         class_ratio = pd.DataFrame({'proportion': df['class'].value_counts(normalize=True),
#                             'class':df['class'].unique()
#                                  })
        
#         for _class in class_ratio['class']:
#             #use relative proportions of classes to sample df
#             no_of_points = n * class_ratio[class_ratio['class']==_class]['proportion'].values[0]
#             #random sample each class
#             print('Class '+ str(_class)+ ': sampling at '+ str(round(no_of_points)) + ' coordinates')
#             sample_loc = df[df['class'] == _class].sample(n=int(round(no_of_points)))
#             samples.append(sample_loc)

#     if sampling == 'equal_stratified_random':
#         classes = df['class'].unique()
        
#         for _class in classes:
            
#             no_of_points = n / len(classes)
#             #random sample each classes
#             try:
#                 sample_loc = df[df['class'] == _class].sample(n=int(round(no_of_points)))
#                 print('Class '+ str(_class)+ ': sampling at '+ str(round(no_of_points)) + ' coordinates')
#                 samples.append(sample_loc)
            
#             except ValueError:
#                         print('Requested more sample points than population of pixels for class '+ str(_class)+', skipping')
#                         pass
    
#     if sampling == 'random':
#         no_of_points = n
#         #random sample entire df
#         print('Randomly sampling dataAraay at '+ str(round(no_of_points)) + ' coordinates')
#         sample_loc = df.sample(n=int(round(no_of_points)))
#         samples.append(sample_loc)
    
#     if sampling == 'manual':
#         if isinstance(manual_class_ratios, dict):
#             for _class in list(manual_class_ratios.keys()):
#                 no_of_points = manual_class_ratios.get(str(_class))
                
#                 try:
#                     print('sampling '+ _class)
#                     sample_loc = df[df['class'] == int(_class)].sample(n=int(round(no_of_points)))
#                     print('Class '+ str(_class)+ ': sampled at '+ str(round(no_of_points)) + ' coordinates')
#                     samples.append(sample_loc)

#                 except ValueError:
#                     print('Requested more sample points than population of pixels for class '+ str(_class)+', skipping')
#                     pass
            
#         else:
#             raise ValueError("Must supply a dictionary mapping {'class': numofpoints} if sampling" +
#                              " is set to 'manual'")
    
#     #join back into single datafame
#     all_samples = pd.concat([samples[i] for i in range(0,len(samples))])
        
#     #get pd.mulitindex coords as list 
#     y = [i[0] for i in list(all_samples.index)]
#     x = [i[1] for i in list(all_samples.index)]

#     #create geopandas dataframe
#     gdf = gpd.GeoDataFrame(
#         all_samples,
#         crs=da.crs,
#         geometry=gpd.points_from_xy(x,y)).reset_index()

#     gdf = gdf.drop(['x', 'y'],axis=1)
    
#     if out_fname is not None:
#         gdf.to_file(out_fname)
    
#     return gdf


# %%time
# gdf = random_sampling(da=da,
#                     n=total_points,
#                     sampling='manual',
#                     manual_class_ratios={'1':167, '2':167, '3':167},
#                     out_fname='results/test_western.shp'
#                        )

# gdf.plot(column='class', figsize=(10,10),legend=True)