In [139]:
import logging
logging.captureWarnings(True)

import deepsensor.torch
from deepsensor.model import ConvNP
from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds
from deepsensor.data.sources import get_era5_reanalysis_data, get_earthenv_auxiliary_data, get_gldas_land_mask
from deepsensor.train import set_gpu_default_device
from deepsensor.train import Trainer

import cartopy.crs as ccrs
import pandas as pd
import xarray as xr
import numpy as np
from tqdm import notebook

In [128]:
dat15 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2015.nc'
dat14 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2014.nc'
dat16 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2016.nc'
mask = xr.open_dataset('/home/erinredd/CIGLRProj/lakemask2.nc')


In [129]:
dat = xr.open_mfdataset([dat14, dat15, dat16],
                                concat_dim='time',
                                combine='nested',
                                chunks={'lat': 'auto', 'lon': 'auto'})

In [130]:
mdat = dat.where(np.isnan(dat.sst) == False, -0.009)
climatology = mdat.groupby('time.dayofyear').mean('time')
anomalies = mdat.groupby('time.dayofyear') - climatology

In [131]:
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
anom_ds = data_processor(anomalies)

In [141]:
task_loader = TaskLoader(
    context = anom_ds,
    target = anom_ds, 
#     aux=auxiliary_data_placeholder
)
val_tasks = []
for date in pd.date_range('2016-01-01T12:00:00.000000000', '2016-12-31T12:00:00.000000000'):
    N_context = np.random.randint(0, 100)
    task = task_loader(date, context_sampling="all", target_sampling="all")
    val_tasks.append(task)

In [133]:
model = ConvNP(data_processor, task_loader)

dim_yc inferred from TaskLoader: (1,)
dim_yt inferred from TaskLoader: 1
dim_aux_t inferred from TaskLoader: 0
internal_density inferred from TaskLoader: 1180
encoder_scales inferred from TaskLoader: [0.00042372880852781236]
decoder_scale inferred from TaskLoader: 0.000847457627118644


In [134]:
def compute_val_rmse(model, val_tasks):
    errors = []
    target_var_ID = task_loader.target_var_IDs[0][0]  # assume 1st target set and 1D
    for task in np.random.choice(val_tasks, 50, replace = False):
#         print("im in for loop")
        mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True)
#         print("mean calc")
        true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True)
#         print("true calc")
        errors.extend(np.abs(mean - true))
    return np.sqrt(np.mean(np.concatenate(errors) ** 2))
def gen_tasks(dates, progress=True):
    tasks = []
    for date in notebook.tqdm(dates, disable=not progress):
#         N_c = np.random.randint(0, 500)
        task = task_loader(date, context_sampling=["all"], target_sampling="all")
        tasks.append(task)
    return tasks

In [142]:
set_gpu_default_device()
losses = []
val_rmses = []
train_range = pd.date_range('2015-01-02T12:00:00.000000000', '2015-12-31T12:00:00.000000000')
val_range = pd.date_range('2016-01-01T12:00:00.000000000', '2016-12-31T12:00:00.000000000')
val_rmse_best = np.inf
trainer = Trainer(model, lr=5e-5)
for epoch in range(5):
#     print("step1")
    train_tasks = gen_tasks(pd.date_range(train_range[0], train_range[1])[::5], progress=False)

    batch_losses = trainer(train_tasks)
#     print("step3")
    losses.append(np.mean(batch_losses))
    val_rmses.append(compute_val_rmse(model, val_tasks))  
    if val_rmses[-1] < val_rmse_best:
        val_rmse_best = val_rmses[-1]

In [143]:
from deepsensor.active_learning import GreedyAlgorithm

alg = GreedyAlgorithm(
    model,
    X_s = anomalies,
    X_t = anomalies,
    context_set_idx=0,
    target_set_idx=0,
    N_new_context=3,
    progress_bar=True,
)

In [144]:
from deepsensor.active_learning.acquisition_fns import Stddev

acquisition_fn = Stddev(model)

The cell below consistently runs the alg() for 33% and then gives the error "GriddedDataError: Cannot append to gridded data"

In [146]:
val_dates = pd.date_range('2016-01-01T12:00:00.000000000', '2016-12-31T12:00:00.000000000')[::5]
placement_dates = val_dates
placement_tasks = task_loader(placement_dates, context_sampling="all")

X_new_df, acquisition_fn_ds = alg(acquisition_fn, placement_tasks)

 33%|███▎      | 74/222 [01:35<03:11,  1.29s/it]


GriddedDataError: Cannot append to gridded data