In [None]:
import requests
from PIL import Image
import io
import random
# import cv2
import torch
from torchvision import transforms
import segmentation_models_pytorch as smp
import plotly.graph_objects as go
import numpy as np
from sklearn import cluster
import plotly.express as px
import pandas as pd

# Replace with your own NASA API key
NASA_KEY = ""
IMG_DIM = 0.1




########################## Download NASA sat image
def download_satellite_image(lon, lat, date):
    """
    Downloads a satellite image from the NASA API for a specific longitude, latitude, and date.

    Args:
        lon (float): The longitude of the location to download the image for.
        lat (float): The latitude of the location to download the image for.
        date (str): The date in YYYY-MM-DD format of the image to download.

    Returns:
        Image: A PIL Image object containing the downloaded satellite image.
    """
    img_url = f"https://api.nasa.gov/planetary/earth/imagery?lon={lon}&lat={lat}&date={date}&dim={IMG_DIM}&api_key={NASA_KEY}"
    response = requests.get(img_url)
    image = Image.open(io.BytesIO(response.content))
    return image

########################## Process image
def process_img(image, resize=(256, 256)):
    """
    Preprocesses a PIL image for use in a machine learning model.

    Args:
        image (PIL.Image.Image): The input image to preprocess.
        resize (tuple[int, int]): The target size of the image after resizing. Defaults to (256, 256).

    Returns:
        np.ndarray: A 3D NumPy array representing the preprocessed image, with pixel values normalized to [0, 1].

    """
    #image = enhance_image(image)
    image = image.resize(resize)
    # Convert the image to a numpy array and normalize its values
    img_array = np.array(image).astype(np.float32) / 255
    return img_array

########################## Improve image clarity
# def enhance_image(image):
#     """
#     Enhances the contrast of a PIL image using the CLAHE (Contrast Limited Adaptive Histogram Equalization) algorithm.

#     Args:
#         image (PIL.Image.Image): The input RGB image to enhance.

#     Returns:
#         PIL.Image.Image: A new PIL image with enhanced contrast.

#     """
#     img_array = np.array(image)

#     # Convert to LAB color space
#     lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)

#     # Split into channels
#     l, a, b = cv2.split(lab)

#     # Apply histogram equalization to the L channel
#     clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
#     cl = clahe.apply(l)

#     # Merge the channels back into the LAB image
#     lab = cv2.merge((cl, a, b))

#     # Convert back to RGB color space
#     enhanced_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
#     return Image.fromarray(enhanced_image)


In [None]:

def segment_image_restnet(image_array, use_imagenet=False):
    """
    Segments an RGB image using a ResNet-based U-Net model.

    Args:
        image_array (np.ndarray): A 3D NumPy array representing the input RGB image.
        use_imagenet (bool): Whether to use ResNet50 pretrained on ImageNet. Defaults to False.

    Returns:
        np.ndarray: A 2D NumPy array representing the segmentation mask of the input image, with pixel values in {0, 1}.

    """
    # Ensure the model runs on CPU
    device = torch.device('cpu')

    if use_imagenet:
        model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', classes=2, activation='softmax').to(device)
    else:
        model = smp.Unet('resnet18', classes=2, activation='softmax').to(device)
    
    model.eval()

    input_tensor = torch.from_numpy(image_array.transpose(2, 0, 1)).unsqueeze(0).float().to(device)

    # Run the model on the input image
    with torch.no_grad():
        prediction = model(input_tensor)

    output_np = prediction.squeeze().argmax(dim=0).numpy()
    return output_np

def kmeans_cluster(img_array, n_clusters):
    """
    Performs k-means clustering on a single-band image.

    Args:
        img_array (np.ndarray): A 2D NumPy array representing the input single-band image.
        n_clusters (int): The number of clusters to use for the k-means algorithm.

    Returns:
        np.ndarray: A 2D NumPy array representing the clustering labels of the input image.

    """
    img = img_array[:, :, 0]  # Just get one band of the image
    X = img.reshape((-1, 1))
    k_means = cluster.KMeans(n_clusters=n_clusters)
    k_means.fit(X)
    segmentation = k_means.labels_
    segmentation = segmentation.reshape(img.shape)

    return segmentation


In [None]:
########################## Map overlay image
def create_colored_mask_image(segmentation, n_clusters):
    """
    Creates a color mask image from a segmentation label image.

    Args:
        segmentation (np.ndarray): A 2D NumPy array representing the segmentation label image.
        n_clusters (int): The number of clusters used to generate the segmentation label image.

    Returns:
        PIL.Image.Image: A new PIL Image object representing the color mask image.

    """
    # Use the colors from the Plotly Viridis colorscale
    colorscale = px.colors.sequential.Viridis

    def hex_to_rgb(hex_color):
        hex_color = hex_color.lstrip("#")
        return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    rgb_colorscale = [hex_to_rgb(color) for color in colorscale]

    # Interpolate the colorscale to have n_clusters colors
    interpolated_colors = []
    for i in range(n_clusters):
        position = i / (n_clusters - 1) * (len(colorscale) - 1)
        lower_color = rgb_colorscale[int(position)]
        upper_color = rgb_colorscale[min(int(position) + 1, len(colorscale) - 1)]
        ratio = position % 1
        interpolated_color = tuple([int(lower_color[i] * (1 - ratio) + upper_color[i] * ratio) for i in range(3)])
        interpolated_colors.append(interpolated_color)

    colored_mask = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)

    for i, color in enumerate(interpolated_colors):
        colored_mask[segmentation == i] = color

    return Image.fromarray(colored_mask)


def create_map_overlay(lon, lat, image=None, colored_mask=None):
    """
    Creates a plotly figure with a marker at the specified longitude and latitude, and optionally overlays an image and/or a colored mask on a map.

    Args:
        lon (float): The longitude to center the map on.
        lat (float): The latitude to center the map on.
        image (Optional[np.ndarray, PIL.Image.Image]): A 3D NumPy array or PIL Image object representing an image to overlay on the map. Defaults to None.
        colored_mask (Optional[np.ndarray, PIL.Image.Image]): A 3D NumPy array or PIL Image object representing a colored mask to overlay on the map. Defaults to None.

    Returns:
        plotly.graph_objs._figure.Figure: A new plotly figure with the specified overlay(s) on the map.

    """
    fig = go.Figure(go.Scattermapbox(
        lat=[lat],
        lon=[lon],
        mode='lines',
    ))
    mapbox_layers = []
    if image:
        mapbox_layers += [{
            'sourcetype': 'image',
            'source': image,
            'coordinates': [
                [lon - IMG_DIM/2, lat + IMG_DIM/2],
                [lon + IMG_DIM/2, lat + IMG_DIM/2],
                [lon + IMG_DIM/2, lat - IMG_DIM/2],
                [lon - IMG_DIM/2, lat - IMG_DIM/2],
            ],
            'opacity': 0.8, ## OPACITY
        }]
    if colored_mask:
        mapbox_layers += [{
            'sourcetype': 'image',
            'source': colored_mask,
            'coordinates': [
                [lon - IMG_DIM/2, lat + IMG_DIM/2],
                [lon + IMG_DIM/2, lat + IMG_DIM/2],
                [lon + IMG_DIM/2, lat - IMG_DIM/2],
                [lon - IMG_DIM/2, lat - IMG_DIM/2],
            ],
            'opacity': 0.6, ## OPACITY
        }]
    fig.update_layout(
        mapbox_style='open-street-map',
        mapbox_center_lat=lat,
        mapbox_center_lon=lon,
        mapbox_zoom=11.5,
        mapbox_layers=mapbox_layers,
        margin={"r":0,"t":0,"l":0,"b":0},
        width=512,
        height=512,
    )

    return fig


########################## Interactive figure
def create_colored_points(lon, lat, image_array, segmentation, n_clusters):
    """
    Creates lists of latitude, longitude, and color values for plotting colored points on a map.

    Args:
        lon (float): The longitude to center the map on.
        lat (float): The latitude to center the map on.
        image_array (np.ndarray): A 3D NumPy array representing an image to sample points from.
        segmentation (np.ndarray): A 2D NumPy array representing the clustering labels of the input image.
        n_clusters (int): The number of clusters used to generate the clustering label image.

    Returns:
        Tuple[List[float], List[float], List[str]]: A tuple of three lists representing the latitude, longitude, and color values of the colored points.

    """
    cluster_colors = px.colors.qualitative.Plotly[:n_clusters]

    lat_points = []
    lon_points = []
    colors = []

    lat_step = IMG_DIM / image_array.shape[0]
    lon_step = IMG_DIM / image_array.shape[1]

    for i in range(image_array.shape[0]):
        for j in range(image_array.shape[1]):
            lat_points.append(lat + IMG_DIM/2 - i*lat_step)
            lon_points.append(lon - IMG_DIM/2 + j*lon_step)
            colors.append(cluster_colors[segmentation[i, j]])

    return lat_points, lon_points, colors

def create_interactive_map(lon, lat, lat_points, lon_points, colors):
    """
    Creates an interactive plotly figure with colored markers at specified latitude and longitude points.

    Args:
        lon (float): The longitude to center the map on.
        lat (float): The latitude to center the map on.
        lat_points (List[float]): A list of latitude values for the colored markers.
        lon_points (List[float]): A list of longitude values for the colored markers.
        colors (List[str]): A list of color values for the colored markers.

    Returns:
        plotly.graph_objs._figure.Figure: A new interactive plotly figure with colored markers at the specified latitude and longitude points.

    """
    fig = go.Figure(go.Scattermapbox(
        lat=lat_points,
        lon=lon_points,
        mode='markers',
        marker=dict(size=3, color=colors),
    ))

    fig.update_layout(
        mapbox_style='open-street-map',
        mapbox_center_lat=lat,
        mapbox_center_lon=lon,
        mapbox_zoom=11.5,
        margin={"r":0,"t":0,"l":0,"b":0},
        width=512,
        height=512,
    )

    return fig



########################## Cluster proportion figure / stats
def calculate_class_proportions(segmentation, n_clusters):
    """
    Calculates the proportion of pixels in each cluster in a clustering label image.

    Args:
        segmentation (np.ndarray): A 2D NumPy array representing the clustering labels of an image.
        n_clusters (int): The number of clusters used to generate the clustering label image.

    Returns:
        np.ndarray: A 1D NumPy array representing the proportion of pixels in each cluster.

    """
    total_pixels = segmentation.size
    class_counts = np.zeros(n_clusters, dtype=int)

    for i in range(n_clusters):
        class_counts[i] = np.sum(segmentation == i)

    class_proportions = class_counts / total_pixels
    return class_proportions

def create_class_distribution_pie_chart(class_proportions):
    """
    Creates a pie chart showing the distribution of pixels in each cluster.

    Args:
        class_proportions (np.ndarray): A 1D NumPy array representing the proportion of pixels in each cluster.

    Returns:
        plotly.graph_objs._figure.Figure: A new plotly pie chart figure showing the distribution of pixels in each cluster.

    """
    cluster_ids = list(range(len(class_proportions)))
    fig = go.Figure(go.Pie(
        labels=cluster_ids,
        values=class_proportions,
        textinfo='label+percent',
        insidetextorientation='radial',
    ))

    fig.update_layout(
        title="Class Distribution",
        width=600,
        height=400,
    )

    return fig



In [None]:


interactive = False
use_coordinates = True

N_CLUSTERS = 4
date = "2019-01-01"
#lon, lat = 88.780777, 24.114852
lon, lat = 89.007557, 24.515067
image_path = "sample_river.jpg"
# image_path = "shasta_lake_2021_june_16.jpg"

if use_coordinates:
    image = download_satellite_image(lon, lat, "2018-01-01")
else:
    image = Image.open(image_path)

image_array = process_img(image)
segmentation = kmeans_cluster(image_array, N_CLUSTERS)
#segmentation = segment_image_restnet(image_array, True)

if interactive:
    lat_points, lon_points, colors = create_colored_points(lon, lat, image_array, segmentation, N_CLUSTERS)
    fig_map = create_interactive_map(lon, lat, lat_points, lon_points, colors)
else:
    colored_mask = create_colored_mask_image(segmentation, N_CLUSTERS)
    colored_mask.save('segmented_image.png')
    fig_map = create_map_overlay(lon, lat, colored_mask=colored_mask)


class_proportions = calculate_class_proportions(segmentation, N_CLUSTERS)
fig_distribution = create_class_distribution_pie_chart(class_proportions)


fig_map.show()
fig_distribution.show()
