In [None]:
"""
Hyperparameter optimization for Resnet-18 classification of roof materials using 'Optuna'
"""

import os, sys, time, glob
import geopandas as gpd
import pandas as pd
import rioxarray as rxr
import xarray as xr
import numpy as np
import rasterio as rio
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import optuna

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import transforms, utils
from torchsat.models.classification import resnet18

from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.metrics import roc_auc_score, log_loss, roc_auc_score, roc_curve, auc, f1_score, precision_score, recall_score
from sklearn.model_selection import StratifiedKFold

from fiona.crs import from_epsg
from shapely.geometry import box
from os.path import join

import warnings
warnings.filterwarnings("ignore")

# Custom functions
sys.path.append(os.path.join(os.getcwd(),'code/'))
from __functions import *

# Projection information
wgs = from_epsg(4326)
proj = from_epsg(32618)
print(f'Projected CRS: {proj}')

# maindir = '/Users/max/Library/CloudStorage/OneDrive-Personal/mcook/earth-lab/opp-rooftop-mapping'

print("Successfully imported all packages!")

In [None]:
class RoofImageDatasetPlanet(Dataset):
    """Class to handle PlanetScope SuperDove imagery for Resnet-18"""

    def __init__(self, gdf, img_path, n_bands, img_dim, transform=None):
        """
        Args:
            gdf: Geodataframe containing 'geometry' column and 'class_code' column
            img_path: the path to the PlanetScope SuperDove composite image (single mosaic file)
                - see 'psscene-prep.py' for spectral indices calculation
            imgdim (int): Image dimension for CNN implementation
            transform (callable, optional): Optional transform to be applied on a sample

        Returns image chunks with class labels
        """

        if not os.path.exists(img_path):
            raise ValueError(f'Image does not exists: {img_path}')

        self.geometries = [p.centroid for p in gdf.geometry.values]  # gather centroid geoms
        self.img_path = img_path  # path to image data
        self.img_dim = img_dim  # resnet window dimension, defaults to 64
        self.n_bands = n_bands  # number of bands in the input image
        self.Y = gdf.code.values  # class codes (numeric)
        self.transform = transform

    def __len__(self):
        return len(self.geometries)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get the geometry of the idx (centroid)
        geom = self.geometries[idx]

        try:
            sample = self.sample_image(geom)  # run the sampling function

            cc = self.Y[idx]  # get the class codes
            if type(cc) != int:
                cc = cc.astype('uint8')  # make sure the cc is an integer

            # Ensure the sample has the correct dimensions
            assert sample.shape == (self.n_bands, self.img_dim, self.img_dim), f'Invalid sample shape: {sample.shape}'

            if self.transform:
                sample = self.transform(sample)

        except Exception as e:
            raise ValueError(e)
            print(f"Skipping invalid sample at index: {idx}")
            sample = torch.from_numpy(np.zeros((self.n_bands, int(self.img_dim), int(self.img_dim))))
            cc = 255  # highest int8 number to be flagged

        # Convert the sample array to a Torch object
        sample = torch.from_numpy(sample)

        # Return the sample and the label as torch objects
        return {'image': sample.type(torch.FloatTensor),
                'code': torch.tensor(cc).type(torch.LongTensor)}

    def sample_image(self, geom):
        """ Sample the image at each geometry for the specified image chunk size (window) """

        N = self.img_dim  # window size to be used for cropping

        # Use the windows.from_bounds() method to return the window
        # Returns image chunks from training data locations
        with rio.open(self.img_path) as src:
            py, px = src.index(geom.x, geom.y)
            window = rio.windows.Window(px - N // 2, py - N // 2, N, N)
            # print(window)

            # Read the data in the window
            # clip is a nbands * N * N numpy array
            clip = src.read(window=window, indexes=list(range(1, self.n_bands + 1)))

            del py, px, window  # clean up

        # Convert the image chunk to a numpy array
        clip_arr = np.array(clip)

        # Check if the image chunk has valid data
        if clip_arr.sum() > 0:
            # Mask invalid values in each band independently
            ans = np.ma.masked_equal(clip_arr, 0).filled(0)
        else:
            ans = clip_arr

        del clip, clip_arr  # clean up
        return ans


def make_good_batch(batch):
    """
    Removes bad samples if image dimensions do not match.
    Args:
        - batch: list of dictionaries, each containing 'image' tensor and 'code' tensor
    returns: list of dictionaries same as input with samples having non-matching image dims removed
    """

    _idx = torch.where(batch['code'] != 255)[0]  # good batches

    new_batch = {}
    new_batch['image'] = batch['image'][_idx]
    new_batch['code'] = batch['code'][_idx]

    return new_batch

print("Functions ready !!")

In [None]:
os.chdir('/home/jovyan')
print(os.getcwd())
print(os.listdir(os.getcwd()))

In [None]:
# Prepare the footprint data

In [None]:
# Load the training data (footprints)
ref_path = 'opp-data/dc_data_reference_footprints.gpkg'
ref = gpd.read_file(ref_path)
ref.head()

In [None]:
# Observe the class imbalance in the reference data
print(f"Class counts:\n\n{ref.description.value_counts()}\n")

In [None]:
ref['code'], _ = pd.factorize(ref['class_code']) # create a factorized version
print(ref['class_code'].value_counts())  # check the counts

In [None]:
# Create a dictionary mapping class_code to code
code_mapping = dict(zip(ref['class_code'], ref['code']))
desc_mapping = dict(zip(ref['class_code'], ref['description']))
print(f'Code map: \n{code_mapping}\nDescription map: \n{desc_mapping}')

In [None]:
# Perform balanced sampling (random undersampling)
ref_bal = balance_sampling(ref, ratio=20, strategy='undersample')
ref_bal.code.value_counts()

In [None]:
# Split the train/test data
train_df, val_df, test_df = split_training_data(ref_bal, ts=0.4, vs=0.2)

# Print the class distribution in training and validation sets to verify stratification
print("Train class distribution:\n", train_df['code'].value_counts())
print("Validation class distribution:\n", val_df['code'].value_counts())
print("Test class distribution:\n", test_df['code'].value_counts())

In [None]:
# Prepare the model parameter testing

In [None]:
# Load our image data to check on the format
stack_da_fp = os.path.join('opp-data/dc_0623_psscene8b_final_norm.tif')
stack_da = rxr.open_rasterio(stack_da_fp, mask=True, cache=False).squeeze()
n_bands = stack_da.values.shape[:1][0]
print(
    f"shape: {stack_da.rio.shape}\n"
    f"bands: {n_bands}\n"
    f"resolution: {stack_da.rio.resolution()}\n"
    f"bounds: {stack_da.rio.bounds()}\n"
    f"sum: {stack_da.sum().item()}\n"
    f"CRS: {stack_da.rio.crs}\n"
    f"NoData: {stack_da.rio.nodata}\n"
    f"Array: {stack_da}"
)
del stack_da

In [None]:
# Set up the Resnet-18 model

# Define whether to leverage cpu or gpu (for my local machine it is only cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # get device for gpu or cpu
print(f'Using {device} for model dev ...')

# Grab the number of classes and bands
n_classes = ref_bal.class_code.unique().shape[0]
print(f'There are {n_classes} roof type classes.')
print(f'Using {n_bands} bands for classification.')

# Define the Resnet-18 model (in_channels = number of bands in the image)
model = resnet18(n_classes, in_channels=n_bands, pretrained=False)

# Make model parallel and on GPU
if torch.cuda.device_count() >= 1:
    print("Using ", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
    model.to(device)
else:
    #ps_model = nn.DataParallel(ps_model)
    model = nn.DataParallel(model)
    print('Made cpu parallel')

In [None]:
# Number of samples in each class
val_counts = list(train_df['code'].value_counts())
total_samples = sum(val_counts) # total number of samples
print(f'Total samples: {total_samples};\nValue counts: {val_counts}')

# Calculate class weights
class_weights = [total_samples / count for count in val_counts]
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Class weights: {class_weights}")

class_weights_norm = class_weights / class_weights.sum()
print(f"Normalized class weights: {class_weights_norm}")

# Updated loss function with weights
criterion = nn.CrossEntropyLoss(weight=class_weights_norm).to(device)

In [None]:
gc.collect()

In [None]:
print(f"Train DataFrame indices: {val_df.index}")

In [None]:
imdir = stack_da_fp

window_size = 64

def objective(trial):
    """ 
    Function for fine-tuning Resnet-18 model using 'optuna' Python package
    Args:
        - trial: Optuna trial
    """

    try:
        # Suggest hyperparameters to test
        batch_size = trial.suggest_int('batch_size', 128, 256)
        lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
        weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-3)
    except Exception as e:
        print(f"Trial failed due to: {e}")
        return None  # Returning None to indicate a failed trial
        
    # Load the train, test, and validation

    # Train
    
    # Create the training samples
    train_ds = RoofImageDataset_Planet(train_df[['geometry', 'code']], stack_da_fp, n_bands=n_bands, img_dim=window_size)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    
    # Create the validation samples
    val_ds = RoofImageDataset_Planet(val_df[['geometry', 'code']], stack_da_fp, n_bands=n_bands, img_dim=window_size)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

    # Training loop
    model.train()
    for epoch in range(10):  # Adjust number of epochs as needed
        running_loss = 0.0
        for idx, batch in enumerate(train_loader):
            # Ensure a good batch
            batch = make_good_batch(batch)
            
            inputs, labels = batch['image'].to(device), batch['code'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            # Ensure a good batch
            batch = make_good_batch(batch)
            
            inputs, labels = batch['image'].to(device), batch['code'].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

print("Ready !!")

In [None]:
# Create an Optuna study
t0 = time.time()

study = optuna.create_study(study_name="Resnet-18 Hyperparameter Tuning", direction='maximize')
study.optimize(objective, n_trials=10)

# Display the best hyperparameters and accuracy
print("Best hyperparameters:", study.best_params)
print("Best accuracy:", study.best_value)

t1 = (time.time() - t0) / 60
print(f"Total elapsed time: {t1:.2f} minutes.")
print("\n~~~~~~~~~~\n")

In [None]:
gc.collect()