# Random Forest

## Setup

Files and system

In [1]:
import os
import json

Arrays and math

In [2]:
import numpy as np
import xarray as xr
import pandas as pd
import geopandas as gpd

In [3]:
import random

Plotting

In [4]:
import matplotlib.pyplot as plt

Raster operations

In [5]:
import rasterio
import rioxarray

Modelling

In [32]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

Directories

In [6]:
DATA_DIR = '../data/'
FIGURES_DIR = '../saved_figures/'

Habitat label code

In [7]:
habtype_code = pd.read_csv(os.path.join(DATA_DIR, 'habtype_ids.csv'))[['id', 'HabType']]

Tile ID list

In [8]:
with open(os.path.join(DATA_DIR, 'tile_buffer_wkt.json'), 'r') as f:
    tile_list = list(json.load(f).keys())

Class sampling stratification: number of samples per class per tile

In [9]:
N_CLASS_SAMPLES = 5

## Prepare dataset

In [18]:
def stratified_sample(g, n_class_samples, background_multiplier):
    class_id = g.name[1] if isinstance(g.name, tuple) else g.name
    n = int(background_multiplier * n_class_samples) if class_id == -1 else n_class_samples
    return g.sample(n=min(n, len(g)), random_state=42)

In [19]:
def process_tile(tile_ds, n_class_samples, background_multiplier=0.5):
    tile_ds = tile_ds.drop_vars(['bitmask'])
    
    # filter based on unknown_flag
    valid_subtiles = (
        tile_ds['subtile_in']
        .where(~tile_ds['unknown_flag'])
        .dropna(dim='subtile_in')
        .values.flatten()
    )
    tile_ds = tile_ds.sel(subtile_in=valid_subtiles)

    # get habitat class map
    habitat_vars = [f"habitat_{i}" for i in habtype_code['id'].values]
    habitat_stack = tile_ds[habitat_vars].to_array(dim='habitat_class').fillna(0)

    # define presence and background
    presence_mask = habitat_stack.any(dim='habitat_class')
    class_map = xr.where(presence_mask, habitat_stack.argmax(dim='habitat_class'), -1)

    # apply coastal buffer
    mask = tile_ds['coastal_buffer'] == 1
    features = tile_ds.drop_vars(habitat_vars).where(mask)
    features['habitat_class'] = class_map.where(mask)

    # flatten
    flat = features.stack(pixel=('subtile_in', 'y', 'x')).dropna(dim='pixel', how='any')
    df = flat.to_dataframe().drop(columns=['y', 'x', 'subtile_in']).reset_index()

    # sample pixels for each habitat class
    df_sampled = (
        df.groupby(['subtile_in', 'habitat_class'])
        .apply(stratified_sample,
               include_groups=False,
               n_class_samples=n_class_samples,
               background_multiplier=background_multiplier)
    ).drop(columns=['y', 'x', 'lat', 'lon', 'unknown_flag', 'coastal_buffer']).reset_index()

    return df_sampled

In [20]:
habitat_vars = [f"habitat_{i}" for i in habtype_code['id'].values]

In [21]:
tile_samples = []

In [22]:
for tileId in tile_list:
    ds_fp = os.path.join(DATA_DIR, f's2_processed/{tileId}/data_cube.nc')
    if not os.path.exists(ds_fp):
        continue

    print(f"Processing tile {tileId}")

    # Load dataset
    tile_ds = xr.open_dataset(ds_fp)

    # Process tile
    df_sampled = process_tile(tile_ds, N_CLASS_SAMPLES)
    df_sampled = df_sampled.drop(columns=['level_2', 'tileId', 'subtile_in'])
    tile_samples.append(df_sampled)

    print(f"Sampled tile {tileId}")

Processing tile 29UPR
Sampled tile 29UPR
Processing tile 30UUA
Sampled tile 30UUA
Processing tile 30UVA
Sampled tile 30UVA
Processing tile 30UVB
Sampled tile 30UVB
Processing tile 30UWB
Sampled tile 30UWB
Processing tile 30UXA
Sampled tile 30UXA
Processing tile 30UVC
Sampled tile 30UVC
Processing tile 30UUC
Sampled tile 30UUC
Processing tile 30UUD
Sampled tile 30UUD
Processing tile 30UUE
Sampled tile 30UUE
Processing tile 30UVE
Sampled tile 30UVE
Processing tile 30UUF
Sampled tile 30UUF
Processing tile 30UVF
Sampled tile 30UVF
Processing tile 30UUG
Sampled tile 30UUG
Processing tile 29UPB
Sampled tile 29UPB
Processing tile 30VUH
Sampled tile 30VUH
Processing tile 30VUJ
Sampled tile 30VUJ
Processing tile 30VUK
Sampled tile 30VUK
Processing tile 30VVK
Sampled tile 30VVK
Processing tile 30VVL
Sampled tile 30VVL
Processing tile 30VWL
Sampled tile 30VWL
Processing tile 29VPC
Sampled tile 29VPC
Processing tile 29VPE
Sampled tile 29VPE
Processing tile 32UMF
Sampled tile 32UMF
Processing tile 

In [23]:
combined_samples = pd.concat(tile_samples, ignore_index=True)

In [24]:
combined_samples['habitat_class'] = combined_samples['habitat_class'].astype(int)

In [27]:
combined_samples['habitat_class'].value_counts()

habitat_class
-1     8674
 6     1570
 0     1156
 4      663
 8      627
 11     467
 7      455
 10     256
 13     135
 9       20
 1       10
 5       10
 14       5
Name: count, dtype: int64

Too few samples for habitat classes 9, 1 and 5.

In [41]:
habtype_code[habtype_code['id'].isin((9, 1, 5))]

Unnamed: 0,id,HabType
1,1,Lophelia pertusa reefs
5,5,Deep-sea sponge aggregations
9,9,Ostrea edulis beds


In [45]:
combined_samples = combined_samples[~combined_samples['habitat_class'].isin((9, 1, 5, 14))]

In [46]:
combined_samples['habitat_class'].value_counts()

habitat_class
-1     8674
 6     1570
 0     1156
 4      663
 8      627
 11     467
 7      455
 10     256
 13     135
Name: count, dtype: int64

In [47]:
combined_samples.to_csv(os.path.join(DATA_DIR, 'multi_class_dataset.csv'))

Close any datasets

In [48]:
tile_ds.close()

## Random Forest model

In [49]:
combined_samples.head()

Unnamed: 0,habitat_class,Rnir,Rgli,logchl,logfb,Rw443,Rw490,Rw560,Rw665,Rw705,Rw740,Rw783,Rw842,Rw865,Rw1610
0,-1,0.031737,0.02972,-0.064695,0.691891,0.031351,0.038182,0.023841,0.005025,0.002126,-0.001619,-0.000183,0.004902,0.001417,0.00979
1,-1,0.036553,0.030647,0.00147,0.767847,0.033966,0.042742,0.030115,0.005528,0.002342,-0.00136,0.001626,0.002456,0.000558,-0.004591
2,-1,0.034345,0.030222,-0.041288,0.727521,0.033409,0.039947,0.027236,0.00501,0.000764,-0.000352,0.001592,0.001467,0.000243,-0.000133
3,-1,0.040351,0.029758,0.19962,0.667138,0.021745,0.030908,0.026212,0.006919,0.001078,-0.00113,0.00222,0.006179,0.000556,0.010109
4,-1,0.046465,0.030413,0.056508,0.663881,0.025654,0.032844,0.024881,0.0048,0.001852,-0.000233,0.000246,0.000205,0.000867,0.003403


In [62]:
feature_cols = list(combined_samples.columns)[1:]

In [64]:
X = combined_samples[feature_cols]
y = combined_samples['habitat_class']

In [67]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, test_size=0.2, random_state=42
)

In [69]:
clf = RandomForestClassifier(
    n_estimators=100,
    max_depth=None,
    class_weight='balanced',
    random_state=42,
    n_jobs=-1
)

In [70]:
clf.fit(X_train, y_train)

In [71]:
y_pred = clf.predict(X_test)

In [72]:
print(confusion_matrix(y_test, y_pred))

[[1715    7    0    5    0    3    0    4    1]
 [ 213   12    0    1    0    3    0    2    0]
 [ 131    0    0    2    0    0    0    0    0]
 [ 304    1    1    6    0    0    0    2    0]
 [  88    0    0    1    1    0    0    1    0]
 [ 113    2    0    0    0   10    0    1    0]
 [  50    0    0    0    0    0    1    0    0]
 [  76    3    0    0    0    0    0   14    0]
 [  25    0    0    1    0    0    0    0    1]]


In [73]:
print(classification_report(y_test, y_pred, zero_division=0))

              precision    recall  f1-score   support

          -1       0.63      0.99      0.77      1735
           0       0.48      0.05      0.09       231
           4       0.00      0.00      0.00       133
           6       0.38      0.02      0.04       314
           7       1.00      0.01      0.02        91
           8       0.62      0.08      0.14       126
          10       1.00      0.02      0.04        51
          11       0.58      0.15      0.24        93
          13       0.50      0.04      0.07        27

    accuracy                           0.63      2801
   macro avg       0.58      0.15      0.16      2801
weighted avg       0.58      0.63      0.51      2801

