In [1]:
import pandas as pd
import numpy as np
import geopandas as gpd

import time
import os, sys
from glob import glob
from pathlib import Path
from datetime import datetime, date, timedelta

import xarray as xr
import rasterio as rio

import requests
import argparse
import tempfile 
from loguru import logger
import tqdm.auto as tq

from layers import *
from dataset import *

from dask.diagnostics import ProgressBar

In [2]:
args = argparse.Namespace(
    date = '2022-02-03',
    timelag = 92,
    hrrr_dir ='data/hrrr', 
    modis_dir = 'data/modis',
    grid_cells ='evaluation/grid_cells.geojson',
    dem_file = 'data/copernicus_dem/COP90_hh.tif',
    soil_file = 'data/global_soil_regions/so2015v2.tif',
    hrrr_sample = 'hrrr_sample.grib2',
    sun_decline = 'sun_decline.csv',
    model_dir = 'weights',
    output_file ='submission_single.csv',
)

In [3]:
def create_ds(date_range, args):
        
    # Extract data
    logger.info("Loading HRRR")
    hrrr_ds = load_hrrr(date_range, args)
    
    logger.info("Loading MODIS")
    modis_ds = load_modis(date_range, args)
    
    # Merge
    ds = xr.merge([hrrr_ds, modis_ds])
    
    logger.info("Loading Sun Duration data")
    sun_duration, attrs = load_sundecline(date_range, args)
       
    logger.info("Loading Static data")
    images_dem, images_soil = load_static(args)
    
    ds = ds.assign(dict(
        sd = (['time', 'cell_id'], sun_duration, attrs),
        dem = (["cell_id", "xlat", "ylon"], images_dem),
        soil = (["cell_id", "xlat", "ylon"], images_soil),
    ))
    return ds

In [4]:
# Download files if required:
start_date = datetime.strptime(args.date, '%Y-%m-%d')
date_range = pd.date_range(
        start_date - timedelta(args.timelag), start_date, closed='left', freq='1D')

download_hrrr(date_range, args)
download_modis(date_range, args)

ds = create_ds(date_range, args)

  0%|          | 0/92 [00:00<?, ?it/s]

  0%|          | 0/92 [00:00<?, ?it/s]

2022-02-11 13:01:42.303 | INFO     | __main__:create_ds:4 - Loading HRRR


  0%|          | 0/3 [00:00<?, ?it/s]

2022-02-11 13:16:30.571 | INFO     | __main__:create_ds:7 - Loading MODIS


  0%|          | 0/92 [00:00<?, ?it/s]

2022-02-11 13:32:55.167 | INFO     | __main__:create_ds:13 - Loading Sun Duration data
2022-02-11 13:32:57.586 | INFO     | __main__:create_ds:16 - Loading Static data


In [5]:
ds = ds.loc[{"time" : date_range}]

band = xr.concat([
        (ds.t00 - 273.15) / 20,
        (ds.t12 - 273.15) / 20,
        (ds.sdwe**0.25 - 1),
        (ds.pwat - 8) / 7,
        ds.refc / 10,
        ds.u / 20,
        ds.v / 20,
        ds.sdwea,
        ds.NDSI.ffill('time').fillna(0).reduce(np.nanmean, ("x", "y")),
        (ds.sd / 200) - 3.6,
    ], dim = 'feature'
)

band_values = np.array(band.ffill('time').fillna(0).transpose(
    "cell_id", "feature", "time").data)

images_dem = ds.dem.data
images_soil = ds.soil.data

In [7]:
grid_cells = gpd.read_file(args.grid_cells)

In [8]:
logger.info("Loading model")
models = []
for fold_idx in range(5):
    model = SnowNet(features=10, h_dim=64, width=92, timelag=92)
    model.load_state_dict(
        torch.load(f'{args.model_dir}/SnowNet_fold_{fold_idx}_last.pt')['model']
    )
    models.append(model)
model = ModelAggregator(models)
model.eval();
logger.info("Evaluating...")

features = torch.from_numpy(band_values).float()
dem = torch.from_numpy(images_dem / 1000 - 2.25).float().unsqueeze(1)
soil = torch.from_numpy(images_soil).long()

with torch.no_grad():
    result = model(features, dem, soil).clamp(0)
    result = result.detach().cpu().numpy()
subm = pd.DataFrame(result,
            index=grid_cells.cell_id.values, columns=[args.date])
subm.to_csv(args.output_file)

logger.info("Evaluation completed ")

2022-02-11 13:49:23.483 | INFO     | __main__:<module>:1 - Loading model
2022-02-11 13:49:23.574 | INFO     | __main__:<module>:11 - Evaluating...
2022-02-11 13:49:41.897 | INFO     | __main__:<module>:24 - Evaluation completed 


In [9]:
subm.head()

Unnamed: 0,2022-02-03
0001daba-dd41-4787-84ab-f7956f7829a8,2.00229
0006d245-64c1-475f-a989-85f4787bae6a,9.534917
000a9004-1462-4b8c-96ee-0601aff0fdf7,2.038069
000ba8d9-d6d5-48da-84a2-1fa54951fae1,2.455909
00118c37-43a4-4888-a95a-99a85218fda6,3.383721
