In [None]:
# load dependencies and allow interactivity
%matplotlib widget

import rioxarray as rxr
import matplotlib.pyplot as plt
import pandas as pd
import glob
import os
import numpy as np

# dependencies for ML
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score
import joblib

## SET USER DEFINITIONS

In [None]:
# location of data and the year that training is performed on
data_direc = '/Users/jpflug/Documents/Projects/cubesatReanaly/Data/Meadows/STR/'
focus_year = 2023

# set to 0 if classifying snow-absent gridcells
# set to 1 if classifying snow-covered gridcells
# set to 2 if classifying glint or other image errors/occlusions
snow_present = 2

# specify the indices of the images that you want to click through
set_of_ten = [86,73,68,64,61,57,48,42,27,16]

# determine the subdirectories that contain the PS imagery
subdirecs = sorted([d for d in glob.glob(data_direc + str(focus_year) + '*') if os.path.isdir(d)])
print('length of data record: ',len(subdirecs))

## FUNCTIONS -- DO NOT EDIT

In [None]:
#### functions

# calculate the rgb bands and normalize radiances
def calc_rgb(ds):
    # Selecting RGB bands
    blue_band = ds.isel(band=0)
    green_band = ds.isel(band=1)
    red_band = ds.isel(band=2)
    nir_band = ds.isel(band=3)
    
    # normalize to help visual plotting
    maxval = green_band.max().values
    minval = green_band.min().values
    red_norm = (red_band - minval) / (maxval - minval)
    green_norm = (green_band - minval) / (maxval - minval)
    blue_norm = (blue_band - minval) / (maxval - minval)
    green_norm = green_norm.where(red_norm <= 1,1)
    blue_norm = blue_norm.where(red_norm <= 1,1)
    red_norm = red_norm.where(red_norm <= 1,1)

    # port to numpy
    red_band = red_band.values
    green_band = green_band.values
    blue_band = blue_band.values
    nir_band = nir_band.values
    
    # Stack normalized bands to create RGB image
    rgb_image = np.stack([red_norm, green_norm, blue_norm], axis=-1)
    return red_band,green_band,blue_band,nir_band,rgb_image

# initialize the interactive points
clicked_points = []

## START: USER DEFINITIONS OF SNOW PRESENCE, SNOW ABSENCE, AND IMAGE ARTIFACTS
#### Each cell plots a single image and allows the user to click-classify pixels in the image. 
#### The PS imagery selected are stored in the list "clicked_points"

In [None]:
direcc = subdirecs[set_of_ten[0]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[1]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[2]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[3]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[4]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[5]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[6]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[7]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[8]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

In [None]:
direcc = subdirecs[set_of_ten[9]]
fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
ds = rxr.open_rasterio(fname)
red_band, green_band, blue_band, nir_band, rgb_image = calc_rgb(ds)

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image, cmap='gray')

# Function to handle mouse clicks
def onclick(event):
    if event.button == 1:  # Left mouse button clicked
        x = int(event.xdata)
        y = int(event.ydata)
        clicked_points.append((x, y, red_band[y, x], green_band[y, x], blue_band[y, x], nir_band[y, x],
                               rgb_image[y, x, 0], rgb_image[y, x, 1], rgb_image[y, x, 2]))
        plt.plot(x, y, 'ro')  # Mark clicked point with red dot
        plt.draw()

# Connect the mouse click event to the onclick function
plt.connect('button_press_event', onclick)
plt.show()

## END: USER DEFINITIONS OF SNOW PRESENCE, SNOW ABSENCE, AND IMAGE ARTIFACTS
#### Save the classified data

In [None]:
# save the classified data
df = pd.DataFrame(clicked_points, columns=['x', 'y','r','g','b','nir','r_norm','g_norm','b_norm'])
print(df)

if snow_present == 0:
    df.to_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class0.csv')
elif snow_present == 1:
    df.to_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class1.csv')
else:
    df.to_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class2.csv')

## TRAIN THE RANDOM FOREST MODEL USING THE USER-CLASSIFIED DATA

In [None]:
# load the band data for the selected points
X0 = pd.read_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class0.csv')
X0 = X0.assign(label=0)
X1 = pd.read_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class1.csv')
X1 = X1.assign(label=1)
X2 = pd.read_csv(data_direc+'self_classified/self_classified_'+str(focus_year)+'_class2.csv')
X2 = X2.assign(label=2)
# concatentate the band data for training
X = pd.concat([X0,X1,X2])
# prepare the user-defined labels
y = np.ravel(X[['label']].values)
# preview the dataframe
X

## CONTRUCT AND PERFORM THE RF MODEL
#### Model motivated by that constructed by Yang et al., (2023): https://github.com/KehanGit/High_resolution_snow_cover_mapping

In [None]:
# set up model springs and determine performance
model = RandomForestClassifier(n_estimators=10, max_depth=10, max_features=3, random_state=1)
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=1000, random_state=1)
n_accuracy = cross_val_score(model, X[['b','g','r','nir']],y, scoring='accuracy', cv=cv, n_jobs=-1, error_score='raise')
n_balanced_accuracy = cross_val_score(model,X[['b','g','r','nir']], y, scoring='balanced_accuracy', cv=cv, n_jobs=-1, error_score='raise')

# report performance
print('Repeat times:'.format(), len(n_accuracy))
print('Balanced Accuracy: %.5f (%.5f)' % (n_balanced_accuracy.mean(), n_balanced_accuracy.std()))
print('Accuracy: %.5f (%.5f)' % (n_accuracy.mean(), n_accuracy.std()))

In [None]:
# fit model with all observations
model.fit(X[['b','g','r','nir']],y)
# save the resulting ML model
joblib.dump(model,data_direc+'self_classified/3class_model.joblib')