In [None]:
"""
Download satellite imagery using Microsoft Planetary Computer API

"""

## Import and Setup

In [1]:
import pandas as pd
from datetime import timedelta
import numpy as np
from pathlib import Path
from tqdm import tqdm

import planetary_computer as pc
from pystac_client import Client
import geopy.distance as distance

import rioxarray
from PIL import Image
import odc.stac
import cv2

In [2]:
catalog = Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1", modifier=pc.sign_inplace
)

## Helper Functions

In [3]:
def get_bounding_box(latitute, longitude, meter_buffer=50000):
    """
    Given a latitude, longitude, and buffer in meters, returns a bounding
    box around the point with the buffer on the left, right, top, and bottom.

    Returns a list of [minx, miny, maxx, maxy]
    """
    distance_search = distance.distance(meters=meter_buffer)
    
    # calculate the lat/long bounds based on ground distance
    # bearings are cardinal directions to move (south, west, north, and east)   
    min_lat = distance_search.destination(point=(latitute, longitude), bearing=180).latitude
    min_long = distance_search.destination(point=(latitute, longitude), bearing=270).longitude
    max_lat = distance_search.destination(point=(latitute, longitude), bearing=0).latitude
    max_long = distance_search.destination(point=(latitute, longitude), bearing=90).longitude
    
    return [min_long, min_lat, max_long, max_lat]

In [4]:
def get_date_range(date, time_buffer_days=15):
    """
    Get a date range to search for in the planetary computer based
    on a sample's date. The time range will include the sample date
    and time_buffer_days days prior

    Returns a string
    """
    datetime_format = "%Y-%m-%d"
    range_start = pd.to_datetime(date) - timedelta(days=time_buffer_days)
    date_range = f"{range_start.strftime(datetime_format)}/{pd.to_datetime(date).strftime(datetime_format)}"

    return date_range

In [5]:
def crop_sentinel_image(item, bounding_box):
    """
    Given a STAC item from Sentinel-2 and a bounding box tuple in the format
    (minx, miny, maxx, maxy), return a cropped portion of the item's visual
    imagery in the bounding box.

    Returns the image as a numpy array with dimensions (color band, height, width)
    """
    (minx, miny, maxx, maxy) = bounding_box

    image = rioxarray.open_rasterio(pc.sign(item.assets["visual"].href)).rio.clip_box(
        minx=minx,
        miny=miny,
        maxx=maxx,
        maxy=maxy,
        crs="EPSG:4326",
    )

    return image.to_numpy()

In [6]:
def crop_landsat_image(item, bounding_box):
    """
    Given a STAC item from Landsat and a bounding box tuple in the format
    (minx, miny, maxx, maxy), return a cropped portion of the item's visual
    imagery in the bounding box.

    Returns the image as a numpy array with dimensions (color band, height, width)
    """
    (minx, miny, maxx, maxy) = bounding_box

    image = odc.stac.stac_load(
        [pc.sign(item)], bands=["red", "green", "blue"], bbox=[minx, miny, maxx, maxy]
    ).isel(time=0)
    image_array = image[["red", "green", "blue"]].to_array().to_numpy()

    # normalize to 0 - 255 values
    image_array = cv2.normalize(image_array, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)

    return image_array

## Prep to Get Data

In [7]:
DATA_DIR = Path.cwd() / 'data'

IMAGE_DIR = DATA_DIR / 'images'
#IMAGE_DIR = DATA_DIR / 'sentinel-images'
IMAGE_DIR.mkdir(exist_ok=True, parents=True)

In [8]:
metadata = pd.read_csv(DATA_DIR / 'metadata.csv')
metadata.date = pd.to_datetime(metadata.date)
print(metadata.shape)
metadata.head()

(23570, 5)


Unnamed: 0,uid,latitude,longitude,date,split
0,aabm,39.080319,-86.430867,2018-05-14,train
1,aabn,36.5597,-121.51,2016-08-31,test
2,aacd,35.875083,-78.878434,2020-11-19,train
3,aaee,35.487,-79.062133,2016-08-24,train
4,aaff,38.049471,-99.827001,2019-07-23,train


In [9]:
metadata.split.value_counts()

train    17060
test      6510
Name: split, dtype: int64

In [10]:

# train_subset = metadata[metadata['split'] == 'train'].sample(n=2500, random_state=42)
# locations_to_get = pd.concat([train_subset, metadata[metadata['split'] == 'test']])

locations_to_get = metadata

locations_to_get.split.value_counts()

train    17060
test      6510
Name: split, dtype: int64

## Get Data

In [11]:
def select_best_image(items, date, latitude, longitude):
    """
    Selects the best satellite item given a sample's date, latitude, and longitude.
    If any Sentinel-2 items are available, they will be selected.
    If not, then the closest Landsat-8 or Landsat-9 item will be selected.
    
    Returns a tuple of (STAC item, item platform name, item date)
    """
    # get item details
    item_details = pd.DataFrame(
        [
            {
                "datetime": item.datetime.strftime("%Y-%m-%d"),
                "platform": item.properties["platform"],
                "min_long": item.bbox[0],
                "max_long": item.bbox[2],
                "min_lat": item.bbox[1],
                "max_lat": item.bbox[3],
                "item_obj": item,
            }
            for item in items
        ]
    )
    
    # filter to items that contain the point location, or return None if none contain the point
    item_details["contains_sample_point"] = (
        (item_details.min_lat < latitude)
        & (item_details.max_lat > latitude)
        & (item_details.min_long < longitude)
        & (item_details.max_long > longitude)
    )
    item_details = item_details[item_details["contains_sample_point"] == True]
    
    if len(item_details) == 0:
        return (np.nan, np.nan, np.nan)
    
    # add time difference between each item and the sample
    item_details["time_diff"] = pd.to_datetime(date) - pd.to_datetime(
        item_details["datetime"]
    )
    
    # if we have sentinel-2, filter to sentinel-2 images only
    item_details["sentinel"] = item_details.platform.str.lower().str.contains(
        "sentinel"
    )
    if item_details["sentinel"].any():
        item_details = item_details[item_details["sentinel"] == True]
        
    # return the closest imagery by time
    best_item = item_details.sort_values(by="time_diff", ascending=True).iloc[0]
    
    return (best_item["item_obj"], best_item["platform"], best_item["datetime"])

In [12]:
#selected_items = {}
paths_dict = {}
errored_ids = []

# i = 0
for row in tqdm(locations_to_get.itertuples(), total=len(locations_to_get)):
    # i += 1
    # if i > 20: break
    
    image_pth = IMAGE_DIR / f"{row.uid}.png"
    
    if image_pth.exists():
        pass
        paths_dict[row.uid] = image_pth
    
    else:
        try:
            ## QUERY PLANETARY COMPUTER
            search_bbox = get_bounding_box(
                row.latitude,
                row.longitude,
                meter_buffer=1000
            )
            
            search_date_range = get_date_range(
                row.date,
                time_buffer_days=15
            )
            
            search_results = catalog.search(
                collections=[
                    "sentinel-2-l2a", 
                    "landsat-c2-l2",
                ],
                bbox=search_bbox,
                datetime=search_date_range,
                query={
                    "eo:cloud_cover": {"lt": 20},
                    "platform": {"in": ["Sentinel-2A", "Sentinel-2B", "landsat-8", "landsat-9"]},
                }
            )

            items = [item for item in search_results.get_items()]
            
            
            ## GET BEST IMAGE
            if len(items) == 0:
                pass
            else:
                pass
                best_item, item_platform, item_date = select_best_image(
                    items,
                    row.date,
                    row.latitude,
                    row.longitude
                )
                if best_item is np.nan:
                    raise Exception('No image found')
                
                # selected_items[row.uid] = {
                #     'item_object': best_item,
                #     'item_platform': item_platform,
                #     'item_date': item_date,
                # }
            
            
            # SAVE IMAGE DATA
            save_bbox = get_bounding_box(
                row.latitude,
                row.longitude,
                meter_buffer=1000,
            )
            
            if 'sentinel' in item_platform.lower():
                image_array = crop_sentinel_image(best_item, save_bbox)
            else:
                image_array = crop_landsat_image(best_item, save_bbox)
            
            if sum(image_array.flatten()) == 0:
                raise Exception("Image is all black")
            
            image = Image.fromarray(np.transpose(image_array, axes=[1, 2, 0]))
            image.save(image_pth)
            
            paths_dict[row.uid] = image_pth
            #selected_items[row.uid]['filename'] = image_pth
            
        except:
            errored_ids.append(row.uid)

100%|██████████| 23570/23570 [48:33<00:00,  8.09it/s]   


In [13]:
len(paths_dict)

17094

In [14]:
# see how many ran into errors
print(f"Could not pull satellite imagery for {len(errored_ids)} samples")

Could not pull satellite imagery for 6476 samples


## Determine what image files were saved

In [13]:
import os

In [14]:
imgs = os.listdir(IMAGE_DIR)
imgs.remove('.DS_Store')

In [15]:
df = pd.DataFrame(imgs, columns=['uid'])

In [16]:
df['uid'] = df['uid'].apply(lambda x: x.replace('.png', ''))

In [17]:
saved = df.merge(
    metadata,
    how='left',
    left_on='uid',
    right_on='uid',
    validate='1:1'
)

In [18]:
print(len(saved))
saved.head()

18900


Unnamed: 0,uid,latitude,longitude,date,split
0,xlix,35.07,-78.88788,2018-08-06,test
1,wadh,40.025833,-85.307944,2021-07-20,test
2,yzss,34.4069,-119.517,2017-09-28,test
3,vypp,40.969888,-80.369248,2019-07-11,test
4,zjkd,36.03,-78.688036,2020-09-16,train


In [19]:
saved.to_csv(DATA_DIR / 'c20p-metadata.csv', index=False)

In [20]:
saved.split.value_counts()

train    14009
test      4880
Name: split, dtype: int64

In [21]:
metadata.split.value_counts()

train    17060
test      6510
Name: split, dtype: int64

## Remove all black image files

In [None]:
im = Image.open(IMAGE_DIR / 'kmcg.png')

In [42]:
black_pngs = []
for i in tqdm(imgs):
    im = Image.open(IMAGE_DIR / i)
    im_arr = np.array(im)
    if sum(im_arr.flatten()) == 0:
        black_pngs.append(i)

100%|██████████| 23536/23536 [02:30<00:00, 156.69it/s]


In [44]:
len(black_pngs)

6896

In [48]:
for i in tqdm(black_pngs):
    os.remove(IMAGE_DIR / i)

100%|██████████| 6896/6896 [00:01<00:00, 4751.22it/s]
