# Methane Plume Detector

In [None]:
import numpy as np
import pandas as pd
import lightning.pytorch as pl 
import wandb
import ee
import json
import pickle
import torch
import torchvision

from src.models.methane_unet import UNetMethane

In [None]:
obs_date = "2022-10-28"
obs_lat = 32.28470055
obs_lon = -108.284939
query_dict = {
    obs_date: (obs_lat, obs_lon)
}

# Define the window size of the queried scene in meters. Window will be a square of size 2*buffer_distance x 2*buffer_distance
buffer_distance = 2000
# Define the cloud threshold for the queried scene (in %)
cloud_threshold = 0
# Define satellite
satellite = "S2"

# Define model used
# Import an object from a s3 bucket called methane-detector, the s3 bucket is located in the same account as this repo. The object name is unet_64_standardized.ckpt


model_name = "unet_64_standardized.ckpt"
base_filters = 64 # This should be a power of two and can be found in the name of the model

## Query EO data

In [None]:
# Authenticate to the GEE API
ee.Authenticate

In [None]:
from src.tools.gee import query_ee_band_data
ee.Initialize()

raw_scene_list = []
source_rates = []
for date, coords in query_dict.items():
    lat = coords[0]
    lon = coords[1]
    try:
        source_rates.append(coords[2])
    except:
        source_rates.append(0)
    
    raw_scene, scene_date, successful_obs = query_ee_band_data(
        lat,
        lon,
        lat_shift=0,
        lon_shift=0,
        buffer_distance=buffer_distance,
        start_date=date,
        n_days=1,
        satellite_name=satellite,
        cloud_threshold=cloud_threshold,
        get_no2_bands=True,
        get_ch4_bands=True,
        get_aux_bands=True,
        verbose=True,
        cache=None,
    )

    raw_scene_list.append(raw_scene)

## Preprocess data

In [None]:
# Define preprocessing function
from src.tools.image import antialiasing_filter, change_resolution

def preprocess_scene(scene):
    channels = [
        "B2",
        "B3",
        "B4",
        "B8",
        "B11",
        "B12",
    ]
    for c in channels:
        scene[c] = change_resolution(scene[c], target_resolution=scene["B2"].shape)
        scene[c] = np.where(scene[c] == 0, np.nan, scene[c])
    scene = antialiasing_filter(scene, channels, sigma=0.5)
        
    return scene

def compute_ndi(scene):
    ndmi = (scene["B11"] - scene["B12"]) / (scene["B11"] + scene["B12"])
    ndbi = (scene["B11"] - scene["B8"]) / (scene["B11"] + scene["B8"])
    ndvi = (scene["B8"] - scene["B4"]) / (scene["B8"] + scene["B4"])
    bsi = ((scene["B11"] + scene["B4"]) - (scene["B8"] + scene["B2"])) / ((scene["B11"] + scene["B4"]) + (scene["B8"] + scene["B2"]))

    scene["ndmi"] = ndmi
    scene["ndbi"] = ndbi
    scene["ndvi"] = ndvi
    scene["bsi"] = bsi
    
    return scene

In [None]:
from src.tools.image import standardize

trusted_scene_list = []
refined_scene_list = []

for raw_data in raw_scene_list:
    trusted_scene = preprocess_scene(raw_scene)
    trusted_scene_ndi = compute_ndi(trusted_scene)

    trusted_scene_list.append(trusted_scene_ndi)
    refined_scene = standardize(trusted_scene_ndi)
    refined_scene_list.append(refined_scene)

In [None]:
# Display one of the queried scene's SWIR12 and NDMI bands
import matplotlib.pyplot as plt

scene = refined_scene_list[0]

f = plt.figure(figsize=(8, 4))

subplot1 = f.add_subplot(1, 2, 1)
im1 = subplot1.imshow(scene["B12"], cmap="Greys_r")
subplot1.set_title(f"SWIR 12 - {scene_date}")

subplot2 = f.add_subplot(1, 3, 1)
im2 = subplot1.imshow(scene["ndmi"], cmap="viridis")
subplot2.set_title(f"NDMI - {scene_date}")

plt.show()

## Build dataset for inference

In [None]:
from src.model import TestDataset, collate_fn

test_dataset = TestDataset(refined_scene_list)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1, 
    num_workers=1,
    collate_fn=collate_fn
)
len(test_dataloader)

In [None]:
import boto3

model_path = f"models/{model_name}"
s3 = boto3.resource('s3')
checkpoint = s3.Bucket('methane-detector').download_file('unet_64_standardized.ckpt', 'unet_64_standardized.ckpt')
checkpoint.keys()

In [None]:
from src.model import CH4UNet

# Define model backbone
model = CH4UNet(in_channels=10, n_classes=1, base_filters=base_filters)

# Populate model with weights
model.load_state_dict(checkpoint["state_dict"])

# Move model to gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

## Inference

In [None]:
for batch_idx, batch in enumerate(test_dataloader):
    inputs = batch["image"].to(device)
    pixel_mask = batch["pixel_mask"].to(device)
    scene_date = batch["scene_date"][0]
    proba, mask = model(inputs)
    inputs = inputs * pixel_mask
    proba = proba * pixel_mask
    mask = mask * pixel_mask

    pred_mask = mask[0, :, :].squeeze().detach().cpu().numpy().astype(np.uint8)
    prediction = proba[0, :, :].squeeze().detach().cpu().numpy()
    image = inputs[0, 0, :, :].detach().cpu().numpy() # this should show the ndmi channel
    
    f = plt.figure(figsize=(15, 4))

    subplot1 = f.add_subplot(1, 3, 1)
    im1 = subplot1.imshow(image)
    subplot1.set_title(f"NDMI - {scene_date}")
    f.colorbar(im1, ax=subplot1)
    

    subplot3 = f.add_subplot(1, 3, 2)
    im3 = subplot3.imshow(pred_mask, cmap="binary")
    subplot3.set_title(f"Predicted Mask - {source_rates[batch_idx]} t/h")
    f.colorbar(im3, ax=subplot3)

    subplot4 = f.add_subplot(1, 3, 3)
    im4 = subplot4.imshow(prediction, cmap="plasma")
    subplot4.set_title("Probabilities")
    f.colorbar(im4, ax=subplot4)

    plt.show()