## Separate training and validation tiles, stratifying by Census Region and the driver class of the auxiliary dataset from which the plot was sampled.

In [2]:
import os
import sys
from pathlib import Path

#Import utility functions
utilsPath = os.path.join(Path(os.getcwd()).parent.absolute(),'utils')
if utilsPath not in sys.path:
    sys.path.append(utilsPath)

from getTiles import image_gen
from getTiles import read_image
from lossFunctions import iou_coef, DiceLoss, weighted_categorical_crossentropy

#Import other functions
import random
import glob
import warnings
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib
import datetime

from sklearn.model_selection import train_test_split


In [3]:
#Define training data and test data

#Set seeds
seed = 42
random.seed = seed
np.random.seed = seed

#Define number of classes
num_classes = 8

#Define dictionary to convert sampledFrom attribute to a driver class
sampleDictionary = {'Burn Area':'Wildfire', 'Insect Damage':'Natural Disturbance', 'Mining':'Hard Commodities', 
                    'NLCD Developed': 'Urbanization', 'Plantations':'Forestry','Power Plants':'Hard Commodities', 
                    'Regional Plantations':'Forestry', 'Surface Water':'Natural Disturbance', 
                    'USDA Cultivated':'Soft Commodities'}

#Read in Census regions
regions = pd.read_csv('../inputs/censusStatesRegionsDivisions.csv')

#Define file path and get plot year and id
labelFilePath = '../inputs/labeledTilesValidYears/plot_{}_{}.tif'
fileNames = glob.glob('../inputs/labeledTilesValidYears/*.tif')
fileNames.sort()

#Define list of [[plotID, yearOfLoss]]
plotIDYears = [[int(x.split('_')[1]),int(x.split('_')[2].split('.')[0])] for x in fileNames]

#Get unique plot ID's and shuffle
plotIDS = np.unique([plotIDYears[i][0] for i in np.arange(len(plotIDYears))])
np.random.shuffle(plotIDS)


#Get a dataframe of plot ID's and who labeled them
plotAssignmentFile = '../inputs/plots/labeledSamples.shp'
plotAssignments = gpd.read_file(plotAssignmentFile)

plotAssignments = plotAssignments[['plotid','email','pl_samplef','pl_state','pl_latitud','pl_longitu']]
plotAssignments = plotAssignments.drop_duplicates()

#Get the region and sampled driver class for each of the plots
for i, row in plotAssignments.iterrows():
    plotAssignments.at[i,'region'] = str(int(regions[regions['Name']==row['pl_state']].Region.values[0]))
    plotAssignments.at[i,'sampledClass'] = sampleDictionary.get(row['pl_samplef'])
    
    
#Define stratify column by appending sampledClass and region
plotAssignments['stratify'] = plotAssignments['sampledClass'].astype(str) + ' '+plotAssignments['region'].astype(str) #+ ' '+plotAssignments['pl_state'].astype(str) 

#Get the number of counts of the stratify column
strat, counts = np.unique(plotAssignments['stratify'],return_counts=True)
stratCount = pd.DataFrame(list(zip(strat, counts)),
               columns =['stratify', 'count'])

#This results in some stratify classes having only one count, we'll put those into one category called "Solos"
plotAssignments['stratify2'] = plotAssignments['stratify']
for i,countRow in stratCount.iterrows():
    if countRow['count']==1:
        for j, row in plotAssignments.iterrows():
            if row['stratify'] == countRow['stratify']:
                plotAssignments.at[j,'stratify2'] = 'Solos'
            
#Stratify on this final stratify attribute and save to CSV with plot ID's
        
trainPlots, validatePlots = train_test_split(plotAssignments, test_size=0.3, random_state=seed, stratify=plotAssignments[['stratify2']])
trainPlots.to_csv('../inputs/plots/trainingPlots.csv')
validatePlots.to_csv('../inputs/plots/validationPlots.csv')
