## Install the package

In [None]:
# Change to your version of CUDA
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

UsageError: Line magic function `%pip3` not found.


In [None]:
# If gdal not installed, run: sudo apt install libgdal-dev gdal-bin
import subprocess

# Get the version of GDAL installed on the system via gdal-config
gdal_version = subprocess.check_output(["gdal-config", "--version"]).decode("utf-8").strip()

# Use pip to install the corresponding GDAL version
%pip install gdal=={gdal_version}

# Downgrade numpy to < 2.0.0
%pip install numpy==1.26.4

# Get cv2
%pip install opencv-python

# Jupyter specifics
%pip install ipywidgets

Note: you may need to restart the kernel to use updated packages.


UsageError: Line magic function `%pip3` not found.


In [None]:
# Tested on version 2.0.1 for detectree2
%pip install git+https://github.com/PatBall1/detectree2.git

## Train

Important!
1. Remove eval/ folder if dataset changed
2. Depending on base model, need sufficient GPU VRAM

In [None]:
# Test package imports, do not proceed if there are errors
import os
import rasterio
import geopandas as gpd
from detectron2.data.catalog import DatasetCatalog
from detectree2.preprocessing.tiling import tile_data, to_traintest_folders
from detectree2.models.train import register_train_data, MyTrainer, setup_cfg

In [None]:
# Base path for all sites
base_path = "../dataset"

# Parameters (change these as needed)
buffer = 10
tile_width = 15
tile_height = 15
threshold = 0.5
appends = f"{tile_width}_{buffer}_{threshold}"
test_frac = 0.15
folds = 5

# Function to get site name from folder path
def get_site_name(folder_path):
    # Extract the site name from the folder path
    return os.path.basename(folder_path)

# Get all immediate subdirectories of the base path (all sites)
site_folders = [f for f in os.listdir(base_path) 
                if os.path.isdir(os.path.join(base_path, f))]

print(f"Found {len(site_folders)} site folders: {site_folders}")

# List to store all registered dataset names
all_train_datasets = []
all_val_datasets = []

# Process each site folder
for site_folder in site_folders:
    site_path = os.path.join(base_path, site_folder)
    site_name = get_site_name(site_path)
    
    # Construct paths for this site
    img_path = os.path.join(site_path, "rgb", f"{site_folder}.tif")
    crown_path = os.path.join(site_path, "crowns", f"{site_folder}.gpkg")
    out_dir = os.path.join(site_path, f"tiles_{appends}/")
    
    print(f"Processing site: {site_folder}")
    
    # Ensure the image and crown files exist
    if not os.path.exists(img_path):
        print(f"  Warning: Image file not found at {img_path}")
        continue
    if not os.path.exists(crown_path):
        print(f"  Warning: Crown file not found at {crown_path}")
        continue
    
    try:
        # Read in the tiff file for CRS info
        data = rasterio.open(img_path)
        
        # Read in crowns and ensure matching CRS
        crowns = gpd.read_file(crown_path)
        crowns = crowns.to_crs(data.crs.data)  # Making sure CRS match
        
        # Create output directory if it doesn't exist
        os.makedirs(out_dir, exist_ok=True)
        
        # Run tile_data function
        print(f"  Running tile_data for {site_folder}...")
        tile_data(img_path, out_dir, buffer, tile_width, tile_height, crowns, threshold, mode="rgb")
        
        # Run to_traintest_folders function
        print(f"  Running to_traintest_folders for {site_folder}...")
        to_traintest_folders(out_dir, out_dir, test_frac=test_frac, strict=False, folds=folds)
        
        # Register the dataset
        dataset_name = f"oilpalm_msia_{site_name}"
        train_location = os.path.join(site_path, f"tiles_{appends}/train")
        
        print(f"  Registering dataset: {dataset_name}")
        register_train_data(train_location, dataset_name, val_fold=folds)
        
        # Add to our lists of registered datasets
        all_train_datasets.append(f"{dataset_name}_train")
        all_val_datasets.append(f"{dataset_name}_val")
        
        print(f"  Completed processing {site_folder}")
    except Exception as e:
        print(f"  Error processing {site_folder}: {str(e)}")

# Convert lists to tuples for Detectron2
all_train_datasets = tuple(all_train_datasets)
all_val_datasets = tuple(all_val_datasets)

print("All sites processed.")
print("Available datasets:", DatasetCatalog.list())
print("Registered train datasets:", all_train_datasets)
print("Registered validation datasets:", all_val_datasets)

# Setup base model from facebookresearch/detectron2
# https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml
# Need at least 6GB of VRAM
base_model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"

# Setup configuration with all datasets
out_dir = "train_outputs"
cfg = setup_cfg(
    base_model,
    all_train_datasets,  # Use all registered train datasets
    all_val_datasets,    # Use all registered validation datasets
    workers=4,
    eval_period=100,
    max_iter=3000,
    out_dir=out_dir      # Default is "train_outputs"
)

# Check evaluation datasets configuration
print("Test datasets:", cfg.DATASETS.TEST)

In [None]:
trainer = MyTrainer(cfg, patience = 5)
trainer.resume_or_load(resume=False)
trainer.train()

## Evaluation

Based on metrics.json from train_output folder

In [None]:
import matplotlib.pyplot as plt
from detectree2.models.train import load_json_arr

experiment_metrics = load_json_arr('train_outputs/metrics.json')

plt.plot(
   [x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
   [x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x], label='Total Validation Loss', color='red')
plt.plot(
   [x['iteration'] for x in experiment_metrics if 'total_loss' in x],
   [x['total_loss'] for x in experiment_metrics if 'total_loss' in x], label='Total Training Loss')

plt.legend(loc='upper right')
plt.title('Comparison of the training and validation loss of detectree2')
plt.ylabel('Total Loss')
plt.xlabel('Number of Iterations')
plt.show()