# Pre-process module

Includes:
- A class that returns a list of S3 URLS for a dataset based on shortname, temporal range, and bounding box
- Function to return temporary credentials for an S3 Bucket.
- Function to clean and concatenate the data.
- Function to write the data.
- Function to run all operations and serve as container entrypoint.

In [None]:
# Standard imports
from http.cookiejar import CookieJar
import netrc
from socket import gethostname, gethostbyname
from urllib import request

# Third-party imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylab as plt
import requests
import s3fs
from tqdm import tqdm
import xarray as xr

In [None]:
class S3List:
    """Class used to query and download from PO.DAAC's CMR API.
    """

    CMR = "cmr.earthdata.nasa.gov"
    URS = "urs.earthdata.nasa.gov"

    def __init__(self):
        self._token = None

    def login(self):
        """Log into Earthdata and set up request library to track cookies.
        
        Raises an exception if can't authenticate with .netrc file.
        """

        try:
            username, _, password = netrc.netrc().authenticators(self.URS)
        except (FileNotFoundError, TypeError):
            raise Exception("ERROR: There not .netrc file or endpoint indicated in .netrc file.")

        # Create Earthdata authentication request
        manager = request.HTTPPasswordMgrWithDefaultRealm()
        manager.add_password(None, self.URS, username, password)
        auth = request.HTTPBasicAuthHandler(manager)

        # Set up the storage of cookies
        jar = CookieJar()
        processor = request.HTTPCookieProcessor(jar)

        # Define an opener to handle fetching auth request
        opener = request.build_opener(auth, processor)
        request.install_opener(opener)

    def get_token(self, client_id, ip_address):
        """Get CMR authentication token for searching records.
        
        Parameters
        ----------
        client_id: str
            client identifier to obtain token
        ip_address: str
            client's IP address
        """

        try:
            username, _, password = netrc.netrc().authenticators(self.URS)
        except (FileNotFoundError, TypeError) as error:
            raise Exception("ERROR: There not .netrc file or endpoint indicated in .netrc file.")

        # Post a token request and return resonse
        token_url = f"https://{self.CMR}/legacy-services/rest/tokens"
        token_xml = (f"<token>"
                        f"<username>{username}</username>"
                        f"<password>{password}</password>"
                        f"<client_id>{client_id}</client_id>"
                        f"<user_ip_address>{ip_address}</user_ip_address>"
                    f"</token>")
        headers = {"Content-Type" : "application/xml", "Accept" : "application/json"}
        self._token = requests.post(url=token_url, data=token_xml, headers=headers) \
            .json()["token"]["id"]

    def delete_token(self):
        """Delete CMR authentication token."""

        token_url = f"https://{self.CMR}/legacy-services/rest/tokens"
        headers = {"Content-Type" : "application/xml", "Accept" : "application/json"}
        try:
            res = requests.request("DELETE", f"{token_url}/{self._token}", headers=headers)
            return res.status_code
        except Exception as e:
            raise Exception(f"Failed to delete token: {e}.")

    def run_query(self, shortname, provider, temporal_range, bbox):
        """Run query on collection referenced by shortname from provider."""

        url = f"https://{self.CMR}/search/granules.umm_json"
        params = {
                    "provider" : provider, 
                    "ShortName" : shortname, 
                    "token" : self._token,
                    "scroll" : "true",
                    "page_size" : 2000,
                    "sort_key" : "start_date",
                    "temporal" : temporal_range,
                    "bounding_box": bbox,
                    "page_size": 2000,
                }
        res = requests.get(url=url, params=params)        
        coll = res.json()
        return [url["URL"] for res in coll["items"] for url in res["umm"]["RelatedUrls"] if url["Type"] == "GET DATA VIA DIRECT ACCESS"]

    def login_and_run_query(self, short_name, provider, temporal_range, bbox):
        """Log into CMR and run query to retrieve a list of S3 URLs."""

        try:
            # Login and retrieve token
            self.login()
            client_id = "podaac_cmr_client"
            hostname = gethostname()
            ip_addr = gethostbyname(hostname)
            self.get_token(client_id, ip_addr)

            # Run query
            s3_urls = self.run_query(short_name, provider, temporal_range, bbox)
            s3_urls.sort()

            # Clean up and delete token
            self.delete_token()            
        except Exception:
            raise
        else:
            # Return list
            return s3_urls

In [None]:
def init_S3FileSystem(provider):
    """
    This routine automatically pull your EDL crediential from .netrc file and use it to obtain an AWS S3 credential through a podaac service accessable at https://archive.podaac.earthdata.nasa.gov/s3credentials
    
    Return:
    =======
    
    s3: an AWS S3 filesystem
    """
    
    s3_cred_endpoint = {
        'pocloud':'https://archive.podaac.earthdata.nasa.gov/s3credentials',
        'lpdaac':'https://data.lpdaac.earthdatacloud.nasa.gov/s3credentials',
        'ornldaac':'https://data.ornldaac.earthdata.nasa.gov/s3credentials',
        'gesdisc':'https://data.gesdisc.earthdata.nasa.gov/s3credentials'
    }
    
    creds = requests.get(s3_cred_endpoint[provider.lower()]).json()
    s3 = s3fs.S3FileSystem(anon=False,
                           key=creds['accessKeyId'],
                           secret=creds['secretAccessKey'], 
                           token=creds['sessionToken'])
    return s3

In [None]:
def clean_l2p_da(data, qc_threshold, bbox):
    """
    bbox: West,South,East,North
    """
    w, s, e, n = [int(a) for a in bbox.split(',')]
    da = data.drop_vars([#'sst_dtime',
                  #'dt_analysis',
                  'satellite_zenith_angle',
                  'sea_ice_fraction',
                  'sst_gradient_magnitude',
                  'sst_front_position'])
    data_threshold = da.where(da.quality_level>=qc_threshold)
    all_data = data_threshold.stack(pt=('ni','nj')).dropna(dim='pt')
    loc_data = all_data.where(all_data.lat>=s).where(all_data.lat<=n).where(all_data.lon<=e).where(all_data.lon>=w)
    neat_data = loc_data.dropna(dim='pt').reset_index('pt').drop(['ni', 'nj'])
    time = neat_data.time.values
    to_return = neat_data.drop('time').squeeze()
    # to_return['time'] = xr.DataArray(data=np.tile(t, (len(sq.lat))), dims=['pt'])
    to_return['time'] = xr.DataArray(data=(time+to_return['sst_dtime'].values), dims=['pt'])
    return to_return.drop_vars(['sst_dtime'])

In [None]:
def clean_and_concat(s3sys, s3_url_list, qc_threshold, bbox):
    """
    s3sys: Initialized S3FileSystem
    s3_url_list: list of S3 urls that contain data to clean and concat.
    qc_threshold: quality control threshold.
    bbox: bound_box (w,s,e,n)
    
    Return an xarray of cleaned and concatenated data from S3 URLS.
    """
    
    cleaned_data = [clean_l2p_da(xr.open_dataset(s3sys.open(url)),
                             qc_threshold=4,
                             bbox=bbox) for url in tqdm(s3_url_list)]
    
    concat_data = xr.concat(cleaned_data, dim='pt').set_coords('time').sortby('lon').sortby('lat')
    
    return concat_data    

In [None]:
def write_data(data, nc_file, zarr_file):
    """
    data: xarray.Dataset
    
    Write a netcdf file and a zarr file.
    """
    
    data.to_netcdf(path=nc_file, mode='w', format="NETCDF4", engine="h5netcdf", compute=True)
    data.to_zarr(store=zarr_file, mode='w', compute=True)

In [None]:
def plot_data(data):
    """Plot data for quick verification"""
    
    data.plot.scatter(x='lon', y='lat', hue='sea_surface_temperature', cmap='Spectral_r', s=0.01)
    plt.show()

In [None]:
def run_preprocessing(short_name, provider, temporal_range, bbox, qc_threshold, out_dir):
    """Preprocess dataset referenced by short_name.
    
    Writes NetCDF and Zarr files.
    """
    
    # Get data
    s3_list = S3List()
    s3_urls = s3_list.login_and_run_query(short_name, provider, temporal_range, bbox)
    s3sys = init_S3FileSystem(provider)
    
    # Clean and concat data
    concat_data = clean_and_concat(s3sys, s3_urls, qc_threshold, bbox)
    
    # Write out data
    nc_file = out_dir.joinpath(f"{short_name}.nc")
    zarr_file = out_dir.joinpath(f"{short_name}.zarr")
    write_data(concat_data, nc_file, zarr_file)