In [None]:
import xarray as xr
import zarr
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm

## II. Aggregating MERRA-2 data
### Written by Claire Wilson (cvwilson@andrew.cmu.edu)

Once you have all the daily files downloaded from MERRA-2, you'll need to aggregate them. This notebook will walk you through concatenating files into a larger zarr. 

If you saved a large region of interest, you will then pull out individual grid cells for model simulations using `3_get_tile.py`. In this case, it's easy to store the larger zarr on an external hard drive and run this code to pull out a smaller subset of data to store on your local machine.

If you only are interested in a single grid cell to begin with, you should name your region of interest to be "{lat}_{lon}" where lat and lon are the centerpoint of the MERRA-2 grid cell you accessed. Example:

`roi = '62.5_-145.625'`

The model expects .nc files, so if you've created .zarrs containing a single grid cell, simply rename these to .nc files and delete the .zarrs. The files are ready for the model to digest.

### 1. Specify the base filepath and name for the region of interest.

The zarr_store filepath points to where you want to store the zarrs. The region of interest is a descriptive name for the region you downloaded which will be included in the file name of the zarr store.

In [None]:
# Specify file naming
base_fp = '../MERRA-2/'
fp_zarr_store = base_fp + 'zarr_store/'
roi = 'reg01'

# Copy over your bounding box from notebook #1
# This is just used to make sure all the files have the exact same lat/lon
lat_min = 50
lat_max = 72
lon_min = -180
lon_max = -133.25

# =============================================================================================================================
# Do not edit these
dataset_variables = {
    'slv':['T2M','U2M','V2M','QV2M','PS'],
    'adg':['BCDP002','BCWT002','OCDP002','OCWT002','DUDP003','DUWT003'],
    'rad':['SWGDN','LWGAB','CLDTOT'], 
    'flx':['PRECTOTCORR']
}

### 2. Aggregate files by variable.

The following code will loop through all the .nc4 files within the dataset (adg, slv, rad, or flx) folder where you should've saved the daily files from MERRA-2. Make sure there are no different .nc4 files in this folder. 

All you should need to do with this code is specify the dataset filepath and let it chug away. Expect this to take anywhere from 1 to 10 minutes per dataset, depending on how much data you are concatenating. In my experience, concatenating 20 years of data for a single variable takes about 4 minutes. 

In [None]:
# Shouldn't need to make any edits to this cell
def process_files(dataset, data_fp, roi=roi, filetype='.nc4'):
    """
    This function loops through a user-specified folder which
    contains daily files downloaded from MERRA-2 in Notebook #1.
    The daily files should all be the same lat/lon bounds. They
    will be appended via time onto the existing zarr if it exists,
    otherwise the zarr will be created for each var in the files.
    
    Parameters
    ==========
    dataset : str
        Name of dataset from set [slv, adg, rad, flx]
    data_fp : str
        Filepath to folder containing files to append
    roi : str 
        Descriptive name for region of interest for output zarr
    filetype : str
        File extension of the files to append (should be .nc4 or .nc)
    """
    # get list of .nc4 files in the folder
    daily_files = sorted(glob.glob(data_fp + '*'+filetype))

    # check which times already exist for which vars
    times_var = {}
    for var in dataset_variables[dataset]:
        fn_zarr_var = os.path.join(fp_zarr_store, var, f'{var}_{roi}.zarr')
        if os.path.exists(fn_zarr_var):
            # open existing zarr and make a set of timestamps already in it
            ds_existing = xr.open_zarr(fn_zarr_var)
            existing_time = pd.to_datetime(ds_existing.time.values)
            times_var[var] = set(existing_time)
        else:
            # zarr does not exist; times is an empty set
            times_var[var] = set()

    # keep track of how many files were actually appended
    n_added = 0
    for i, f in tqdm(enumerate(daily_files), total=len(daily_files), desc=f'Processing {dataset}'):
        ds = xr.open_dataset(f)

        # loop through variables in the dataset
        for var in dataset_variables[dataset]:
            da = ds[var].sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))

            # get zarr path for this variable
            fn_zarr_var = os.path.join(fp_zarr_store, var, f'{var}_{roi}.zarr')

            # get first timestamp of the file without opening it
            assert 'Nx.' in f, 'File name is in an unexpected format: manually specify how to pull timestamp from the filename'
            time_file = pd.to_datetime(f.split('Nx.')[-1][:8]) + pd.Timedelta(minutes=30)

            # check if the time is already in the database
            if time_file in times_var[var]:
                # skip days that are already in the file
                continue
            elif var == dataset_variables[dataset][0]:
                n_added += 1

            # add this timestamp to the check list
            times_var[var].add(time_file)

            # check if we need to make the dataset
            if i == 0 and not os.path.exists(fn_zarr_var):
                # dataset does not exist so create it with the first file
                da.to_zarr(
                    fn_zarr_var,
                    mode='w',
                    consolidated=True,
                )
            else:
                # append daily dataset onto the zarr
                da.to_dataset(name=var).to_zarr(
                    fn_zarr_var,
                    mode='a',
                    append_dim='time',
                )

    print(f'Successfully concatenated {n_added} files to {dataset}_{roi}')

    # loop through variables and sort by time
    for var in dataset_variables[dataset]:
        # define file names
        fn_zarr_var = os.path.join(fp_zarr_store, var, f'{var}_{roi}.zarr')
        fn_zarr_unsorted = os.path.join(fp_zarr_store, var, f'{var}_{roi}_unsorted.zarr')

        # create a unique filename for the unsorted dataset 
        # (just in case you mess something up and need to run this more than once)
        ii = 0
        while os.path.exists(fn_zarr_unsorted.replace('unsorted', f'unsorted_{ii}')):
            ii += 1 
        fn_zarr_unsorted = fn_zarr_unsorted.replace('unsorted', f'unsorted_{ii}')

        # rename the unsorted dataset you just built to the unique filename
        os.rename(fn_zarr_var, fn_zarr_unsorted)

        # open the unsorted dataset
        ds = xr.open_zarr(fn_zarr_unsorted)

        # sort the dataset and check there are no time duplicates
        ds_sorted = ds.sortby('time')
        ds = ds.sel(time=~ds.time.to_index().duplicated())

        # rechunk dataset and clear the encoding for safe saving
        ds_sorted = ds_sorted.chunk({'time':8760, 'lat':20, 'lon':16})
        for v in ds_sorted.variables:
            ds_sorted[v].encoding.clear()

        # save the sorted dataset to the original filename
        ds_sorted.to_zarr(fn_zarr_var, mode='w', consolidated=True)
        
        # reconsolidate the data
        store = zarr.DirectoryStore(fn_zarr_var)
        zarr.consolidate_metadata(store)

    print(f'Resorted and reconsolidated all vars in {dataset}')
    return

In [None]:
for dataset in ['slv']:
    # Specify the folder to find files
    data_fp = base_fp + dataset + '/'

    # Process all files in the folder
    process_files(dataset, data_fp)
    # If files are not in .nc4 format, change the file type using filetype='.XXXX' argument in process_files

### 3. (Recommended) Visualize all the data to ensure files are saved properly and do not contain gaps.

The following code will help you check that this notebook worked correctly by producing the following:

1. Map of the region of interest containing missing data percentage for each grid cell. Purple = no missing data = good.
2. Timeseries collapsed by lat/lon. Should not contain any gaps.
3. Printout of the dataset itself for you to check the coordinates.

In [None]:
# loop through all datasets and all variables 
for dataset in dataset_variables:
    for var in dataset_variables[dataset]:
        # Figure 1: missing data map
        plt.figure()

        # open the dataset and get the data array
        ds = xr.open_zarr(f'../MERRA-2/zarr_store/{var}/{var}_{roi}.zarr')
        da = ds[var]
        n_time = da.sizes['time']

        # fraction of missing timesteps per lat/lon cell
        frac_missing = da.isnull().sum(dim='time') / n_time

        # plot the missing fraction
        frac_missing.plot(cmap='viridis', vmin=0, vmax=1)
        plt.title('Data missing percentage over Alaska region')

        # save the figure 
        if not os.path.exists(base_fp + 'Figs/'):
            os.mkdir(base_fp + 'Figs/')
        plt.savefig(base_fp + f'Figs/{var}_missing_area.png',dpi=300, bbox_inches='tight')
        plt.show()

        # Figure 2: timeseries
        fig, ax = plt.subplots()

        # sum or average over lat/lon based on the variable
        if 'OC' in var or 'PREC' in var or 'BC' in var or 'DU' in var:
            da = ds[var].sum(dim=['lat','lon'])
        else:
            da = ds[var].mean(dim=['lat','lon'])

        # plot the timeseries
        ax.plot(da.time.values, da.values)

        # descriptive title
        start = pd.to_datetime(da.time.values[0]).strftime('%d %b %Y')
        end = pd.to_datetime(da.time.values[-1]).strftime('%d %b %Y')
        ax.set_ylabel(var)
        ax.set_title(f'{var} timeseries for the Alaska region\n({start} to {end})')

        # store figure
        plt.savefig(base_fp + f'Figs/{var}_full_timeseries.png',dpi=300, bbox_inches='tight')
        plt.show()

        # print dataset
        print(ds)

## Congratulations, your zarrs are complete!