# Co-register ArcticDEM strip files using ICESat-2 elevation data

In [None]:
import os
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterstats import zonal_stats
from rasterio.plot import plotting_extent
import geopandas as gpd
from shapely.geometry import box
from shapely.geometry import Point
from shapely.geometry import shape
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pyproj import Transformer
from sklearn.linear_model import LinearRegression
from affine import Affine
import earthpy.plot as ep
import rasterstats as rs
from datetime import datetime

In [None]:
# Set directories
arcticdem_folder = "ArcticDEM Processing/1_download"
artefact_folder = "ArcticDEM Processing/masks/artefacts"
dynamic_mask_folder = "ArcticDEM Processing/masks/dynamic"
clipped_dem_folder = "ArcticDEM Processing/2_crop"
corrected_dem_folder = "ArcticDEM Processing/3_co-register"

dynamic_mask_file = os.path.join(dynamic_mask_folder, "dynamic_mask.shp")
alignment_aoi_file = os.path.join("ArcticDEM Processing/masks/plane/", "align_dem.shp")
icesat_atlo6 = "ArcticDEM Processing/csv/merged_ATL06_march_may.csv"

Load and mask the ArcticDEM data using artefact masks and dynamic mask 
(masking the dynamic movement of the collapse basin and ice dynamics e.g. crevassing)

In [None]:
# Load the align_plane shapefile
align_plane = gpd.read_file(alignment_aoi_file)

# Convert align_plane to a single geometry for intersection checks
align_plane = align_plane.to_crs("EPSG:3413")  # Change to match your DEM CRS
align_plane_geom = align_plane.unary_union

# Iterate over all ArcticDEM .tif files
for dem_file in os.listdir(arcticdem_folder):
    if dem_file.endswith(".tif"):
        # Get paths
        base_name = os.path.splitext(dem_file)[0]
        artefact_file = os.path.join(artefact_folder, f"artefact_{base_name}.shp")
        dem_path = os.path.join(arcticdem_folder, dem_file)
        output_path = os.path.join(clipped_dem_folder, dem_file)

        print(f"Processing {dem_file}...")
        with rasterio.open(dem_path) as src:
            # Check if raster contains valid data
            data = src.read(1)  # Read the first band
            if np.all(data == 0) or np.isnan(data).all():
                print(f"Skipping {dem_file}: No valid data.")
                continue

            # Check if raster overlaps the align_plane
            raster_bounds = box(*src.bounds)  # Convert raster bounds to a shapely box
            if not raster_bounds.intersects(align_plane_geom):
                print(f"Skipping {dem_file}: No overlap with align_plane.")
                continue

            # Align align_plane CRS with the raster
            align_plane_geom = align_plane.to_crs(src.crs).unary_union

            # Start with the original DEM
            masked_image, masked_transform = src.read(), src.transform
            meta = src.meta.copy()
 
            # Check if the artefact file exists
            if os.path.exists(artefact_file):
                print(f"Applying mask using {artefact_file}...")
                artefact_mask = gpd.read_file(artefact_file)
                artefact_mask = artefact_mask.to_crs(src.crs)
                mask_geom = [feature["geometry"] for feature in artefact_mask.__geo_interface__["features"]]

                # Apply the artefact mask
                masked_image, masked_transform = mask(src, mask_geom, invert=True, crop=False)
                meta.update({
                    "height": masked_image.shape[1],
                    "width": masked_image.shape[2],
                    "transform": masked_transform
                })

            # Clip to align_plane
            print(f"Clipping to align_plane bounding box...")
            with rasterio.io.MemoryFile() as memfile:
                with memfile.open(**meta) as temp_raster:
                    temp_raster.write(masked_image)
                    clipped_image, clipped_transform = mask(temp_raster, [align_plane_geom], invert=False, crop=True)

            # Update meta after clipping
            meta.update({
                "height": clipped_image.shape[1],
                "width": clipped_image.shape[2],
                "transform": clipped_transform
            })

        # Save the result
        print(f"Saving raster to {output_path}")
        with rasterio.open(output_path, "w", **meta) as dest:
            dest.write(clipped_image)

print("Processing complete!")

In [None]:
# Load the ICESat-2 ATL06 elevation points
icesat_atlo6 = pd.read_csv(icesat_atlo6)

# Create GeoDataFrame from DataFrame
geometry = [Point(xy) for xy in zip(icesat_atlo6['longitude'], icesat_atlo6['latitude'])]
icesat_march_may = gpd.GeoDataFrame(icesat_atlo6, geometry=geometry, crs="EPSG:4326")  # Assuming WGS84

# Import dynamic regions and study area mask
mask_aoi = gpd.read_file(alignment_aoi_file)
mask_dynamic = gpd.read_file(dynamic_mask_file)

In [None]:
# Ensure the mask is in the same CRS as the points
mask_aoi = mask_aoi.to_crs(icesat_march_may.crs) 
mask_dynamic = mask_dynamic.to_crs(icesat_march_may.crs)

# Check the shapefiles have the same crs
print(f"Plane mask projection: {mask_aoi.crs}")
print(f"Dynamic_mask projection: {mask_dynamic.crs}")
print(f"ICESat-2 elevation projection: {icesat_march_may.crs}")

The ATL06_quality_summary parameter indicates the best-quality subset of all ATL06 data. A zero in
this parameter implies that no data-quality tests have found a problem with the segment, a one
implies that some potential problem has been found. Users who select only segments with zero values
for this flag can be relatively certain of obtaining high-quality data, but will likely miss a significant
fraction of usable data, particularly in cloudy, rough, or low-surface-reflectance conditions.

In [None]:
# Define output path
output_path = 'ArcticDEM Processing/shp/merged_ATL06_march_may_filtered.shp'

# Check if the file already exists before saving
if not os.path.exists(output_path):
    # Clip icesat data using dynamic mask and aoi
    icesat_mar_may_aoi_clip = gpd.clip(icesat_march_may, mask_aoi)
    icesat_mar_may_clip = icesat_mar_may_aoi_clip[~icesat_mar_may_aoi_clip.geometry.within(mask_dynamic.unary_union)]

    # Filter icesat data by removing lower quality points atl06_quality_summary==1
    icesat_mar_may_filtered = icesat_mar_may_clip[icesat_mar_may_clip['atl06_quality_summary']==0]

    # Save the filtered points as a new shapefile
    icesat_mar_may_filtered.to_file(f"ArcticDEM Processing/shp/merged_ATL06_march_may_filtered.shp", driver='ESRI Shapefile')

    print("Clipped and filtered shapefile created: merged_ATL06_march_may_filtered.shp")

In [None]:
# (Optional) Plot the clipped data
fig, ax = plt.subplots(figsize=(12, 8))
icesat_mar_may_filtered.plot(ax=ax, color="purple")
mask_dynamic.boundary.plot(ax=ax, color="black")
mask_aoi.boundary.plot(ax=ax, color="red")
ax.set_title("IceSat-2 ATL06 tracks (march-may, atl06_quality_summary==0)", fontsize=10)
ax.set_axis_off()
plt.show()

In [None]:
# Define output path
output_path = 'ArcticDEM Processing/shp/icesat2_mar_may_filter_20m_buffer.shp'

# Check if the file already exists before saving
if not os.path.exists(output_path):
    # Ensure the GeoDataFrame has a Polar Stereographic coordinate reference system
    icesat2_20mbuffer = icesat_mar_may_filtered.to_crs("EPSG:3413") 

    # Create a 20m buffer around each point
    icesat2_20mbuffer['geometry'] = icesat2_20mbuffer.geometry.buffer(20)
    
    icesat2_20mbuffer.to_file(output_path, driver="ESRI Shapefile")
    print(f"Buffered shapefile created: {os.path.basename(output_path)}")
else:
    print(f"File already exists and was not overwritten: {os.path.basename(output_path)}")


For each clipped ArcticDEM tif, extract elevation within 20m buffered point and calculate the mean elevation

In [None]:
# Convert features into GeoDataFrame
def create_geodataframe(features):
    geometries = []
    properties = []
    for feature in features:
        geometries.append(shape(feature['geometry']))
        properties.append(feature['properties'])
    
    gdf = gpd.GeoDataFrame(properties, geometry=geometries)
    return gdf

In [None]:
zonal_stats_file = "ArcticDEM Processing/csv/zonal_stats_20mbuffer_icesat2_filtered.csv"

# Check if the file already exists
if not os.path.exists(zonal_stats_file):
    # List to store zonal stats for all DEMs
    zonal_stats_all = []

    # Iterate over each masked DEM
    for dem_file in os.listdir(clipped_dem_folder):
        if dem_file.endswith(".tif"):
            dem_path = os.path.join(clipped_dem_folder, dem_file)
            print(f"Processing {dem_file}...")

            with rasterio.open(dem_path) as src:
                # Align CRS
                icesat2_points = icesat2_20mbuffer.to_crs(src.crs)

                # Check for overlap
                raster_bounds = box(*src.bounds)
                buffered_extent = icesat2_20mbuffer.geometry.unary_union

                if not raster_bounds.intersects(buffered_extent):
                    print(f"No overlap between {dem_file} and buffered points. Skipping...")
                    continue

                # Retrieve NoData value
                nodata_value = src.nodata
                if nodata_value is None:
                    nodata_value = -9999  # Set a default if not defined

                # Compute zonal stats
                stats = zonal_stats(
                    vectors=icesat2_points,  # Buffered points
                    raster=dem_path,
                    stats=["mean", "max", "min", "std", "count"],
                    geojson_out=True,
                    copy_properties=True,
                    all_touched=True,
                    nodata=nodata_value  # Exclude NoData values
                )

                # Add file information to stats
                fileid = os.path.splitext(dem_file)[0]
                for stat in stats:
                    # Ensure `properties` exists and add custom attributes
                    if "properties" not in stat:
                        stat["properties"] = {}
                    stat["properties"]["DEM"] = dem_file
                    stat["properties"]["fileid"] = fileid

                # Append to list
                zonal_stats_all.extend(stats)

    # Create the GeoDataFrame
    zonal_stats_gdf = create_geodataframe(zonal_stats_all)

    # Save results to a CSV
    zonal_stats_gdf.to_csv(zonal_stats_file, index=False)
    print(f"Zonal statistics saved to {zonal_stats_file}")
    
else:
    print(f"Zonal statistics already exists: {os.path.basename(zonal_stats_file)}")

In [None]:
zonal_stats_df = pd.read_csv(zonal_stats_file)

# Check if your CSV has latitude and longitude columns
latitude_column = 'latitude'
longitude_column = 'longitude'

# Create a GeoDataFrame
geometry = [Point(xy) for xy in zip(zonal_stats_df[longitude_column], zonal_stats_df[latitude_column])]
geo_df = gpd.GeoDataFrame(zonal_stats_df, geometry=geometry, crs="EPSG:4326")  # WGS84

# Transform to Polar Stereographic
geo_df_polar = geo_df.to_crs("EPSG:3413")

zonal_stats_gdf = gpd.GeoDataFrame(geo_df_polar)

Example plot

In [None]:
# Open the ArcticDEM raster
example_raster = "SETSM_s2s041_WV02_20120821_103001001C45DA00_103001001B312B00_2m_lsf_seg2_dem.tif"

with rasterio.open("ArcticDEM Processing/3_clip/"+ example_raster) as src:
    # Read the DEM data (assuming it has a single band)
    dem_data = src.read(1)  # Reads the first band
    
    # Replace null values (-9999) with NaN
    dem_data = np.where(dem_data == -9999, np.nan, dem_data)
    
    # Get the affine transform to map row/column indices to spatial coordinates (eastings/northings)
    transform = src.transform
    
    # Mask the DEM to focus on valid data points (non-NaN values)
    dem_masked = np.ma.masked_invalid(dem_data)  # Mask invalid (NaN) values

    # Get the eastings and northings for each pixel
    rows, cols = np.where(~dem_masked.mask)  # Find the valid pixels (non-masked)
    
    # Convert pixel row/col indices to geographic coordinates (eastings and northings)
    eastings, northings = src.xy(rows, cols)

icesat_dem_extract = zonal_stats_gdf[zonal_stats_gdf['DEM'] == example_raster]

fig, ax = plt.subplots(figsize=(10, 10))

ep.plot_bands(dem_data,
              extent=plotting_extent(src), # Set spatial extent 
              cmap='Greys',
              title=f"ICESat-2 filtered points \n ArcticDEM {datetime.strptime(example_raster.split("_")[3], "%Y%m%d").date()}",
              scale=False,
              ax=ax)

icesat_dem_extract.plot(ax=ax,
                       marker='s',
                       markersize=10,
                       color='purple')
ax.set_axis_off()
plt.show()


In [None]:
zonal_stats_gdf_dropna = zonal_stats_gdf.dropna(subset=['mean'])

zonal_stats_gdf_dropna['x'] = zonal_stats_gdf_dropna.geometry.centroid.x
zonal_stats_gdf_dropna['y'] = zonal_stats_gdf_dropna.geometry.centroid.y

# Use x, y, and DEM elevation as predictors
X = zonal_stats_gdf_dropna[['x', 'y', 'mean']].values
y = y = zonal_stats_gdf_dropna['mean'] - zonal_stats_gdf_dropna['h_li'].values # Difference between ICESat-2 and ArcticDEM

# Fit a linear regression model
model = LinearRegression()
model.fit(X, y)

print("Regression coefficients:", model.coef_)
print("Intercept:", model.intercept_)


Remove ICESat-2 outliers where values are more than 3 standard deviatioins away from the mean 

In [None]:
zonal_stats_gdf_dropna = zonal_stats_gdf.dropna(subset=['mean'])

# Calculate the mean and standard deviation of the ICESat-2 elevations
# Define the range for valid elevations (within 3 standard deviations from the mean)
lower_bound = zonal_stats_gdf_dropna['h_li'].mean() - 3 * zonal_stats_gdf_dropna['h_li'].std()
upper_bound = zonal_stats_gdf_dropna['h_li'].mean() + 3 * zonal_stats_gdf_dropna['h_li'].std()

# Filter out the outliers by checking if the elevations are within the valid range
valid_elevations_mask = (zonal_stats_gdf_dropna['h_li'] >= lower_bound) & (zonal_stats_gdf_dropna['h_li'] <= upper_bound)
zonal_stats_gdf_filtered = zonal_stats_gdf_dropna[valid_elevations_mask]

# Calculate x/y from geometry
zonal_stats_gdf_filtered['x'] = zonal_stats_gdf_filtered.geometry.centroid.x
zonal_stats_gdf_filtered['y'] = zonal_stats_gdf_filtered.geometry.centroid.y

Create a modeled DEM using linear regression parameters obtained from the ICESat-2 elevation differences.
The "modelled DEM" should represent a plane generated using the regression model's coefficients over the entire DEM extent.

In [None]:
def load_dem(file_path):
    """Load DEM file and return metadata, data, and coordinate grid."""
    with rasterio.open(file_path) as src:
        meta = src.meta
        meta.update(dtype='float32')
        transform = src.transform
        height, width = src.height, src.width

        x_coords, y_coords = np.meshgrid(
            np.arange(width) * transform[0] + transform[2],
            np.arange(height) * transform[4] + transform[5]
        )

        dem_data = src.read(1)
        dem_data = np.where(dem_data == -9999, np.nan, dem_data)

        return meta, dem_data, x_coords, y_coords

def correct_dem(dem_data, x_coords, y_coords, reg_model):
    """Apply the regression model to correct the DEM."""
    x_residual = x_coords - np.nanmean(x_coords)
    y_residual = y_coords - np.nanmean(y_coords)

    modelled_dz = (reg_model.coef_[0] * x_residual) + (reg_model.coef_[1] * y_residual) + reg_model.intercept_
    return dem_data + modelled_dz, modelled_dz



In [None]:
# Process all .tif files in the input folder
for file_name in os.listdir(clipped_dem_folder):
    if file_name.endswith(".tif"):
        clipped_dem_path = os.path.join(clipped_dem_folder, file_name)
        corrected_dem_path = os.path.join(corrected_dem_folder, file_name)

        # Check if output file exists and is valid
        if os.path.exists(corrected_dem_path) and os.path.getsize(corrected_dem_path) > 0:
            print(f"Skipping {file_name}: Output file already exists and is >0MB.")
            continue

        print(f"Processing {file_name}...")

        try:
            # Load DEM file and return metadata, data, and coordinate grid
            meta, dem_data, x_coords, y_coords = load_dem(clipped_dem_path)

            # Linear regression setup
            x = zonal_stats_gdf_filtered["x"].values - np.nanmean(x_coords)
            y = zonal_stats_gdf_filtered["y"].values - np.nanmean(y_coords)
            Y = zonal_stats_gdf_filtered["h_li"] - zonal_stats_gdf_filtered['mean']

            mask = ~np.isnan(Y)
            X = np.stack([x, y], axis=1)[mask]
            Y = Y[mask]

            reg_model = LinearRegression().fit(X, Y)
            print("Regression score:", reg_model.score(X, Y))
            print("Coefficients:", reg_model.coef_)
            print("Intercept:", reg_model.intercept_)

            # Apply the regression model to correct the DEM
            corrected_dem, modelled_dz = correct_dem(dem_data, x_coords, y_coords, reg_model)

            # Save the corrected DEM to the specified path
            meta.update(dtype=corrected_dem.dtype)
            with rasterio.open(corrected_dem_path, "w", **meta) as dst:
                dst.write(corrected_dem, 1)

        except Exception as e:
            print(f"Error processing {file_name}: {e}")

print("DEM correction completed for all files.")

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(20, 10))

ep.plot_bands(dem_data,
              extent=plotting_extent(src), # Set spatial extent 
              cmap='Greys',
              title="original dem ",
              scale=False,
              ax=ax[0])

ep.plot_bands(modelled_dz,
              extent=plotting_extent(src), # Set spatial extent 
              cmap='Greys',
              title="modelled dem ",
              scale=False,
              ax=ax[1])

ep.plot_bands(corrected_dem,
              extent=plotting_extent(src), # Set spatial extent 
              cmap='Greys',
              title="corrected dem ",
              scale=False,
              ax=ax[2])

ep.plot_bands(dem_data-corrected_dem,
              extent=plotting_extent(src), # Set spatial extent 
              cmap='RdBu',
              title="difference dem ",
              scale=False,
              ax=ax[3])

# icesat_mar_may_clip_Polar_Stereographic.plot(ax=ax[1],
#                        marker='s',
#                        markersize=10,
#                        color='purple')

plt.tight_layout()
plt.show()
