# Landsat Processing
Created by: Oriana Chegwidden

In [None]:
%load_ext autoreload
%autoreload 2
import boto3
from rasterio.session import AWSSession
from s3fs import S3FileSystem
aws_session = AWSSession(boto3.Session(), requester_pays=True)
fs = S3FileSystem(requester_pays=True)
from osgeo.gdal import VSICurlClearCache
VSICurlClearCache() 
import rasterio as rio
import xarray as xr
import dask
import os
import fsspec
import geopandas as gpd
import regionmask as rm
# from satsearch import Search
from matplotlib.pyplot import imshow
from intake import open_stac_item_collection
import numcodecs
import numpy as np
import rioxarray # for the extension to load
import matplotlib.pyplot as plt
import utm

In [None]:
from dask_gateway import Gateway

gateway = Gateway()
options = gateway.cluster_options()
options.worker_cores = 2
options.worker_memory = 32
options.environment = {'AWS_REQUEST_PAYER': 'requester'}
cluster = gateway.new_cluster(cluster_options=options)
cluster.adapt(minimum=1, maximum=10)
cluster

In [None]:
client = cluster.get_client()

Each Landsat scene is stored in cloud optimized geotiff (COG) according to a verbose (but once you understand it, human readable!) naming convention. Landsat Collection 2 uses the same naming convention as Collection 1 which is as follows (lifted from their docs at `https://prd-wret.s3.us-west-2.amazonaws.com/assets/palladium/production/atoms/files/LSDS-1656_%20Landsat_Collection1_L1_Product_Definition-v2.pdf`

```LXSS_LLLL_PPPRRR_YYYYMMDD_yyyymmdd_CC_TX```
where
```
L = Landsat  (constant)
X = Sensor  (C = OLI / TIRS, O = OLI-only, T= TIRS-only, E = ETM+, T = TM, M= MSS)
SS = Satellite  (e.g., 04 for Landsat 4, 05 for Landsat 5, 07 for Landsat 7, etc.) 
LLLL = Processing  level  (L1TP, L1GT, L1GS)
PPP  = WRS path
RRR  = WRS row
YYYYMMDD = Acquisition  Year (YYYY) / Month  (MM) / Day  (DD) 
yyyymmdd  = Processing  Year (yyyy) / Month  (mm) / Day (dd)
CC = Collection  number  (e.g., 01, 02, etc.) 
TX= RT for Real-Time, T1 for Tier 1 (highest quality), and T2 for Tier 2

```

Thus, we're looking for scenes coded in the following way:
`LE07_????_PPP_RRR_YYYMMDD_yyyymmdd_02_T1` for Landsat 7 and
`LT05_????_PPP_RRR_YYYMMDD_yyyymmdd_02_T1` for Landsat 5
(but T1 might be wrong there)


We are re-implementing (to the best of our abilities) the methods from Wang et al (in review). Jon Wang's paper said:

```To extend our AGB predictions through space and time, we used time series (1984 – 2014) of 30 m surface reflectance data from the Thematic Mapper onboard Landsat 5 and the Enhanced Thematic Mapper Plus onboard Landsat 7. We used the GLAS-derived estimates of AGB as a response variable and the mean growing season (June, July, August) and non-growing season values for each of Landsat’s six spectral reflectance bands as the predictors in an ensemble machine learning model```

So we'll be looking for:
* Landsat 5 (Thematic mapper) and 7 (Enhanced Thematic Mapper Plus)
* Growing season (June-August) and non-growing season (Sept-May) averages at an annual timestep. <--- will need to figure out around the calendar whether we want consecutive
* All six spectral reflectance bands
* We'll do a quality thresholding of cloudless cover for now based upon their thresholding

In orienting myeslf, these are the potential collection options I've figured out (by poking around here on the [sat-api catalog](https://landsatlook.usgs.gov/sat-api/collections):
* `landsat-c2l2-sr` Landsat Collection 2 Level-2 UTM Surface Reflectance (SR) Product
* `landsat-c2l2alb-sr` Landsat Collection 2 Level-2 Albers Surface Reflectance (SR) Product
* `landsat-c1l2alb-sr` Landsat Collection 1 Level-2 Albers Surface Reflectance (SR) Product <-- we don't want this one (b/c we'll go with collection 2)
* `landsat-c2l1` Landsat Collection 2 Level-1 Product <-- don't think we want this because we want surface reflectance


Run this once to apply the aws session to the rasterio environment

In [None]:
def test_credentials(aws_session, 
                            canary_file='s3://usgs-landsat/collection02/level-2/standard/'+\
                            'tm/2003/044/029/LT05_L2SP_044029_20030827_20200904_02_T1/'+\
                            'LT05_L2SP_044029_20030827_20200904_02_T1_SR_B2.TIF'):    
#     VSICurlClearCache()
    # this file is the canary in the coal mine
    # if you can't open this one you've got *issues* because it exists!
    # also the instantiation of the environment here
    # might help you turn on the switch of the credentials
    # but maybe that's just anecdotal i hate credential stuff SO MUCH
    # if anyone is reading this message i hope you're enjoying my typing
    # as i wait for my cluster to start up.... hmm....

    with rio.Env(aws_session):
        with rio.open(canary_file) as src:
            profile = src.profile
            
            arr = src.read(1)
        

In [None]:
def fix_link(url):
    return url.replace('https://landsatlook.usgs.gov/data', 's3://usgs-landsat')

There are different kinds of QA/QC bands contained in L2SP:
* SR_CLOUD_QA - I think we want this one because anything less than 2 is either just dark dense vegetation or no flags. everything above is stuff like water, snow, cloud (different levels of obscurity). This is the result of the fmask algorithm from Zhu et al.
* QA_PIXEL - this gets a little more specific and goes intot different kinds of clouds. Super interesting but I don't think we want to use it.

Pull in the SR_CLOUD_QA and use as a mask - see Table 5-3 in https://prd-wret.s3.us-west-2.amazonaws.com/assets/palladium/production/atoms/files/LSDS-1370_L4-7_C1-SurfaceReflectance-LEDAPS_ProductGuide-v3.pdf for description of cloud integer values to select which ones to use as drop. For now I'll drop anything greater than 1 (0= no QA concerns and 1 is Dark dense vegetation (DDV)).

In [None]:
def cloud_qa(item):
    if type(item)==str:
        qa_path = item
    else:
        qa_path = fix_link(item._stac_obj.assets['SR_CLOUD_QA.TIF']['href'])
    cog_mask = xr.open_rasterio(qa_path).squeeze().drop('band')
    return cog_mask

First we make the query using sat-search to find every file in the STAC catalog that we want. We'll store that list of files. We'll do this first for a single tile (in this first exmaple just covering Washington State) but then we'll loop through in 1-degree by 1-degree tiles. 

Due to memory constraints we'll average repeated captures of the same scene. Then we'll average all of those averaged scenes together to create the full mesh. As of now we're just doing a straight average but ideally we would carry the weights of the number of repeats of each scene and do a weighted average when quilting the scenes together.


In [None]:
def grab_ds(item, bands_of_interest, cog_mask):
    if type(item) == str:
        url_list = [item+'_{}.TIF'.format(band) for band in bands_of_interest]
    else:
        url_list = [fix_link(item._stac_obj.assets['{}.TIF'.format(band)]['href']) for band in bands_of_interest]
    da_list = []
    for url in url_list:
        da_list.append(rioxarray.open_rasterio(url, chunks={'x': 1024,
                                                    'y': 1024}))#.load())

    # combine into one dataset
    ds = xr.concat(da_list, dim='band').to_dataset(dim='band').rename({1: 'reflectance'})
    ds = ds.assign_coords({'band': bands_of_interest})
    # fill value is 0; let's switch it to nan
    ds = ds.where(ds != 0)  
    ds = ds.where(cog_mask<2)#.compute()
#     ds['reflectance'] = ds['reflectance'].astype('int16')
    
    return ds

In [None]:
def average_stack_of_scenes(ds_list):
    '''
    Average across scenes. This will work the same regardless 
    of whether your scenes are perfectly overlapping or they're offset.
    However, if they're offset it requires a merge and so the entire 
    datacube (pre-collapsing) will be instantiated and might make 
    your kernel explode.
    '''
    full_ds = xr.concat(ds_list, dim='scene').mean(dim='scene')#.compute()
#     full_ds = full_ds.chunk({'band': 1, 'x': 256, 'y': 256})
    return full_ds#.compute()

In [None]:
def write_out(ds, mapper, aws_session):
    encoding = {'reflectance': {'compressor': numcodecs.Blosc()}}
#     with rio.Env(aws_session):
    ds.to_zarr(store=mapper,
                        encoding=encoding, 
                         mode='w')

In [None]:
@dask.delayed 
def scene_seasonal_average(row, path, year, bucket, #aws_session,
                           bands_of_interest='all', season='JJA'):
    '''
    Given location/time specifications will grab all valid scenes,
    mask each according to its time-specific cloud QA and then 
    return average across all masked scenes
    '''
    aws_session = AWSSession(boto3.Session(), requester_pays=True)

    with dask.config.set(scheduler='threads'): # this? **** #threads #single-threaded
        with rio.Env(aws_session):
            print('testing credentials')
            test_credentials(aws_session)
            print('it works!')

            # set where you'll save the final seasonal average
            url = f'{bucket}{row}/{path}/{year}/{season}_reflectance.zarr'
            mapper = fsspec.get_mapper(url)
            # all of this is just to get the right formatting stuff to access the scenes

            landsat_bucket = 's3://usgs-landsat/collection02/level-2/standard/tm/{}/{:03d}/{:03d}/'
            month_keys = {'JJA': ['06', '07', '08']}
            valid_files, ds_list = [], []

            if bands_of_interest=='all':
                bands_of_interest = ['SR_B1', 'SR_B2', 'SR_B3', 
                                         'SR_B4', 'SR_B5', 'SR_B7']

            scene_stores = fs.ls(landsat_bucket.format(year, row, path))
            summer_datestamps = ['{}{}'.format(year, month) for month in month_keys[season]]
            for scene_store in scene_stores:
                for summer_datestamp in summer_datestamps:
                    if summer_datestamp in scene_store:
                        valid_files.append(scene_store)
            for file in valid_files:
                scene_id = file[-40:]
                url = 's3://{}/{}'.format(file, scene_id)
                cloud_mask_url = url+'_SR_CLOUD_QA.TIF'
                cog_mask = cloud_qa(cloud_mask_url)
                ds_list.append(grab_ds(url, bands_of_interest, cog_mask))
            seasonal_average = average_stack_of_scenes(ds_list)
            write_out(seasonal_average.chunk({'band': 6, 'x': 1024, 'y': 1024}), mapper, aws_session)
            return url

Then we take the list of files for a given year to average across growing season for each of the tiles and write it out to a mapper with those specifications.

In [None]:
dask.config.set({"array.slicing.split_large_chunks": True})

In [None]:
gdf = gpd.read_file('https://prd-wret.s3-us-west-2.amazonaws.com/assets/'
                   'palladium/production/s3fs-public/atoms/files/'
                   'WRS2_descending_0.zip')

In [None]:
washington_row_paths = gdf.cx[-125:-115,45:49][['PATH', 'ROW']].values

In [None]:
PANGEO_SCRATCH=os.environ['PANGEO_SCRATCH_PREFIX']+'/orianac/'
tasks=[]
rerun=True
if rerun:
    with rio.Env(aws_session): # delete
#         for every year in GLAS record
        for year in np.arange(2003,2009):
            # for every row path in the domain
            for [row, path] in washington_row_paths:
                for season in ['JJA']:
                        tasks.append(
                            client.compute(
                        scene_seasonal_average(row, path, year, PANGEO_SCRATCH, #aws_session,
                                                              bands_of_interest='all',
                                                              season=season),
                            retries=1))

In [None]:
for year in np.arange(2003,2009):
    # for every row path in the domain
    for [row, path] in washington_row_paths:
        for season in ['JJA']:
            url = f'{PANGEO_SCRATCH}{row}/{path}/{year}/{season}_reflectance.zarr'
            if not fs.exists(url):
                print(url)

### Now let's link with GLAS. We'll loop through every 10x10 degree GLAS tile and repeat this process

In [None]:
ul_lat, ul_lon = 50, 120
biomass = xr.open_zarr(PANGEO_SCRATCH+'biomass/{}N_{}W.zarr'.format(ul_lat, ul_lon), consolidated=True).load()
biomass = biomass.stack(unique_index=("record_index", "shot_number")).to_dataframe().dropna(how='all')

In [None]:
biomass_gdf = gpd.GeoDataFrame(
    biomass, geometry=gpd.points_from_xy(biomass.lon, biomass.lat)).set_crs("EPSG:4326")
linked_gdf = gpd.sjoin(biomass_gdf, gdf.cx[-ul_lon:-ul_lon+10,ul_lat-10:ul_lat], how='left')

### Now you have the row and path for each shot. Let's now get the url to the appropriate COG.

In [None]:
linked_gdf['utm_zone'] = utm.from_latlon(linked_gdf['lat'], )

In [None]:
test = linked_gdf.head()

In [None]:
test.apply(lambda x: segmentMatch(x['TimeCol'], x['ResponseCol']), axis=1)

In [None]:
# grab zone (ensuring it's the right one for the row/path) mapping of the lat/lon in 6 degree increments

In [None]:
utm.from_latlon(biomass_gdf.head()['lat'].values, biomass_gdf.head()['lon'].values)

In [None]:
# convert time to date  which will go into the lookup (really we just need year)

In [None]:
# fix x/y w pyproj 

In [None]:
# 

In [None]:
one_shot = biomass.load()#.lat.values#, shot_number=15).load()

In [None]:
one_shot.stack(unique_index=("record_index", "shot_number"))

In [None]:
# project lat lon 

In [None]:
from pyproj import Proj, transform

In [None]:
wgs84 = Proj(proj="utm", zone=10, ellps="WGS84")

In [None]:
proj_x, proj_y = wgs84(one_shot.lon.values,one_shot.lat.values)

In [None]:
proj_x, proj_y 