Steps 2 and 3 of Fast ice detection workflow: SAM, SVM

ARE Settings:

Queue: gpuvolta
Compute size: 1 gpu
Jobfs size: 10GB

Conda environment: sam_env

In [None]:
#Steps 2 & 3: SAM + SVM

import torch
from rasterio.features import rasterize
import matplotlib.pyplot as plt
from shapely.geometry import box
import copy
import cv2
from scipy.ndimage import zoom
import pdb
import geopandas as gpd
import pathlib
import supervision as sv
import math
from sklearn import svm
from pathlib import Path
from rasterio.warp import reproject, Resampling
from rasterio.transform import Affine
from scipy.ndimage import minimum_filter1d
from skimage.transform import resize

In [None]:
# SAM 1.1. Enable GPU
!nvidia-smi

#Set working directory
os.chdir('/g/data/jk72/gb4219/honours_data/Normprod_Demo/')
#read model weights ('vit_h')
CHECKPOINT_PATH = os.path.join("/", "g", "data", "jk72", "gb4219", "honours_data", "Normprod_Demo", "sam_vit_h_4b8939.pth")
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

#Model specifications

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

############# SAM Parameters, can play with these ####################
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=95,              # High sampling density
    pred_iou_thresh=0.85,              # ‚Üì Allow more low-confidence masks
    stability_score_thresh=0.85,      # ‚Üì Keep less "stable" masks
    crop_n_layers=2,                  # ‚Üë Enable multi-resolution crops
    crop_n_points_downscale_factor=2, # ‚Üì Fewer points per crop (more diversity)
    crop_overlap_ratio=0.5,           # ‚Üë More overlapping regions between crops
    box_nms_thresh=0.3,               # ‚Üì Less aggressive suppression of overlapping masks
    min_mask_region_area=50,          # ‚Üê Allow tiny segments (don‚Äôt filter small regions)
    output_mode='binary_mask' 
)

#function for input image    
def fill_nans(image):
    """Fills NaNs before filtering to prevent them from growing."""
    nan_mask = np.isnan(image)
    if not np.any(nan_mask):
        return image
    print("üîß Filling NaNs before filtering...")
    return np.where(nan_mask, 0, image)

In [None]:
##### SAM loop#####

# Base directory
base_dir = "/g/data/jk72/gb4219/honours_data/Normprod_Demo/" ################################################################################
# path to coastline shapefile 
shp_path = pathlib.Path("/g/data/jk72/gb4219/honours_data/Normprod_Demo/add_coastline_high_res_polygon_v7_10.shp")
shapefile = gpd.read_file(shp_path)
 

# Loop over every ISCE3 folder
for folder_name in sorted(os.listdir(base_dir)):
    folder_path = os.path.join(base_dir, folder_name)
    if not os.path.isdir(folder_path):
        continue  # skip files, only process directories

    print(f"Processing folder: {folder_name}")

    # Build expected file paths
    image_paths = [
        os.path.join(folder_path, f"NormProdSmoVar_11_EPSG3031_{folder_name[-17:]}.tif"),
        os.path.join(folder_path, f"NormProdSmoVar_21_EPSG3031_{folder_name[-17:]}.tif"),
        os.path.join(folder_path, f"NormProdSmoVar_33_EPSG3031_{folder_name[-17:]}.tif"),
    ]
    
    # Load 3 bands
    images = []
    for path in image_paths:
        with rasterio.open(path) as src:
            img = src.read(1)
            profile = src.profile
            images.append(img)

    img1, img2, img3 = images


    #Generate landmask (using band 1 as reference)
    img_path = pathlib.Path(image_paths[0])
    with rasterio.open(img_path) as src:
        bounds = src.bounds
        xmin, ymin, xmax, ymax = bounds.left, bounds.bottom, bounds.right, bounds.top
        width = src.width
        height = src.height
        transform = rasterio.transform.from_bounds(xmin, ymin, xmax, ymax, width, height)


    shapes = ((geom, 1) for geom in shapefile.geometry)
    img_landmask = rasterize(
        shapes,
        out_shape=(height, width),
        transform=transform,
        fill=0,
        dtype="uint8"
    )
    print("landmask generated")
    
    #erode landmask
    tmp = minimum_filter1d(img_landmask, size=100, axis=0, mode='constant', cval=0)
    eroded_landmask = minimum_filter1d(tmp, size=100, axis=1, mode='constant', cval=0)

    print("landmask eroded")


    # Normalise (uint8 [0,255])
    npmin, npmax, newmax = 0.5, 1.0, 255
    img_norm1 = ((img1 + npmin) / (npmax + npmin) * newmax).astype(np.uint8)
    img_norm2 = ((img2 + npmin) / (npmax + npmin) * newmax).astype(np.uint8)
    img_norm3 = ((img3 + npmin) / (npmax + npmin) * newmax).astype(np.uint8)
    print ('normalised three images')


    #landmask each img_norm
    img_norm1_masked = copy.copy(img_norm1).astype(float)
    img_norm1_masked[eroded_landmask==1] = np.nan
    
    img_norm2_masked = copy.copy(img_norm2).astype(float)
    img_norm2_masked[eroded_landmask==1] = np.nan
    
    img_norm3_masked = copy.copy(img_norm3).astype(float)
    img_norm3_masked[eroded_landmask==1] = np.nan


    print('masked three images')
    
    #stack 3 channels into rgb image
    img_rgb = cv2.merge([img_norm1_masked, img_norm2_masked, img_norm3_masked])  # shape (H, W, 3)
    print('created rgb image')
    
    #cleanup big files
    img1 = img2 = img3 = img_data = img_norm1 = img_norm1_masked = img_norm2 = img_norm2_masked = img_norm3 = img_norm3_masked = 0 
    print('cleaned up big files')
    
    #fill NaNs
    img_rgb = fill_nans(img_rgb)

    #Downsample by factor of 10
 
    zoom_factors = (1/10, 1/10, 1)  # Downsample 10x on the first two dimensions
    resampled = zoom(img_rgb, zoom=zoom_factors, order=1)
    resampled = np.nan_to_num(resampled, nan=0)
    img_uint8 = (resampled).astype(np.uint8)


    #plot RGB and save
    plt.imshow(img_uint8/255)
    #plt.savefig("RGB.pdf", bbox_inches="tight", dpi=300)
    #plt.savefig("RGB.png", bbox_inches="tight", dpi=300)
    #plt.savefig("RGB.eps", bbox_inches="tight")
    
    print('downsampled')

    #Segmentation
    sam_result = mask_generator.generate(img_uint8)
    print('segmented')
    
    #visualise segmentation
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    detections = sv.Detections.from_sam(sam_result=sam_result)
    annotated_image = mask_annotator.annotate(scene=img_uint8.copy(), detections=detections)

    #Plot rgb image vs segmented image
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(img_uint8)
    axes[0].set_title("a) source image")
    axes[0].axis("off")
    
    axes[1].imshow(annotated_image)
    axes[1].set_title("b) segmented image")
    axes[1].axis("off")
    
    plt.tight_layout()
    
    # Save in multiple formats
    #fig.savefig("SAMmask.png", dpi=300, bbox_inches="tight")
    #fig.savefig("SAMmask.pdf", bbox_inches="tight")
    #fig.savefig("SAMmask.eps", bbox_inches="tight")
    
    #plt.show()
    #fig = sv.plot_images_grid(
    #    images=[img_uint8, annotated_image],
    #    grid_size=(1, 2),
    #    titles=['a) source image', 'b)segmented image']
    #)
    
    #fig.savefig("SAMmask.png", dpi=300, bbox_inches="tight")
    #fig.savefig("SAMmask.pdf", bbox_inches="tight")
    #fig.savefig("SAMmask.eps", bbox_inches="tight")

    
    #making an array of non-overlapping masks based on 'smallest area wins'
    
    masks = [
        mask['segmentation']
        for mask in sorted(sam_result, key=lambda x: x['area'], reverse=True)
    ]
    height, width = masks[0].shape
    label_map = np.zeros((height, width), dtype=np.uint16)
    # Compute area for each mask
    mask_areas = [np.sum(mask) for mask in masks]
    # Get sorted indices: smallest area first
    sorted_indices = np.argsort(mask_areas)
    # Assign labels: smallest area wins
    for label, idx in enumerate(sorted_indices, start=1):
        mask = masks[idx]
        label_map[np.logical_and(mask, label_map == 0)] = label
    
    plt.imshow(label_map)
    plt.figure(figsize=(10, 10))
    plt.imshow(label_map, cmap='gist_ncar')  
    plt.title("Segmented Mask ‚Äî Smallest Area Wins")
    plt.axis('off')
    plt.colorbar(label="Segment Label")
    #plt.savefig("myMasks.pdf", bbox_inches="tight", dpi=300)
    #plt.savefig("myMasks.png", bbox_inches="tight", dpi=300)
    #plt.savefig("myMasks.eps", bbox_inches="tight")
    plt.show()

    #Write out arrays to segmented_masks
    #make output folder specific to scene:
    # Extract the date from the filename using a regular expression

    date_str = folder_name[-17:]
        
    folder_name = f'prydz_{date_str}' ############################################################################
    # Define the base directory where you want to save the segmented masks 
    base_directory = '/g/data/jk72/gb4219/honours_data/Normprod_Demo/maps/' ##################################################################
    # Create the full path for the new folder
    new_folder_path = os.path.join(base_directory, folder_name)
    # Create the new folder if it doesn't exist
    os.makedirs(new_folder_path, exist_ok=True)
    # Write out segments
    output_filepath = os.path.join(new_folder_path, f'label_map_{date_str}.npy')
    np.save(output_filepath, label_map)
    #write out rgb image
    output_filepath = os.path.join(new_folder_path, f'rgb_image_{date_str}.npy')
    np.save(output_filepath, img_uint8)
    print('files written out')
    

In [None]:
#SVM 

#directory of labelled scenes to train the SVM
TRAIN_DATA_DIR = pathlib.Path('/g/data/jk72/gb4219/honours_data/SVM_trainingdata/')###########

# list of all images that should be used for training
# these folder are located in the DATA_DIR that I defined above
training_image_list = [
    'prydz_20210129_20210210',
    'prydz_20210330_20210411',
    'prydz_20210728_20210809',
    'prydz_20210821_20210902',
    'prydz_20211020_20211101',
    'thwaites_20240708_20240720',
    'thwaites_20240813_20240825',
    'thwaites_20241012_20241024',
    'thwaites_20241024_20241105',
    'thwaites_20241105_20241117'
    
]
# initialize an empty training feature and label vector
# we will append the training data extracted from each segment to these
X_train = []
y_train = []

# we have 4 class labels
class_label_list = [0,2,3]

# assign 4 class colors
class_color_list = [
    [1,0,0],
    [0,1,0],
    [0,0,1],
    [0,0,0]
]

for img in training_image_list:

    print(f'Processing image {img}')

    img_date = f'{img.split('_')[1]}_{img.split('_')[2]}'

    print(f'img_date: {img_date}')
    
    # build the full paths to the files we need
    RGB_path     = TRAIN_DATA_DIR / img / f'rgb_image_{img_date}.npy'
    segment_path = TRAIN_DATA_DIR / img / f'label_map_{img_date}.npy'
    labels_path  = TRAIN_DATA_DIR / img / f'labelled_array_{img_date}.npy'

    # read in data from current img
    rgb            = np.load(RGB_path)
    segments       = np.load(segment_path)
    segment_labels = np.load(labels_path)

    # loop over all segments in your current img
    for current_segment in np.unique(segments):
        #print(f'Processing segment number {current_segment}')

        # get class label of current segment
        current_segment_label = segment_labels[current_segment]

        #print(f'    Current segment class label: {current_segment_label}')

        # only use the segment if its label is other than nan
        if not np.isnan(current_segment_label):
           # print('    Using this segment')

            # get the NP data (here: normalized in RGB) for current segment
            current_segment_data = rgb[segments==current_segment]

            # and average them to one single training vector for the current segment
            current_X_train = current_segment_data.mean(0)


            #print(f'    Current training feature vector: {current_X_train}')

            # now append the current segment to the training feature and label vecor
            X_train.append(list(current_X_train))
            y_train.append(int(current_segment_label))



X_train = np.array(X_train)
y_train = np.array(y_train)

#visualise training dataset
#fig, ax = plt.subplots(1,1)
#for class_label in class_label_list:
#    ax.plot(
#        X_train[y_train==class_label,1],
#        X_train[y_train==class_label,2],
#        '.',
#        color=class_color_list[class_label]
#)
#ax.legend(['Pack/ocean ', 'Class 2=fast', 'Class 3 = melting fast ice'])
#ax.set_xlabel("Average pixel value (dimension 1)")   
#ax.set_ylabel("Average pixel value (dimension 2)")
#plt.show()

# Base directory containing test folders
TEST_DATA_DIR = Path("/g/data/jk72/gb4219/honours_data/Normprod_Demo/maps/")######################################################################

# Build set of training dates to skip
training_dates = set()
for img in training_image_list:
    date_str = f"{img.split('_')[1]}_{img.split('_')[2]}"
    training_dates.add(date_str)

# Initialize lists
X_test = []
test_segments_info = []

# Loop over all folders in TEST_DATA_DIR
for folder in sorted(TEST_DATA_DIR.iterdir()):
    if not folder.is_dir():
        continue

    # Find the label_map file in the folder
    label_files = list(folder.glob("label_map_*.npy"))
    if len(label_files) == 0:
        print(f"‚ö†Ô∏è No label_map found in {folder}, skipping")
        continue

    label_file = label_files[0]  # take the first one if multiple exist

    # Extract date string from filename
    import re
    match = re.search(r'label_map_(\d{8}_\d{8})\.npy$', label_file.name)
    if not match:
        print(f"‚ö†Ô∏è Could not extract date from {label_file.name}, skipping")
        continue
    date_str = match.group(1)

    # Skip if date in training dataset
    #if date_str in training_dates:
        #print(f"Skipping {folder.name}, date {date_str} in training set")
        #continue

    # Build RGB path in the same folder
    rgb_file = folder / f"rgb_image_{date_str}.npy"
    if not rgb_file.exists():
        print(f"‚ö†Ô∏è No RGB file found at {rgb_file}, skipping")
        continue

    # Load data
    segments = np.load(label_file)
    rgb = np.load(rgb_file)

    # Extract features from segments
    for seg_id in np.unique(segments):
        seg_data = rgb[segments == seg_id]
        feature_vec = seg_data.mean(0)
        X_test.append(feature_vec.tolist())
        test_segments_info.append((folder.name, int(seg_id)))

# Convert to numpy array
X_test = np.array(X_test)

clf = svm.SVC()
clf.fit(X_train, y_train)
# Apply model
y_test_pred = clf.predict(X_test)

#####Save prediction map output ####


# Group predictions by image
predictions_by_image = {}

for (img_name, seg_id), pred_class in zip(test_segments_info, y_test_pred):
    if img_name not in predictions_by_image:
        predictions_by_image[img_name] = {}
    predictions_by_image[img_name][seg_id] = pred_class

# Now reconstruct a map for each image
for img_name, seg_predictions in predictions_by_image.items():

    img_date = f'{img_name.split("_")[1]}_{img_name.split("_")[2]}'

    # Path to this image's folder in TEST_DATA_DIR
    out_dir = TEST_DATA_DIR / img_name
    out_dir.mkdir(parents=True, exist_ok=True)  # just in case

    # Load segments
    segment_path = out_dir / f'label_map_{img_date}.npy'
    segments = np.load(segment_path)

    # Create empty prediction map
    pred_map = np.zeros_like(segments, dtype=int)

    # Fill each segment with predicted class
    for seg_id, pred_class in seg_predictions.items():
        pred_map[segments == seg_id] = pred_class

    """
    height, width = pred_map.shape
    new_transform = rasterio.transform.from_bounds(xmin, ymin, xmax, ymax, width, height)
    
    second_landmask = rasterize(
        shapes,
        out_shape=(height, width),
        transform=transform,
        fill=0,
        dtype="uint8"
    )
    """

    REF_BASE = Path("/g/data/jk72/gb4219/honours_data/Normprod_Demo/")#######################################################################
    
    ref_folder = REF_BASE / f"ISCE3_NormProd_EW_{img_date}"
    ref_path = ref_folder / f"NormProdSmoVar_11_EPSG3031_{img_date}.tif"

    if not ref_path.exists():
    # fallback: try to find a close match inside the folder
        candidates = list(ref_folder.glob(f"NormProdSmoVar_11*{img_date}*.tif"))
        if not candidates:
            print(f"‚ö†Ô∏è No reference tif found for {img_date} in {ref_folder}; skipping GeoTIFF.")
        else:
            ref_path = candidates[0]

    if ref_path.exists():
        with rasterio.open(ref_path) as src:
            profile = src.profile.copy()
            T = src.transform
    
    # --- Open reference raster (defines grid, transform, CRS) ---
    with rasterio.open(ref_path) as src:
        ref_shape = (src.height, src.width)
        ref_transform = src.transform
        ref_crs = src.crs
        profile = src.profile.copy()
    
    # --- Rasterize landmask from shapefile ---
    gdf = gpd.read_file("/g/data/jk72/gb4219/honours_data/Normprod_Demo/add_coastline_high_res_polygon_v7_10.shp")
    
    if gdf.crs is None:
        gdf = gdf.set_crs("EPSG:3031")  # replace with actual CRS if needed
    if gdf.crs != ref_crs:
        gdf = gdf.to_crs(ref_crs)
    
    landmask_resampled = rasterize(
        [(geom, 1) for geom in gdf.geometry],
        out_shape=ref_shape,
        transform=ref_transform,
        fill=0,
        dtype=np.uint8
    )
    
    # --- Upsample pred_map to reference size safely ---
    pred_map_resized = resize(
        pred_map,
        ref_shape,             # target height, width
        order=0,               # nearest neighbor preserves classes
        preserve_range=True,    # keep original class IDs
        anti_aliasing=False
    ).astype(np.uint8)
    
    # --- Apply landmask ---
    pred_map_masked = pred_map_resized.copy()
    pred_map_masked[landmask_resampled == 1] = 1  # or 255/nodata if preferred
    
    # --- Save a quicklook PNG ---
    class_colors = np.array([
        [255,   0,   0],   # class 0 ‚Üí red
        [  0, 255,   0],   # class 1 ‚Üí green
        [  0,   0, 255],   # class 2 ‚Üí blue
        [  0,   0,   0],   # class 3 ‚Üí black
        [200, 200, 200],   # land (masked) ‚Üí grey
    ], dtype=np.uint8)
    
    rgb_pred_map = class_colors[np.where(pred_map_masked == 1, 4, pred_map_masked)]
    plt.imsave(out_dir / f'predicted_map_{img_date}.png', rgb_pred_map)
    
    # --- Save georeferenced GeoTIFF ---
    profile.update(
        dtype=rasterio.uint8,
        count=1,
        height=ref_shape[0],
        width=ref_shape[1],
        transform=ref_transform,
        crs=ref_crs,
        compress='lzw',
        nodata=255
    )
    
    geotiff_path = out_dir / f'predicted_map_{img_date}.tif'
    with rasterio.open(geotiff_path, 'w', **profile) as dst:
        dst.write(pred_map_masked, 1)
    
    print(f"‚úÖ Saved georeferenced prediction for {img_name} to {geotiff_path}")
