Note: This notebook assumes that "gmn_extract_candidate_contrails.ipynb" and "gmn_run_segmentation.ipynb" has been run.

In [1]:
import sys
sys.path.append("..")

from datetime import datetime, timezone
from glob import glob
import json
import os
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import imageio

from matplotlib import pyplot as plt
from PIL import Image

from pycontrails.datalib.goes import GOES as pycon_GOES, extract_goes_visualization
from segmentation_utils import masks_to_geo

In [2]:
stations  = ['US001N']
dates = ['20230805']
region = 'C'  # For satellite imagery

# The intended use for this example is to only lift the segmentation masks to geo, if a contrail has been detected (see gmn_example4_classify_candidates.ipynb)
# But for simplicity, here we are manually filtering for a flight that has a contrail
flight_id_filter = ['3C6514_DLH413']

In [3]:
# Prepare function for GOES visualization
def fetch_goes(time, region):
    date = datetime.fromtimestamp(time, timezone.utc)
    year, month, day_of_month = date.year, date.month, date.day
    hour, minute = date.hour, date.minute

    image_time = np.datetime64(f"{year:04d}-{month:02d}-{day_of_month:02d}T{hour:02d}:{minute:02d}")

    # Using pycontrails to load GOES data
    pd_time = pd.to_datetime(image_time)
    pycon_goes = pycon_GOES(region=region)
    da = pycon_goes.get(pd_time)
    rgb, transform, extent = extract_goes_visualization(da)

    return rgb, transform, extent

def save_images_for_video(frames, sam2_output, output_dir, mask_longitudes, mask_latitudes, bounding_box, image_corners):
    prev_time = None
    for idx, frame in enumerate(frames):
        image = Image.open(frame['file_path']).convert('RGBA')
        sam2_mask = Image.open(sam2_output[idx]).convert('RGBA')
        combined_image = Image.blend(image, sam2_mask, 0.1)
        combined_image.save(os.path.join(output_dir, f'cam_{idx:04d}.png'))

        # Get satellite imagery at the same time
        time = frame['time']
        if region == "F":
            rounded_time = time - time % 600  # Every 10 minutes
        elif region == "C":
            rounded_time = time - time % 300  # Every 5 minutes

        if prev_time != rounded_time:
            prev_time = rounded_time
            rgb, transform, extent = fetch_goes(rounded_time, region)
            
        # Create figures
        fig = plt.figure(figsize=(5.5, 5.5))
        pc = ccrs.PlateCarree()
        ax = fig.add_subplot(projection=pc, extent=bounding_box)
        ax.coastlines()  # type: ignore[attr-defined]
        ax.axis('off')
        ax.imshow(rgb, extent=extent, transform=transform)
        ax.scatter(mask_longitudes[idx], mask_latitudes[idx], edgecolors=(0, 0, 0), transform=pc, s=0.1, facecolor='none')
        ax.scatter(mask_longitudes[idx], mask_latitudes[idx], color=(1, 0.5, 0.5), transform=pc, s=0.1, edgecolors='none')
        fig.savefig(os.path.join(output_dir, f'sat_with_mask_{idx:04d}.png'), dpi=300, bbox_inches='tight')

        plt.close('all')

        fig = plt.figure(figsize=(5.5, 5.5))
        pc = ccrs.PlateCarree()
        ax = fig.add_subplot(projection=pc, extent=bounding_box)
        ax.coastlines()  # type: ignore[attr-defined]
        ax.axis('off')
        ax.imshow(rgb, extent=extent, transform=transform)
        fig.savefig(os.path.join(output_dir, f'sat_without_mask_{idx:04d}.png'), dpi=300, bbox_inches='tight')

        plt.close('all')

In [4]:
# Run conversion from segmentation masks to latitudes and longitudes
for station in stations:
    for date in dates:
        target_dir = f'../data/gmn_extracted_flight_images/{station}/{date}/'
        masks_to_geo(target_dir, flight_id_filter)

Loading config file: d:\Clouds\tmp_gmn_contrails\src\data\gmn_extracted_flight_images\US001N\20230805\.config


In [5]:
# Visualise on satellite imagery (output is saved as video files in "/flight_dir/debug")
for station in stations:
    for date in dates:
        target_dir = f'../data/gmn_extracted_flight_images/{station}/{date}/'

        meta_data_files = glob(os.path.join(target_dir, '*/metadata/*.json'))

for meta_data_file in meta_data_files:
    images = []

    flight_dir = os.path.dirname(os.path.dirname(meta_data_file))
    flight_id = os.path.basename(flight_dir)

    if flight_id not in flight_id_filter:
        continue
    
    output_dir = os.path.join(flight_dir, 'debug')
    os.makedirs(output_dir, exist_ok=True)

    with open(meta_data_file, 'r') as f:
        data = json.load(f)

    # Get the mask latitude and longitudes
    if not os.path.exists(os.path.join(flight_dir, 'mask_geo', 'mask_geo.json')):
        print(f"Skipping {flight_dir} as mask_geo.json does not exist")
    else:
        with open(os.path.join(flight_dir, 'mask_geo', 'mask_geo.json'), 'r') as f:
            mask_geo = json.load(f)
        sorted_keys = sorted(mask_geo['masks'].keys())
        latitudes = [mask_geo['masks'][key]['lats'] for key in sorted_keys]
        longitudes = [mask_geo['masks'][key]['lons'] for key in sorted_keys]
        image_corners = mask_geo['image_corners']
        
    # Visualise frames after flight
    sam2_output = sorted(glob(os.path.join(flight_dir, 'sam2_output/*.png')))
    
    padding = 0.25
    bounding_box = [min(longitudes[0]) - padding, max(longitudes[0]) + padding, min(latitudes[0]) - padding, max(latitudes[0]) + padding]

    save_images_for_video(data['frames_after_flight'], sam2_output, output_dir, longitudes, latitudes, bounding_box, image_corners)    

In [6]:
for meta_data_file in meta_data_files:
    images = []

    flight_dir = os.path.dirname(os.path.dirname(meta_data_file))
    flight_id = os.path.basename(flight_dir)

    if flight_id not in flight_id_filter:
        continue

    # Convert saved images to video
    cam_images = sorted(glob(os.path.join(flight_dir, 'debug', 'cam_*.png')))
    sat_with_mask_images = sorted(glob(os.path.join(flight_dir, 'debug', 'sat_with_mask_*.png')))
    sat_without_mask_images = sorted(glob(os.path.join(flight_dir, 'debug', 'sat_without_mask_*.png')))

    # Concatenate horizontally into single video
    frames = []
    for idx in range(len(cam_images)):
        cam_image = Image.open(cam_images[idx])
        sat_with_mask_image = Image.open(sat_with_mask_images[idx])
        sat_without_mask_image = Image.open(sat_without_mask_images[idx])

        combined_image = Image.new('RGB', (cam_image.width + sat_with_mask_image.width + sat_without_mask_image.width, max(cam_image.height, sat_with_mask_image.height, sat_without_mask_image.height)))
        combined_image.paste(cam_image, (0, 0))
        combined_image.paste(sat_with_mask_image, (cam_image.width, 0))
        combined_image.paste(sat_without_mask_image, (cam_image.width + sat_with_mask_image.width, 0))
        # combined_image.save(os.path.join(flight_dir, 'debug', f'combined_{idx:04d}.png'))
        frames.append(combined_image)

    imageio.mimwrite(os.path.join(flight_dir, 'debug', 'combined.mp4'), frames, fps=10, quality=8)

    # Remove the saved images as well
    for idx in range(len(cam_images)):
        os.remove(cam_images[idx])
        os.remove(sat_with_mask_images[idx])
        os.remove(sat_without_mask_images[idx])

