In [1]:
import pandas as pd
import numpy as np
import geopandas as gpd

import time
import os, sys
from glob import glob
from pathlib import Path
from datetime import datetime, date, timedelta
from hashlib import md5

import xarray as xr
import rasterio as rio

import requests
import argparse

from loguru import logger
import tqdm.auto as tq

from layers import *
from dataset import *

In [2]:
args = argparse.Namespace(
    date = '2022-02-10',
    timelag = 92,
    hrrr_dir ='data/hrrr', 
    modis_dir = 'data/modis',
    grid_cells ='evaluation/grid_cells.geojson',
    dem_file = 'data/copernicus_dem/COP90_hh.tif',
    soil_file = 'data/global_soil_regions/so2015v2.tif',
    hrrr_sample = 'hrrr_sample.grib2',
    sun_decline = 'sun_decline.csv',
    model_dir = 'weights',
    dataset_file = 'evaluation/evaluation_dataset.nc',
    format_file ='evaluation/submission_evaluation.csv',
    output_file ='evaluation/submission_evaluation.csv',
)

In [3]:
# Get MD5 hash for grid_cells.geojson
md5_hash = md5()
with open(args.grid_cells, "rb") as f:
    md5_hash.update(f.read())
    MD5_HASH = md5_hash.hexdigest()
    
args.md5 = MD5_HASH

grid_cells = gpd.read_file(args.grid_cells)

## Cheack Dataset

In [4]:
def create_ds(date_range, args):
        
    # Extract data
    logger.info("Loading HRRR")
    hrrr_ds = load_hrrr(date_range, args)
    
    logger.info("Loading MODIS")
    modis_ds = load_modis(date_range, args)
    
    # Merge
    ds = xr.merge([hrrr_ds, modis_ds])
    
    logger.info("Loading Sun Duration data")
    sun_duration, attrs = load_sundecline(date_range, args)
       
    logger.info("Loading Static data")
    images_dem, images_soil = load_static(args)
    
    ds = ds.assign(dict(
        sd = (['time', 'cell_id'], sun_duration, attrs),
        dem = (["cell_id", "xlat", "ylon"], images_dem),
        soil = (["cell_id", "xlat", "ylon"], images_soil),
    ))
    return ds

def update_dataset(ds, date_range, args):
       
    ndate_range = date_range[~np.isin(date_range, ds.time.data)]
    if len(ndate_range) == 0:
        return ds

    download_hrrr(ndate_range, args)
    download_modis(ndate_range, args)

    # Extract data
    logger.info("Loading HRRR")
    hrrr_ds = load_hrrr(ndate_range, args)
    
    logger.info("Loading MODIS")
    modis_ds = load_modis(ndate_range, args)
    # Merge datasets
    nds = xr.merge([hrrr_ds, modis_ds])
    
    logger.info("Loading Sun Duration data")
    sun_duration, attrs = load_sundecline(ndate_range, args)
          
    nds = nds.assign(dict(
        sd = (['time', 'cell_id'], sun_duration, attrs),
    ))

    return xr.merge([ds, nds])

In [5]:
ds = None
if os.path.isfile(args.dataset_file):
    ds = xr.open_dataset(args.dataset_file, engine='netcdf4')
    ds.load();
    ds.close();
    
# Download files if required:
start_date = datetime.strptime(args.date, '%Y-%m-%d')
date_range = pd.date_range(
        start_date - timedelta(args.timelag), start_date, closed='left', freq='1D')
    
if (ds is None) or (ds.attrs['md5'] != args.md5):
    download_hrrr(date_range, args)
    download_modis(date_range, args)
    
    ds = create_ds(date_range, args)
    ds.attrs['md5'] = MD5_HASH
    
# Cheack dates
ds = update_dataset(ds, date_range, args)

# Save to file:
ds.to_netcdf(f"{args.dataset_file}", format="NETCDF4", engine='netcdf4')
ds.close();

  0%|          | 0/21 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

100% [..........................................................................] 8604413 / 8604413

2022-02-11 12:50:19.303 | INFO     | __main__:update_dataset:36 - Loading HRRR


  0%|          | 0/1 [00:00<?, ?it/s]

2022-02-11 12:53:45.382 | INFO     | __main__:update_dataset:39 - Loading MODIS


  0%|          | 0/21 [00:00<?, ?it/s]

2022-02-11 12:57:37.538 | INFO     | __main__:update_dataset:44 - Loading Sun Duration data


In [6]:
ds = ds.loc[{"time" : date_range}]

band = xr.concat([
        (ds.t00 - 273.15) / 20,
        (ds.t12 - 273.15) / 20,
        (ds.sdwe**0.25 - 1),
        (ds.pwat - 8) / 7,
        ds.refc / 10,
        ds.u / 20,
        ds.v / 20,
        ds.sdwea,
        ds.NDSI.ffill('time').fillna(0).reduce(np.nanmean, ("x", "y")),
        (ds.sd / 200) - 3.6,
    ], dim = 'feature'
)

band_values = np.array(band.ffill('time').fillna(0).transpose(
    "cell_id", "feature", "time").data)

images_dem = ds.dem.data
images_soil = ds.soil.data

In [7]:
subm = pd.read_csv(args.format_file, index_col=[0])

In [8]:
logger.info("Loading model")
models = []
for fold_idx in range(5):
    model = SnowNet(features=10, h_dim=64, width=92, timelag=92)
    model.load_state_dict(
        torch.load(f'{args.model_dir}/SnowNet_fold_{fold_idx}_last.pt')['model']
    )
    models.append(model)
model = ModelAggregator(models)
model.eval();
logger.info("Evaluating...")

features = torch.from_numpy(band_values).float()
dem = torch.from_numpy(images_dem / 1000 - 2.25).float().unsqueeze(1)
soil = torch.from_numpy(images_soil).long()

with torch.no_grad():
    result = model(features, dem, soil).clamp(0)
    result = result.detach().cpu().squeeze().numpy()
    
subm.loc[ds.cell_id.data, f"{start_date:%Y-%m-%d}"] = result
subm.to_csv(args.output_file)

logger.info("Evaluation completed ")

2022-02-11 12:57:41.484 | INFO     | __main__:<module>:1 - Loading model
2022-02-11 12:57:41.618 | INFO     | __main__:<module>:11 - Evaluating...
2022-02-11 12:58:01.132 | INFO     | __main__:<module>:24 - Evaluation completed 


In [9]:
subm.head()

Unnamed: 0,2022-01-13,2022-01-20,2022-01-27,2022-02-03,2022-02-10,2022-02-17,2022-02-24,2022-03-03,2022-03-10,2022-03-17,...,2022-04-28,2022-05-05,2022-05-12,2022-05-19,2022-05-26,2022-06-02,2022-06-09,2022-06-16,2022-06-23,2022-06-30
0001daba-dd41-4787-84ab-f7956f7829a8,0.0,3.546582,1.829914,0.831555,0.848542,,,,,,...,,,,,,,,,,
0006d245-64c1-475f-a989-85f4787bae6a,13.955076,6.358151,10.343756,9.55324,9.379189,,,,,,...,,,,,,,,,,
000a9004-1462-4b8c-96ee-0601aff0fdf7,10.017718,9.181601,0.936452,4.023512,3.084391,,,,,,...,,,,,,,,,,
000ba8d9-d6d5-48da-84a2-1fa54951fae1,8.177724,5.406429,1.136543,0.535009,3.114652,,,,,,...,,,,,,,,,,
00118c37-43a4-4888-a95a-99a85218fda6,11.490028,9.043668,9.718908,7.942723,5.224966,,,,,,...,,,,,,,,,,
