# Run STARCOP models on full granules of EMIT data

>  V. Růžička, G. Mateo-Garcia, L. Gómez-Chova, A. Vaughan, L. Guanter, and A. Markham, [Semantic segmentation of methane plumes with hyperspectral machine learning models](https://www.nature.com/articles/s41598-023-44918-6). _Scientific Reports 13, 19999_ (2023). DOI: 10.1038/s41598-023-44918-6.

Demo with loading the AVIRIS trained models to show zero-shot generalisation on the data from EMIT.

*Update Jan 2025: the library versions were updated to work with the current Colab environment.*

In [None]:
# Try this first:
# !pip install git+https://github.com/spaceml-org/STARCOP.git

# But for Google Colab (as of January 2025) instead use:
!pip install georeader-spaceml -q
!pip install torch==2.0.0 torchvision==0.15.1 torchtext==0.15.1 pytorch-lightning==2.2 -q
!pip install fsspec gcsfs omegaconf kornia==0.6.7  torchmetrics==0.10.0 wandb segmentation_models_pytorch hydra-core ipython rasterio  geopandas ipykernel matplotlib scikit-image scikit-learn wandb -q
!pip install netCDF4 spectral -q

!pip install huggingface_hub[cli,torch] -q
!pip install matplotlib-scalebar -q

In [None]:
!git clone https://github.com/spaceml-org/STARCOP.git

In [None]:
%cd STARCOP

## Step 1: download EMIT image

In order to download and process the EMIT image we will use the emit reader in the [georeader](https://github.com/spaceml-org/georeader/) package. See [this tutorial](https://github.com/spaceml-org/georeader/blob/main/notebooks/emit_explore.ipynb) for an example of how to load and plot the data.

In [None]:
from huggingface_hub import hf_hub_download
from georeader.readers import emit, download_utils
from starcop.models import mag1c_emit
from georeader import plot
import starcop
from starcop.models.model_module import ModelModule
import os
import torch
import omegaconf
import numpy as np
import matplotlib.pyplot as plt
from starcop.models.utils import padding
import georeader

In [None]:
file_name = "EMIT_L1B_RAD_001_20250813T111228_2522507_037.nc"
token = "eyJ0eXAiOiJKV1QiLCJvcmlnaW4iOiJFYXJ0aGRhdGEgTG9naW4iLCJzaWciOiJlZGxqd3RwdWJrZXlfb3BzIiwiYWxnIjoiUlMyNTYifQ.eyJ0eXBlIjoiVXNlciIsInVpZCI6Im1haF9zYW0iLCJleHAiOjE3NjA1NzI3OTksImlhdCI6MTc1NTM1NTU0NywiaXNzIjoiaHR0cHM6Ly91cnMuZWFydGhkYXRhLm5hc2EuZ292IiwiaWRlbnRpdHlfcHJvdmlkZXIiOiJlZGxfb3BzIiwiYWNyIjoiZWRsIiwiYXNzdXJhbmNlX2xldmVsIjozfQ.459S_fTfcXafo7Yk017iH4iclfb0z41d57Rj8oTdGYRWCsGD2oGTJGYLrKKHot3JjfgS1kyc33YSAY5JmHy5I0CaYaneDriJIJrBHifka-IJBC3bGow331XgUbiAV_WOM_kJ1ReMSVz7lqr1NiMvf6YGMnE6N9QAQcGCL8UgRZNTYJx8qUrBlpYiPA3p-FQYSbYkINlRjr7myJUkiAHItQ-CTxuAuxbO0j2rdi5Hx1Aldze7NBovlBLjBNnHW6LI9XhaMvlJciXVAXw9T_AuAoD55qkP_ACRIje7hCLz_s-1srjXUGViQfFGy3ixqZuiQ6GYzXRGOU9kJPrjYUWbrA"
earthdata_nasa_account = True

# NASA's data archive requires creating an account for downloading EMIT files directly.
# Create an user and a token at the NASA Earthdata portal (https://urs.earthdata.nasa.gov/profile)

def download_granule(granule_name=file_name, token=token, earthdata_nasa_account=earthdata_nasa_account):
    if earthdata_nasa_account:
        link = emit.get_radiance_link(granule_name)
        emit.AUTH_METHOD = "token"
        emit.TOKEN = token # copy your token here
        headers = {"Authorization": f"Bearer {emit.TOKEN}"}
    
        product = download_utils.download_product(link, headers=headers,  verify=True)
    
        rst = emit.EMITImage(product)
    return rst

granule = download_granule()

## Step 3: run mag1c on the EMIT product

Run mag1c filter retrieval based on the work of [Foote et al. 2020](https://ieeexplore.ieee.org/document/9034492).

In [None]:
def get_rgb(granule):
    wavelengths_read = np.array([640, 550, 460])
    bands_read = np.argmin(np.abs(wavelengths_read[:, np.newaxis] - granule.wavelengths), axis=1).tolist()
    rst_rgb = granule.read_from_bands(bands_read)
    rgb_raw = rst_rgb.load_raw(transpose=True)
    return rgb_raw
    
def apply_mf(granule):
    mfoutput, albedo = mag1c_emit.mag1c_emit(granule, column_step=2, georreferenced=False)
    return mfoutput

In [None]:
plt.imshow(apply_mf(granule), vmin=0,vmax=1750)
plt.title("$\Delta$CH$_4$ [ppm x m]")
plt.colorbar()

## Step 4: Load STARCOP model

In [None]:
from huggingface_hub import hf_hub_download
# experiment_name = "hyperstarcop_mag1c_only"
experiment_name = "hyperstarcop_mag1c_rgb"
subfolder_local = f"models/{experiment_name}"
config_file = hf_hub_download(repo_id="isp-uv-es/starcop",subfolder=subfolder_local, filename="config.yaml",
                              local_dir=".", local_dir_use_symlinks=False)
model_file = hf_hub_download(repo_id="isp-uv-es/starcop",subfolder=subfolder_local,
                             filename="final_checkpoint_model.ckpt",
                              local_dir=".", local_dir_use_symlinks=False)

In [None]:
hsi_model_path = os.path.join(subfolder_local, "final_checkpoint_model.ckpt")
hsi_config_path =  os.path.join(subfolder_local, "config.yaml")

device = torch.device("cpu")
config_general = omegaconf.OmegaConf.load(os.path.join(os.path.dirname(os.path.abspath(starcop.__file__)), 'config.yaml'))

def load_model_with_emit(model_path, config_path):
    config_model = omegaconf.OmegaConf.load(config_path)
    config = omegaconf.OmegaConf.merge(config_general, config_model)

    model = ModelModule.load_from_checkpoint(model_path, settings=config)
    model.to(device)
    model.eval() # !

    print("Loaded model with",model.num_channels,"input channels")

    return model, config

hsi_model, hsi_config = load_model_with_emit(hsi_model_path, hsi_config_path)
print("successfully loaded HyperSTARCOP model!")

## Step 6: run inference

In [None]:
def model_predict(granule):
    # DIV the EMIT data by
    MAGIC_DIV_BY = 240.
    RGB_DIV_BY = 20.
    # clipping too large values
    MAGIC_CLIP_TO = [0.,2.]
    RGB_CLIP_TO =   [0.,2.]
    # MULT_BY to get it back to the range we saw in the AVIRIS data ...
    MAGIC_MULT_BY = 1750.
    RGB_MULT_BY =   60.
    
    
    # NORMALISE
    # emit rgb has max ~22
    rgb_raw = get_rgb(granule)
    mfoutput = apply_mf(granule)
    e_mag1c = np.clip(mfoutput / MAGIC_DIV_BY, MAGIC_CLIP_TO[0], MAGIC_CLIP_TO[1]) * MAGIC_MULT_BY
    e_rgb = np.clip(rgb_raw / RGB_DIV_BY, RGB_CLIP_TO[0], RGB_CLIP_TO[1]) * RGB_MULT_BY
    input_data = np.concatenate([e_mag1c[None], e_rgb], axis=0)
    input_data.shape
    pred = padding.padded_predict(input_data, model=lambda x: torch.sigmoid(hsi_model(x)))
    return pred

pred = model_predict(granule)

In [None]:
plt.imshow(pred[0],vmin=0,vmax=1)

In [None]:
crs_utm = georeader.get_utm_epsg(granule.footprint("EPSG:4326"))
emit_image_utm = granule.to_crs(crs_utm)
mfgeo = emit_image_utm.georreference(apply_mf(granule), fill_value_default=-1)
predgeo = emit_image_utm.georreference(pred[0], fill_value_default=0)
rgbgeo = emit_image_utm.georreference(get_rgb(granule), fill_value_default=-1)
transform = predgeo.transform
#utm_x, utm_y = transform * (max_col + 0.5, max_row + 0.5)
source_crs = predgeo.crs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from scipy.spatial.distance import cdist # Efficiently calculate distances
from pyproj import Transformer

def get_maxima(pred, threshold=0.5, min_distance=40, border_margin=0):
    # --- Assume 'pred_map' is your 2D prediction array ---
    pred_map = pred.squeeze() 
    ## Step 7: georreference and plot results
    image_height, image_width = pred_map.shape
    
    # --- Step 1 & 2: Threshold, Label, and Find Initial Maxima ---
    binary_mask = pred_map > threshold
    labeled_mask, num_labels = measure.label(binary_mask, connectivity=2, return_num=True)
    
    initial_maxima = []
    for i in range(1, num_labels + 1):
        component_mask = (labeled_mask == i)
        component_values = np.where(component_mask, pred_map, 0)
        
        # Find the coordinate and value of the max in this component
        max_coord = np.unravel_index(np.argmax(component_values), pred_map.shape)
        max_value = component_values[max_coord]
        
        # Store as a dictionary for easy sorting and access
        initial_maxima.append({'coord': max_coord, 'value': max_value})
    
    print(f"Found {len(initial_maxima)} initial blobs.")
    
    # --- Step 3: Filter by Edge ---
    edge_filtered_maxima = []
    for maxima in initial_maxima:
        row, col = maxima['coord']
        if (border_margin <= row < image_height - border_margin and
            border_margin <= col < image_width - border_margin):
            edge_filtered_maxima.append(maxima)
        else:
            print(f"  - Discarding maxima at ({row}, {col}) due to border proximity.")
    
    print(f"{len(edge_filtered_maxima)} maxima remaining after edge filtering.")
    
    # --- Step 4: Filter by Proximity (Greedy Suppression) ---
    # Sort the maxima by value in descending order (highest first)
    sorted_maxima = sorted(edge_filtered_maxima, key=lambda m: m['value'], reverse=True)

    final_maxima_coords = []
    while sorted_maxima:
        # Pop the highest value maximum from the list
        current_max = sorted_maxima.pop(0)
        final_maxima_coords.append(current_max['coord'])
        
        # If there are no other maxima left, we're done
        if not sorted_maxima:
            break
            
        # Get coordinates for comparison
        current_coord = np.array([current_max['coord']])
        remaining_coords = np.array([m['coord'] for m in sorted_maxima])
        
        # Calculate distances between the current max and all others
        distances = cdist(current_coord, remaining_coords)[0]
        
        # Keep only those maxima that are further away than min_distance
        # We build a new list of the survivors
        survivors = []
        for i, is_far_enough in enumerate(distances >= min_distance):
            if is_far_enough:
                survivors.append(sorted_maxima[i])
                
        sorted_maxima = survivors # Replace the list with the filtered survivors
    
    print(f"{len(final_maxima_coords)} maxima remaining after proximity filtering.")

    # --- Create a blank mask with the same shape as the original prediction map ---
    maxima_mask = np.zeros_like(pred[0], dtype=np.uint8)
    
    # --- Mark the locations of your final maxima on this mask ---
    if final_maxima_coords:
        for row, col in final_maxima_coords:
            maxima_mask[row, col] = 1 # Set the pixel value to 1 at each maximum
    
    # --- Apply the EXACT SAME georeferencing transformation ---
    # This ensures the points are in the same coordinate system as your images.
    maxima_geo = emit_image_utm.georreference(maxima_mask, fill_value_default=0)
    
    print("Successfully created a georeferenced layer for the maxima points.")
    
    # Lists to store the final, correctly transformed coordinates
    plot_utm_x = []
    plot_utm_y = []
    final_geo_coordinates = []
    
    # Prepare the coordinate transformer
    source_crs = maxima_geo.crs
    target_crs = "EPSG:4326"
    transformer = Transformer.from_crs(source_crs, target_crs, always_xy=True)
    
    # Find the locations of our markers within the georeferenced object
    # np.where will give us the (row, col) indices of all non-zero pixels
    marker_rows, marker_cols = np.where(maxima_geo.values == 1)
    
    if len(marker_rows) > 0:
        # Convert these pixel indices to UTM coordinates using the object's transform
        plot_utm_x, plot_utm_y = maxima_geo.transform * (marker_cols + 0.5, marker_rows + 0.5)
    
        # Now, convert these correct UTM coordinates to Lat/Lon for the links
        for x, y in zip(plot_utm_x, plot_utm_y):
            lon, lat = transformer.transform(x, y)
            final_geo_coordinates.append({'lat': lat, 'lon': lon})
    else:
        print("No maxima to process.")
    return final_maxima_coords, final_geo_coordinates

final_maxima_coords, final_geo_coordinates = get_maxima(pred)
# --- Step 5: Visualize the Final Results ---
plt.figure(figsize=(14, 12))
plt.imshow(pred.squeeze(), cmap='inferno', vmin=0, vmax=1)
plt.colorbar(label='Prediction Confidence')
plt.title(f'Final {len(final_maxima_coords)} Maxima')

# Plot only the final, filtered maxima
if final_maxima_coords:
    final_cols, final_rows = zip(*[(col, row) for row, col in final_maxima_coords])
    plt.plot(final_cols, final_rows, 'c+', markersize=15, markeredgewidth=3, linestyle='None', label='Final Maxima')

plt.legend()
plt.show()

## Step 7: georreference and plot results

In [None]:
# from pyproj import Transformer
# # Prepare the tools for conversion (do this once for efficiency)
# affine_transform = predgeo.transform
# source_crs = predgeo.crs
# target_crs = "EPSG:4326" # WGS84 for Lat/Lon
# transformer = Transformer.from_crs(source_crs, target_crs, always_xy=True)

# final_geo_coordinates = [] # A list to store the final lat/lon pairs

# if not final_maxima_coords:
#     print("No maxima found after filtering.")
# else:
#     for i, (row, col) in enumerate(final_maxima_coords):
#         # A) Convert pixel to UTM
#         utm_x, utm_y = affine_transform * (col + 0.5, row + 0.5)
        
#         # B) Convert UTM to Lat/Lon
#         lon, lat = transformer.transform(utm_x, utm_y)
        
#         # Store the result
#         final_geo_coordinates.append({'lat': lat, 'lon': lon})

# # --- Step 6: Visualize the Final Results ---
# plt.figure(figsize=(14, 12))
# plt.imshow(pred_map, cmap='inferno', vmin=0, vmax=1)
# plt.colorbar(label='Prediction Confidence')
# plt.title(f'Final {len(final_maxima_coords)} Maxima (min_dist={min_distance}, border={border_margin})')

# if final_maxima_coords:
#     final_cols, final_rows = zip(*[(col, row) for row, col in final_maxima_coords])
#     plt.plot(final_cols, final_rows, 'c+', markersize=15, markeredgewidth=3, linestyle='None', label='Final Maxima')
#     plt.legend()

# plt.show()

In [None]:
# --- Your final plotting block, now with text annotations only on the first map ---
fig, ax = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True)

# --- Plot 1: RGB Image ---
rgbgeomask = np.any(rgbgeo.values == -1, axis=0, keepdims=False)
rgbplot = (rgbgeo/12).clip(0,1)
rgbplot.values[:, rgbgeomask] = -1
plot.show(rgbplot, ax=ax[0], title="RGB", mask=True, add_scalebar=True)

# --- Plot 2: Methane Enhancement Map ---
plot.show(mfgeo, ax=ax[1], title="$\Delta$CH$_4$ [ppm x m]", mask=True, vmin=0, vmax=1750,
         add_colorbar_next_to=True, add_scalebar=True)

# --- Plot 3: Prediction Map ---
plot.show(predgeo, ax=ax[2], title="Prediction", mask=True, vmin=0, vmax=1, add_scalebar=True,
          add_colorbar_next_to=True)

plt.tight_layout()
plt.show()


# --- Step 4: Print the Final, Correct Google Maps Links ---
print("\n" + "="*50)
print("GEOGRAPHIC LOCATIONS OF FINAL MAXIMA")
print("="*50)

if not final_geo_coordinates:
    print("No maxima found after filtering.")
else:
    print
    for i, coords in enumerate(final_geo_coordinates):
        print(f"\n--- Maximum #{i + 1} ---")
        print(f"Latitude:  {coords['lat']:.6f}")
        print(f"Longitude: {coords['lon']:.6f}")
        print(f"Google Maps Link: https://www.google.com/maps?q={coords['lat']},{coords['lon']}")

In [None]:
import folium
from folium.plugins import HeatMap
import numpy as np
from pyproj import Transformer
import branca # Folium's parent library
from jinja2 import Template # Used to create the custom element's template

# ==============================================================================
# This script generates a map where clicking anywhere creates a popup with a
# link to Google Maps.
#
# REQUIRED INPUT VARIABLES:
# - predgeo: The georeferenced GeoTensor of the prediction heatmap.
# - final_geo_coordinates: A list of {'lat': ..., 'lon': ...} dicts for each detected source.
#
# ==============================================================================

# --- Step 1 & 2: Prepare Data and Transform Coordinates (Same as before) ---
def visualize_output(predgeo, final_geo_coordinates=None, heatmap_threshold=0.4):
    prediction_pixels = predgeo.values
    rows, cols = np.where(prediction_pixels >= heatmap_threshold)
    scores = prediction_pixels[rows, cols].astype(float)
    affine_transform = predgeo.transform
    source_crs = predgeo.crs
    target_crs = "EPSG:4326"
    transformer = Transformer.from_crs(source_crs, target_crs, always_xy=True)
    utm_x, utm_y = affine_transform * (cols + 0.5, rows + 0.5)
    lons, lats = transformer.transform(utm_x, utm_y)
    heatmap_data = list(zip(lats, lons, scores))
    
    # --- Step 3: Define Map Center and Create the Base Map (Same as before) ---
    if final_geo_coordinates:
        map_center = [final_geo_coordinates[0]['lat'], final_geo_coordinates[0]['lon']]
        points_layer = folium.FeatureGroup(name='Detected Plume Sources', show=True)
        for i, coords in enumerate(final_geo_coordinates):
            lat, lon = coords['lat'], coords['lon']
            popup_html = f"<b>Candidate #{i+1}</b><br>Lat: {lat:.6f}, Lon: {lon:.6f}<br><a href='https://www.google.com/maps?q={lat},{lon}' target='_blank'>Google Maps</a>"
            folium.Marker(
                location=[lat, lon],
                popup=folium.Popup(popup_html, max_width=300),
                icon=folium.Icon(color='cyan', icon='cloud', prefix='fa')
            ).add_to(points_layer)
    else:
        footprint_wgs84 = predgeo.footprint("EPSG:4326")
        min_lon, min_lat, max_lon, max_lat = footprint_wgs84.bounds
        map_center = [(min_lat + max_lat) / 2, (min_lon + max_lon) / 2]
    m_final = folium.Map(location=map_center, zoom_start=14, tiles="Esri.WorldImagery")
    
    # --- Step 4: Create and Populate All Layers (Same as before) ---
    heatmap_layer = folium.plugins.HeatMap(data=heatmap_data, name='Dynamic Methane Heatmap')
    
    # --- Step 5: Add "Click for Popup" Functionality (THE MODIFIED VERSION) ---
    
    # We modify our custom class to create a Leaflet popup instead of opening a new window.
    class ClickForPopup(branca.element.MacroElement):
        _template = Template(u"""
            {% macro script(this, kwargs) %}
                function create_popup_on_click(e) {
                    var lat = e.latlng.lat;
                    var lon = e.latlng.lng;
                    var url = `https://www.google.com/maps?q=${lat},${lon}`;
                    
                    // Create the HTML content for the popup
                    var html = `
                        <b>Location Info</b><br>
                        Latitude: ${lat.toFixed(6)}<br>
                        Longitude: ${lon.toFixed(6)}
                        <hr style="margin: 5px 0;">
                        <ul>
                            <li><a href="${url}" target="_blank" rel="noopener noreferrer">View in Google Maps</a></li>
                        </ul>
                    `;
                    
                    // Create a Leaflet popup object and open it on the map
                    var popup = L.popup()
                        .setLatLng(e.latlng)
                        .setContent(html)
                        .openOn({{this._parent.get_name()}});
                }
                // Attach the function to the map's click event
                {{this._parent.get_name()}}.on('click', create_popup_on_click);
            {% endmacro %}
            """)
    
        def __init__(self):
            super(ClickForPopup, self).__init__()
            self._name = 'ClickForPopup'
    
    # Create an instance of our custom element and add it to the map.
    m_final.add_child(ClickForPopup())
    
    
    # --- Step 6: Add Layers to the Map and Save ---
    heatmap_layer.add_to(m_final)
    if final_geo_coordinates:
        points_layer.add_to(m_final)
    folium.LayerControl().add_to(m_final)
    output_map_path_final = 'methane_detection_map_popup.html'
    m_final.save(output_map_path_final)
    print(f"Interactive map with popups saved to: {output_map_path_final}")
    return m_final

#m_final = visualize_output(predgeo, final_geo_coordinates)
#m_final

## 8. Detection Over Multiple Granules

In [None]:
granule_names = ["EMIT_L1B_RAD_001_20250730T075636_2521105_018.nc",\
                "EMIT_L1B_RAD_001_20250813T111228_2522507_037.nc",\
                "EMIT_L1B_RAD_001_20250417T095633_2510706_024.nc",\
                "EMIT_L1B_RAD_001_20250413T113146_2510307_040.nc",\
                "EMIT_L1B_RAD_001_20241007T051258_2428104_006.nc",\
                "EMIT_L1B_RAD_001_20241003T064746_2427705_017.nc"]
granules = {}

In [None]:
for granule_name in granule_names:
    granule = download_granule(granule_name)
    pred = model_predict(granule)
    crs_utm = georeader.get_utm_epsg(granule.footprint("EPSG:4326"))
    emit_image_utm = granule.to_crs(crs_utm)
    predgeo = emit_image_utm.georreference(pred[0], fill_value_default=0)
    final_maxima_coords, final_geo_coordinates = get_maxima(pred)
    visualization = visualize_output(predgeo, final_geo_coordinates)
    granules[granule_name] = {"granule":granule, "pred":pred, "predgeo":predgeo, "maxima":final_maxima_coords, "geomaxima":final_geo_coordinates, "visualization":visualization}

In [None]:
granules[granule_names[0]]["visualization"]

In [None]:
for name, g in granules.items():
    export_to_geotiff_compressed(
        g["predgeo"],
        f"starcop_output_{name}.tif",
        compression='DEFLATE'
    )

In [None]:
!ls

## 9. Aggregating Timeline

In [None]:
!pip install rasterio -q

In [None]:
import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.warp import calculate_default_transform, reproject, Resampling
from shapely.geometry import box

def aggregate_rasters_on_union(predgeo_objects):
    """
    Takes a list of misaligned georeferenced raster objects, aligns them to their
    total combined extent (union), and returns two new objects:
    1. The normalized mean signal map (0 to 1).
    2. The observation count map (confidence).

    Args:
        predgeo_objects (list): A list of your georeferenced raster objects.

    Returns:
        tuple: A tuple containing two new georeferenced objects:
            - (mean_signal_geo, observation_count_geo)
    """
    if not predgeo_objects:
        return None, None

    # --- Step 1: Establish Common Grid based on UNION ---
    print("Step 1: Establishing a common grid based on UNION...")
    target_crs = CRS.from_user_input(predgeo_objects[0].crs)
    
    b = predgeo_objects[0].footprint(target_crs).bounds
    total_left, total_bottom, total_right, total_top = b[0], b[1], b[2], b[3]
    for pg in predgeo_objects[1:]:
        b = pg.footprint(target_crs).bounds
        total_left = min(total_left, b[0])
        total_bottom = min(total_bottom, b[1])
        total_right = max(total_right, b[2])
        total_top = max(total_top, b[3])
        
    target_res = (abs(predgeo_objects[0].transform.a), abs(predgeo_objects[0].transform.e))
    common_transform, common_width, common_height = calculate_default_transform(
        target_crs, target_crs, 
        width=int((total_right - total_left) / target_res[0]), 
        height=int((total_top - total_bottom) / target_res[1]),
        left=total_left, bottom=total_bottom, right=total_right, top=total_top
    )

    # --- Step 2 & 3: Reproject and Aggregate ---
    print("\nStep 2 & 3: Reprojecting granules and aggregating...")
    sum_raster = np.zeros((common_height, common_width), dtype=np.float32)
    count_raster = np.zeros((common_height, common_width), dtype=np.int16)
    
    for i, pg in enumerate(predgeo_objects):
        print(f"  - Processing granule {i+1}/{len(predgeo_objects)}...")
        reprojected_granule = np.zeros((common_height, common_width), dtype=np.float32)
        reproject(
            source=pg.values, destination=reprojected_granule,
            src_transform=pg.transform, src_crs=CRS.from_user_input(pg.crs),
            dst_transform=common_transform, dst_crs=target_crs,
            resampling=Resampling.average, dst_nodata=0
        )
        sum_raster += reprojected_granule
        count_raster[reprojected_granule > 0] += 1

    # --- Step 4: Calculate the Mean Signal Map ---
    print("\nStep 4: Calculating the mean signal map...")
    # Nodata value for areas with no coverage
    nodata_value = -1 
    mean_signal_raster = np.divide(
        sum_raster, count_raster, 
        out=np.full_like(sum_raster, fill_value=nodata_value),
        where=count_raster != 0
    )
    
    # --- Step 5: Normalize the Mean Signal Map to [0, 1] ---
    print("Step 5: Normalizing the mean signal map...")
    # Create a mask to ignore nodata values during normalization
    valid_mask = mean_signal_raster != nodata_value
    
    if np.any(valid_mask):
        min_val = np.min(mean_signal_raster[valid_mask])
        max_val = np.max(mean_signal_raster[valid_mask])
        
        if max_val > min_val:
            # Apply min-max normalization only to the valid data
            normalized_raster = np.full_like(mean_signal_raster, fill_value=nodata_value)
            normalized_raster[valid_mask] = (mean_signal_raster[valid_mask] - min_val) / (max_val - min_val)
        else:
            # Handle case where all values are the same
            normalized_raster = np.full_like(mean_signal_raster, fill_value=0.5)
            normalized_raster[~valid_mask] = nodata_value # Restore nodata
    else:
        # Handle case where there is no valid data at all
        normalized_raster = mean_signal_raster

    # --- Step 6: Construct and Return the Final Objects ---
    print("Step 6: Constructing the final georeferenced objects...")
    GeoObjectType = type(predgeo_objects[0])
    
    mean_signal_geo = GeoObjectType(
        values=normalized_raster, 
        transform=common_transform, 
        crs=target_crs.to_string()
    )
    
    observation_count_geo = GeoObjectType(
        values=count_raster,
        transform=common_transform,
        crs=target_crs.to_string()
    )
    
    print("Aggregation complete. Returning mean signal and observation count objects.")
    return mean_signal_geo, observation_count_geo

import numpy as np

def calculate_final_confidence(mean_signal_geo, count_geo):
    """
    Combines a normalized mean signal map and an observation count map into a
    final confidence score, which is then re-normalized to span the full [0, 1] range.

    Args:
        mean_signal_geo (GeoTensor): The object for the normalized mean signal.
        count_geo (GeoTensor): The object for the observation count.

    Returns:
        tuple: A tuple containing two georeferenced objects:
            - final_confidence_geo: The final, re-normalized confidence score map.
            - penalized_score_geo: The intermediate, non-normalized penalized score map (for comparison).
    """
    mean_signal = mean_signal_geo.values
    count_map = count_geo.values.astype(np.float32)
    nodata_value = -1

    # --- Step 1: Create a Count Weight Map [0, 1] ---
    valid_mask = count_map > 0
    count_weight_map = np.full_like(count_map, fill_value=0.0)
    if np.any(valid_mask):
        max_count = np.max(count_map[valid_mask])
        if max_count > 0:
            count_weight_map[valid_mask] = count_map[valid_mask] / max_count

    # --- Step 2: Calculate the Intermediate Penalized Score ---
    signal_valid_mask = mean_signal != nodata_value
    penalized_score_map = np.full_like(mean_signal, fill_value=nodata_value)
    penalized_score_map[signal_valid_mask] = mean_signal[signal_valid_mask] * count_weight_map[signal_valid_mask]

    # --- Step 3: Re-Normalize the Penalized Score Map ---
    final_normalized_map = np.full_like(penalized_score_map, fill_value=nodata_value)
    final_valid_mask = penalized_score_map != nodata_value

    if np.any(final_valid_mask):
        min_score = np.min(penalized_score_map[final_valid_mask])
        max_score = np.max(penalized_score_map[final_valid_mask])
        
        if max_score > min_score:
            # Stretch the penalized scores to the full [0, 1] range
            final_normalized_map[final_valid_mask] = (penalized_score_map[final_valid_mask] - min_score) / (max_score - min_score)
        else:
            # Handle case where all valid scores are the same after penalization
            final_normalized_map[final_valid_mask] = 0.5
    
    # --- Step 4: Construct and Return the Final Georeferenced Objects ---
    GeoObjectType = type(mean_signal_geo)
    
    # Create an object for the final, normalized result
    final_confidence_geo = GeoObjectType(
        values=final_normalized_map,
        transform=mean_signal_geo.transform,
        crs=mean_signal_geo.crs
    )
    
    # Also create an object for the intermediate step for clear comparison
    penalized_score_geo = GeoObjectType(
        values=penalized_score_map,
        transform=mean_signal_geo.transform,
        crs=mean_signal_geo.crs
    )

    return final_confidence_geo

# --- Example Usage ---
predgeo_list = [data["predgeo"] for data in granules.values()]
mean_pred_geo, count_geo = aggregate_rasters_on_union(predgeo_list)

if mean_pred_geo and count_geo:
    print("\nSuccessfully created the aggregated objects.")
    # You can now visualize BOTH `mean_pred_geo` and `count_geo` on separate maps
    # or side-by-side to get the full picture.

final_confidence_score_geo = calculate_final_confidence(mean_pred_geo, count_geo)

# Now you have a single, powerful map to visualize!
if final_confidence_score_geo:
    print("Successfully created the final confidence score map.")
    # You can now pass `final_confidence_score_geo` to your Folium visualization.

final_confidence_score_geo = calculate_final_confidence(mean_pred_geo, count_geo)

if final_confidence_score_geo:
    print("Successfully created the final, re-normalized confidence score map.")
    # The `values` of this object are now guaranteed to be between 0 and 1 (and -1 for nodata).
    # It's ready for perfect visualization.

In [None]:
import rasterio

def export_to_geotiff_compressed(geopred_object, output_filepath, compression='DEFLATE'):
    """
    Exports a georeferenced raster object to a COMPRESSED GeoTIFF file.

    Args:
        geopred_object (GeoTensor): The object to export.
        output_filepath (str): The path to save the new .tif file.
        compression (str): The compression method to use.
                           Recommended: 'DEFLATE', 'LZW'.
                           For visualization only: 'JPEG', 'WEBP'.
    """
    raster_data = geopred_object.values
    transform = geopred_object.transform
    crs = geopred_object.crs
    nodata_value = -1

    metadata = {
        'driver': 'GTiff',
        'height': raster_data.shape[0],
        'width': raster_data.shape[1],
        'count': 1,
        'dtype': raster_data.dtype,
        'crs': crs,
        'transform': transform,
        'nodata': nodata_value,
        'compress': compression  # <-- THE KEY ADDITION
    }
    
    # For some compression types, you can add extra options
    # For example, for JPEG:
    # with rasterio.open(..., compress='JPEG', jpeg_quality=85) as dst:

    with rasterio.open(output_filepath, 'w', **metadata) as dst:
        dst.write(raster_data, 1)
        
    print(f"Raster data successfully exported to {output_filepath} with {compression} compression.")

# --- Example Usage ---
# Export with the recommended lossless compression
export_to_geotiff_compressed(
    final_confidence_score_geo,
    "final_methane_confidence_map_compressed.tif",
    compression='DEFLATE'
)

In [None]:
import numpy as np

def compare_signal_and_confidence(mean_signal_geo, count_geo, final_confidence_score_geo):
    """
    Compares average AND maximum pixel values before and after applying the confidence score.
    Provides both a global summary and a detailed breakdown by observation count.

    Args:
        mean_signal_geo (GeoTensor): The georeferenced object for the normalized mean signal.
        count_geo (GeoTensor): The georeferenced object for the observation count.
        final_confidence_score_geo (GeoTensor): The georeferenced object for the final score.
    """
    # Extract the NumPy arrays from the georeferenced objects
    mean_signal = mean_signal_geo.values
    count_map = count_geo.values
    final_confidence = final_confidence_score_geo.values
    
    # Define the nodata value used in your rasters
    nodata_value = -1

    print("="*60)
    print("Comparison of Mean Signal vs. Final Confidence Score")
    print("="*60)

    # --- 1. Global Summary Comparison (Average and Max) ---
    print("\n--- Global Summary (across all pixels with data) ---")

    # Create a mask for all valid pixels
    valid_mask = count_map > 0

    if not np.any(valid_mask):
        print("No valid data found in the maps to compare.")
        return

    # Calculate global averages
    avg_before = np.mean(mean_signal[valid_mask])
    avg_after = np.mean(final_confidence[valid_mask])
    
    # --- NEW: Calculate global maximums ---
    max_before = np.max(mean_signal[valid_mask])
    max_after = np.max(final_confidence[valid_mask])
    
    print(f"Average Signal (Before Scoring):   {avg_before:.4f}")
    print(f"Average Confidence (After Scoring):  {avg_after:.4f}")
    print("-" * 40)
    print(f"Maximum Signal (Before Scoring):   {max_before:.4f}")
    print(f"Maximum Confidence (After Scoring):  {max_after:.4f}")
    
    # --- 2. Breakdown by Observation Count ---
    print("\n--- Breakdown by Observation Count ---")

    unique_counts = np.unique(count_map[valid_mask])

    for count in unique_counts:
        # Create a specific mask for pixels with this exact observation count
        count_mask = count_map == count
        num_pixels = np.sum(count_mask)

        # Calculate averages for this subset
        avg_before_subset = np.mean(mean_signal[count_mask])
        avg_after_subset = np.mean(final_confidence[count_mask])
        
        # --- NEW: Calculate maximums for this subset ---
        max_before_subset = np.max(mean_signal[count_mask])
        max_after_subset = np.max(final_confidence[count_mask])
        
        print(f"\nFor {num_pixels} pixels seen {int(count)} time(s):")
        print(f"  - Avg Signal Before: {avg_before_subset:.4f}  |  Max Signal Before: {max_before_subset:.4f}")
        print(f"  - Avg Signal After:  {avg_after_subset:.4f}  |  Max Signal After:  {max_after_subset:.4f}")

    print("\n" + "="*60)
    print("Analysis complete.")


# --- Example Usage ---
# Assume `mean_pred_geo`, `count_geo`, and `final_confidence_score_geo`
# are the outputs from your previous, corrected functions.

compare_signal_and_confidence(mean_pred_geo, count_geo, final_confidence_score_geo)


# --- Example Usage ---
# Assume `mean_pred_geo`, `count_geo`, and `final_confidence_score_geo`
# are the outputs from your previous functions.

# Just call the function with your three objects.
compare_signal_and_confidence(mean_pred_geo, count_geo, final_confidence_score_geo)

In [None]:
visualize_output(final_confidence_score_geo, heatmap_threshold=0.6)

# Run Project-Eucalyptus

## Loading Model

In [1]:
!git clone https://github.com/Orbio-Earth/Project-Eucalyptus.git

fatal: destination path 'Project-Eucalyptus' already exists and is not an empty directory.


In [2]:
%cd Project-Eucalyptus/notebooks

/kaggle/working/Project-Eucalyptus/notebooks


In [3]:
# Choose ONE of the following commands based on how you uploaded the file.

# If you used Option A (Added as a Dataset):
# !pip install -r /kaggle/input/my-project-reqs/requirements.txt --upgrade --no-cache-dir

# If you used Option B (Uploaded directly):
!pip install -r ../requirements.txt --upgrade --no-cache-dir -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m125.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m251.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m293.4 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m327.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:

You have to restart the kernel here.

In [4]:
!pip install georeader-spaceml -q
!pip install netCDF4 spectral -q
!pip install fsspec gcsfs omegaconf segmentation_models_pytorch hydra-core ipython rasterio  geopandas -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.8/174.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m338.4/338.4 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.
bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.[0m[31m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m249.0/249.0 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.4/194.4 kB[0m [31m7.0 MB/s[0m eta [36m0

In [5]:
import numpy as np
import torch
import xarray as xr
from utils import EMIT_SCALING_FACTOR, plot_predictions, predict

In [6]:
from huggingface_hub import hf_hub_download
from georeader.readers import emit, download_utils
from georeader import plot
import os
import torch
import omegaconf
import numpy as np
import matplotlib.pyplot as plt
import georeader

In [7]:
file_name = "EMIT_L1B_RAD_001_20250813T111228_2522507_037.nc"
token = "eyJ0eXAiOiJKV1QiLCJvcmlnaW4iOiJFYXJ0aGRhdGEgTG9naW4iLCJzaWciOiJlZGxqd3RwdWJrZXlfb3BzIiwiYWxnIjoiUlMyNTYifQ.eyJ0eXBlIjoiVXNlciIsInVpZCI6Im1haF9zYW0iLCJleHAiOjE3NjA1NzI3OTksImlhdCI6MTc1NTM1NTU0NywiaXNzIjoiaHR0cHM6Ly91cnMuZWFydGhkYXRhLm5hc2EuZ292IiwiaWRlbnRpdHlfcHJvdmlkZXIiOiJlZGxfb3BzIiwiYWNyIjoiZWRsIiwiYXNzdXJhbmNlX2xldmVsIjozfQ.459S_fTfcXafo7Yk017iH4iclfb0z41d57Rj8oTdGYRWCsGD2oGTJGYLrKKHot3JjfgS1kyc33YSAY5JmHy5I0CaYaneDriJIJrBHifka-IJBC3bGow331XgUbiAV_WOM_kJ1ReMSVz7lqr1NiMvf6YGMnE6N9QAQcGCL8UgRZNTYJx8qUrBlpYiPA3p-FQYSbYkINlRjr7myJUkiAHItQ-CTxuAuxbO0j2rdi5Hx1Aldze7NBovlBLjBNnHW6LI9XhaMvlJciXVAXw9T_AuAoD55qkP_ACRIje7hCLz_s-1srjXUGViQfFGy3ixqZuiQ6GYzXRGOU9kJPrjYUWbrA"
earthdata_nasa_account = True

def download_granule(granule_name, token=token, earthdata_nasa_account=earthdata_nasa_account):
    if earthdata_nasa_account:
        link = emit.get_radiance_link(granule_name)
        emit.AUTH_METHOD = "token"
        emit.TOKEN = token # copy your token here
        headers = {"Authorization": f"Bearer {emit.TOKEN}"}
    
        product = download_utils.download_product(link, headers=headers,  verify=True)
    
        rst = emit.EMITImage(product)
    return rst

In [8]:
import folium
from folium.plugins import HeatMap
import numpy as np
from pyproj import Transformer
import branca # Folium's parent library
from jinja2 import Template # Used to create the custom element's template

# ==============================================================================
# This script generates a map where clicking anywhere creates a popup with a
# link to Google Maps.
#
# REQUIRED INPUT VARIABLES:
# - predgeo: The georeferenced GeoTensor of the prediction heatmap.
# - final_geo_coordinates: A list of {'lat': ..., 'lon': ...} dicts for each detected source.
#
# ==============================================================================

# --- Step 1 & 2: Prepare Data and Transform Coordinates (Same as before) ---
def visualize_output(predgeo, final_geo_coordinates=None, heatmap_threshold=0.4):
    prediction_pixels = predgeo.values
    rows, cols = np.where(prediction_pixels >= heatmap_threshold)
    scores = prediction_pixels[rows, cols].astype(float)
    affine_transform = predgeo.transform
    source_crs = predgeo.crs
    target_crs = "EPSG:4326"
    transformer = Transformer.from_crs(source_crs, target_crs, always_xy=True)
    utm_x, utm_y = affine_transform * (cols + 0.5, rows + 0.5)
    lons, lats = transformer.transform(utm_x, utm_y)
    heatmap_data = list(zip(lats, lons, scores))
    
    # --- Step 3: Define Map Center and Create the Base Map (Same as before) ---
    if final_geo_coordinates:
        map_center = [final_geo_coordinates[0]['lat'], final_geo_coordinates[0]['lon']]
        points_layer = folium.FeatureGroup(name='Detected Plume Sources', show=True)
        for i, coords in enumerate(final_geo_coordinates):
            lat, lon = coords['lat'], coords['lon']
            popup_html = f"<b>Candidate #{i+1}</b><br>Lat: {lat:.6f}, Lon: {lon:.6f}<br><a href='https://www.google.com/maps?q={lat},{lon}' target='_blank'>Google Maps</a>"
            folium.Marker(
                location=[lat, lon],
                popup=folium.Popup(popup_html, max_width=300),
                icon=folium.Icon(color='cyan', icon='cloud', prefix='fa')
            ).add_to(points_layer)
    else:
        footprint_wgs84 = predgeo.footprint("EPSG:4326")
        min_lon, min_lat, max_lon, max_lat = footprint_wgs84.bounds
        map_center = [(min_lat + max_lat) / 2, (min_lon + max_lon) / 2]
    m_final = folium.Map(location=map_center, zoom_start=14, tiles="Esri.WorldImagery")
    
    # --- Step 4: Create and Populate All Layers (Same as before) ---
    heatmap_layer = folium.plugins.HeatMap(data=heatmap_data, name='Dynamic Methane Heatmap')
    
    # --- Step 5: Add "Click for Popup" Functionality (THE MODIFIED VERSION) ---
    
    # We modify our custom class to create a Leaflet popup instead of opening a new window.
    class ClickForPopup(branca.element.MacroElement):
        _template = Template(u"""
            {% macro script(this, kwargs) %}
                function create_popup_on_click(e) {
                    var lat = e.latlng.lat;
                    var lon = e.latlng.lng;
                    var url = `https://www.google.com/maps?q=${lat},${lon}`;
                    
                    // Create the HTML content for the popup
                    var html = `
                        <b>Location Info</b><br>
                        Latitude: ${lat.toFixed(6)}<br>
                        Longitude: ${lon.toFixed(6)}
                        <hr style="margin: 5px 0;">
                        <ul>
                            <li><a href="${url}" target="_blank" rel="noopener noreferrer">View in Google Maps</a></li>
                        </ul>
                    `;
                    
                    // Create a Leaflet popup object and open it on the map
                    var popup = L.popup()
                        .setLatLng(e.latlng)
                        .setContent(html)
                        .openOn({{this._parent.get_name()}});
                }
                // Attach the function to the map's click event
                {{this._parent.get_name()}}.on('click', create_popup_on_click);
            {% endmacro %}
            """)
    
        def __init__(self):
            super(ClickForPopup, self).__init__()
            self._name = 'ClickForPopup'
    
    # Create an instance of our custom element and add it to the map.
    m_final.add_child(ClickForPopup())
    
    
    # --- Step 6: Add Layers to the Map and Save ---
    heatmap_layer.add_to(m_final)
    if final_geo_coordinates:
        points_layer.add_to(m_final)
    folium.LayerControl().add_to(m_final)
    output_map_path_final = 'methane_detection_map_popup.html'
    m_final.save(output_map_path_final)
    print(f"Interactive map with popups saved to: {output_map_path_final}")
    return m_final

#m_final = visualize_output(predgeo, final_geo_coordinates)
#m_final

In [16]:
model = torch.load("./resources/emit/model.pth", weights_only=False, map_location="cpu")

## Run Inference

In [17]:
granule_names = ["EMIT_L1B_RAD_001_20250730T075636_2521105_018.nc",\
                "EMIT_L1B_RAD_001_20250813T111228_2522507_037.nc",\
                "EMIT_L1B_RAD_001_20250417T095633_2510706_024.nc",\
                "EMIT_L1B_RAD_001_20250413T113146_2510307_040.nc",\
                "EMIT_L1B_RAD_001_20241007T051258_2428104_006.nc",\
                "EMIT_L1B_RAD_001_20241003T064746_2427705_017.nc"]

In [None]:
# %% [code] {"execution":{"execution_failed":"2025-09-09T05:43:18.326Z"},"jupyter":{"outputs_hidden":false}}
import torch
import torch.nn.functional as F
import xarray as xr
import georeader
from georeader.readers import emit
import gc # Garbage Collector
import numpy as np

# Import the necessary tools from rasterio, which you already use elsewhere
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.crs import CRS

euca_preds = {}
for file_name in granule_names:
    print(f"\n--- Processing: {file_name} ---")
    file_path = f"../../STARCOP/{file_name}"
    
    with xr.open_dataset(file_path) as ds:
        radiance = ds["radiance"].load().transpose("bands", "downtrack", "crosstrack")
        x = torch.from_numpy(radiance.values * EMIT_SCALING_FACTOR)[None, ...]
        
    with torch.no_grad():
        original_h, original_w = x.shape[-2:]
        
        # --- 1. PAD, PREDICT, AND CROP ---
        output_stride = 32
        pad_h = (output_stride - original_h % output_stride) % output_stride
        pad_w = (output_stride - original_w % output_stride) % output_stride
        padding = (0, pad_w, 0, pad_h)
        x_padded = F.pad(x, padding, mode='constant', value=0)
        yhat_padded = predict(model, x_padded)
        yhat_final = yhat_padded[:, :, :original_h, :original_w]

        # --- [NEW] Print model output statistics ---
        print("\n--- Model Output Statistics (yhat_final) ---")
        # Squeeze the batch dimension to get a tensor of shape [3, H, W]
        pred_channels = yhat_final.squeeze(0)
        for i in range(pred_channels.shape[0]):
            channel_data = pred_channels[i]
            min_val = torch.min(channel_data).item()
            max_val = torch.max(channel_data).item()
            avg_val = torch.mean(channel_data).item()
            print(f"  Channel {i}:")
            print(f"    Min: {min_val:.6f}")
            print(f"    Max: {max_val:.6f}")
            print(f"    Avg: {avg_val:.6f}")
        print("--------------------------------------------")
        
        # --- 2. PREPARE NUMPY ARRAY ---
        pred_tensor = yhat_final.squeeze(0)[0] 
        pred_numpy = pred_tensor.cpu().numpy()

    # --- 3. REPROJECT THE PREDICTION DATA (THE CORRECTED WORKFLOW) ---
    
    # a) Load the original granule to define the SOURCE grid
    original_granule = emit.EMITImage(file_path)
    src_transform = original_granule.transform
    src_crs = CRS.from_user_input(original_granule.crs)

    # b) Create the reprojected granule to define the TARGET grid
    target_crs_utm_str = georeader.get_utm_epsg(original_granule.footprint("EPSG:4326"))
    emit_image_utm = original_granule.to_crs(target_crs_utm_str)
    dst_transform = emit_image_utm.transform
    dst_crs = CRS.from_user_input(emit_image_utm.crs)
    
    # c) Create an empty array with the TARGET dimensions
    reprojected_pred = np.zeros((emit_image_utm.height, emit_image_utm.width), dtype=pred_numpy.dtype)

    # d) Warp the source data (pred_numpy) into the target array (reprojected_pred)
    reproject(
        source=pred_numpy,
        destination=reprojected_pred,
        src_transform=src_transform,
        src_crs=src_crs,
        dst_transform=dst_transform,
        dst_crs=dst_crs,
        resampling=Resampling.bilinear, # Bilinear is a good choice for continuous data
        dst_nodata=0
    )

    # --- 4. GEOREFERENCE THE REPROJECTED DATA ---
    # The dimensions now match perfectly, so this will succeed.
    predgeo_utm = emit_image_utm.georreference(reprojected_pred, fill_value_default=0)

    print(f"Successfully created georeferenced map with shape: {predgeo_utm.values.shape}")
    
    visualization = visualize_output(predgeo_utm, heatmap_threshold=0)
    euca_preds[file_name] = {
        "pred": yhat_final, 
        "predgeo": predgeo_utm, 
        "visualization": visualization
    }
    
    # Clean up memory before the next loop
    del original_granule, emit_image_utm, predgeo_utm, pred_numpy, pred_tensor, reprojected_pred
    gc.collect()

print("\n--- All granules processed successfully! ---")

In [20]:
euca_preds[granule_names[0]]["visualization"]

In [21]:
euca_preds = {}
for file_name in granule_names:
    print(f"Processing granule: {file_name}")

    # Step 1: Load radiance data from the NetCDF file
    # Use os.path.join to ensure correct path construction
    granule_path = os.path.join("../../STARCOP/", file_name)
    with xr.open_dataset(granule_path) as ds:
        radiance = ds["radiance"].load().transpose("bands", "downtrack", "crosstrack")
        x = torch.from_numpy(radiance.values * EMIT_SCALING_FACTOR)[None, ...]
    
    # Get original dimensions for cropping
    original_h, original_w = x.shape[-2:]
    
    # Step 2: Pad the input tensor for model compatibility
    output_stride = 32
    pad_h = (output_stride - original_h % output_stride) % output_stride
    pad_w = (output_stride - original_w % output_stride) % output_stride
    padding_tuple = (0, pad_w, 0, pad_h) # (pad_left, pad_right, pad_top, pad_bottom)
    x_padded = F.pad(x, padding_tuple, mode='constant', value=0)
    
    # Step 3: Run inference
    with torch.no_grad(): # Disable gradient calculations to save memory
        yhat_padded = predict(model, x_padded)
    
    # Step 4: Crop the output to original dimensions
    # Ensure yhat_final is on CPU and converted to NumPy for georeferencing
    yhat_final_cpu_np = yhat_padded[:, :original_h, :original_w].cpu().numpy()
    
    # Step 5: Download/Load the EMIT granule for georeferencing
    # Pass the full path to download_granule if it expects it, or just the name
    # depending on how your download_granule is implemented.
    # Assuming it expects the full path to check for local existence.
    granule_obj = download_granule(granule_name=file_name) # Assuming download_granule handles path correctly

    # Step 6: Reproject the EMIT granule to UTM
    crs_utm = georeader.get_utm_epsg(granule_obj.footprint("EPSG:4326"))
    emit_image_utm = granule_obj.to_crs(crs_utm)

    # Step 7: Georeference the model prediction
    # Use the first channel (index 0) of the prediction for the heatmap
    predgeo = emit_image_utm.georreference(yhat_final_cpu_np[0], fill_value_default=0)
    
    # Step 8: (Optional) Find maxima and generate visualization
    # If you need maxima, you must call get_maxima here.
    # For now, I'm adapting based on your commented out line and the visualize_output signature.
    # If get_maxima is not called, final_geo_coordinates would be undefined.
    # Let's assume for now visualize_output can work without final_geo_coordinates
    # or you'll re-enable get_maxima.
    
    # If you want to use get_maxima, uncomment and ensure it's defined in this scope:
    # final_maxima_coords, final_geo_coordinates = get_maxima(predgeo.values) # get_maxima expects a numpy array
    # visualization = visualize_output(predgeo, final_geo_coordinates)
    
    # If not using get_maxima for Eucalyptus, call visualize_output without it
    visualization = visualize_output(predgeo) 
    
    # Store results
    euca_preds[file_name] = {
        "pred": yhat_final_cpu_np, # Store the numpy array
        "predgeo": predgeo,
        # "maxima": final_maxima_coords, # Uncomment if using get_maxima
        # "geomaxima": final_geo_coordinates, # Uncomment if using get_maxima
        "visualization": visualization
    }
    
    # Explicitly delete large objects no longer needed to free memory
    del x, x_padded, yhat_padded, yhat_final_cpu_np
    del granule_obj, emit_image_utm, predgeo
    
    # Force Python's garbage collector to run
    gc.collect()
    
    # If using CUDA, clear the GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\nAll granules processed for Project-Eucalyptus.")

Processing granule: EMIT_L1B_RAD_001_20250730T075636_2521105_018.nc


KeyboardInterrupt: 

In [None]:
subplot_props = {
    "likelihood": {
        "title": "Likelihood score",
        "imshow_kwargs": {"cmap": "Greys", "vmin": 0, "vmax": 1},
    },
    "conditional": {
        "title": "Conditional prediction ({units})",
        "imshow_kwargs": {"cmap": "Reds", "vmin": 0},
    },
    "marginal": {
        "title": "Marginal prediction ({units})",
        "imshow_kwargs": {"cmap": "Reds", "vmin": 0},
    },
}

plot_predictions(
    yhat_gamma_concentration, subplot_props, units=r"$\gamma \cdot mol/m^2$"
)

In [None]:
pred2geo = emit_image_utm.georreference(yhat_final[0], fill_value_default=0)
transform = pred2geo.transform
#utm_x, utm_y = transform * (max_col + 0.5, max_row + 0.5)
source_crs = pred2geo.crs

In [None]:
visualize_output(pred2geo)

# Ensembling

In [None]:
# %% [markdown]
# # Ensembling STARCOP and Project-Eucalyptus
# 
# To create a more robust prediction, we can ensemble the outputs of the two models. The most direct approach is to combine their respective likelihood scores. We will perform a weighted average of the STARCOP confidence map and the Project-Eucalyptus `likelihood` map.

# %% [code]
import numpy as np
import gc

# --- Configuration for Ensembling ---
# You can adjust these weights. They must sum to 1.0.
# A 50/50 split is a good starting point.
WEIGHT_STARCOP = 0.5
WEIGHT_EUCALYPTUS = 0.5

assert WEIGHT_STARCOP + WEIGHT_EUCALYPTUS == 1.0, "Weights must sum to 1.0"

# This dictionary will store the final ensembled results for each granule
ensembled_results = {}

print("--- Starting Model Ensembling ---")

for granule_name in granule_names:
    print(f"\nProcessing granule: {granule_name}")

    # --- 1. Retrieve the prediction objects for this granule ---
    starcop_predgeo = granules[granule_name]["predgeo"]
    euca_likelihood_geo = euca_preds[granule_name]["likelihood"]["predgeo"]

    # --- 2. Perform Sanity Checks ---
    # Ensure the georeferenced objects are on the same grid before combining them.
    # This is crucial for accurate pixel-wise operations.
    print("  - Performing sanity checks...")
    assert starcop_predgeo.values.shape == euca_likelihood_geo.values.shape, \
        f"Shape mismatch: STARCOP is {starcop_predgeo.values.shape}, Eucalyptus is {euca_likelihood_geo.values.shape}"
    assert starcop_predgeo.crs == euca_likelihood_geo.crs, "CRS mismatch"
    # Note: A direct transform comparison can sometimes fail due to float precision.
    # We rely on the shape and CRS checks which are sufficient given our pipeline.
    print("  - Sanity checks passed. Grids are aligned.")

    # --- 3. Perform the Weighted Average ---
    print("  - Calculating weighted average...")
    starcop_array = starcop_predgeo.values
    euca_array = euca_likelihood_geo.values

    # The core ensembling calculation
    ensembled_array = (WEIGHT_STARCOP * starcop_array) + (WEIGHT_EUCALYPTUS * euca_array)
    
    # --- 4. Create a new GeoTensor for the ensembled result ---
    # We can reuse the georeferencing info from one of the inputs since they are aligned.
    GeoObjectType = type(starcop_predgeo)
    ensembled_geo = GeoObjectType(
        values=ensembled_array,
        transform=starcop_predgeo.transform,
        crs=starcop_predgeo.crs
    )
    print("  - Created new ensembled GeoTensor.")

    # --- 5. Visualize and Store the Result ---
    visualization = visualize_output(ensembled_geo)
    
    ensembled_results[granule_name] = {
        "predgeo": ensembled_geo,
        "visualization": visualization
    }
    
    # Clean up memory
    del starcop_predgeo, euca_likelihood_geo, ensembled_geo
    gc.collect()

print("\n--- All granules have been ensembled successfully! ---")

# %% [markdown]
# You can now display the ensembled map for any granule. For example, for the first one:

# %% [code]
# Display the visualization for the first granule in the list
ensembled_results[granule_names[0]]["visualization"]