In [1]:
import sys
sys.path.insert(0, '../src/')
import joblib
import numpy as np
import torch
from torch.utils.data import DataLoader

from data_generator import train_val_test_split
from models import NonlinearSCI
from trainers import Trainer

# To make this notebook's output stable across runs
np.random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed(2020)
torch.cuda.manual_seed_all(2020)
torch.backends.cudnn.deterministic = True

# Load and preprocess data

In [2]:
data = joblib.load('../data/dataset_1.joblib')
# Data must be in the order of: interventions, observed confounders, observed spatial confounders
center_idx = data[0][3].shape[0] // 2
ndvi = np.array([x[3][center_idx,center_idx] for x in data])   # intervention
nlcd = np.array([x[2] for x in data])                          # confounder
ndvi_neighbors = np.array([x[3] for x in data])                # spatial confounder
ndvi_neighbors = np.expand_dims(ndvi_neighbors, axis=1)
features = [ndvi, nlcd, ndvi_neighbors]
spatial_features = np.array([abs(x[0]-x[1]) for x in data])
targets = np.array([x[4] for x in data])

train_dataset, val_dataset, test_dataset = train_val_test_split(
    features, spatial_features, targets, train_size=0.6, val_size=0.2, test_size=0.2, shuffle=True
)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
dataloaders = {'train': train_dataloader, 'val': val_dataloader, 'test': test_dataloader}

# Build and train nonlinear SCI model

In [None]:
min_dist, max_dist = np.min(spatial_features), np.max(spatial_features)
model = NonlinearSCI(
    num_interventions=1, 
    num_confounders=1, 
    num_spatial_confounders=1, 
    confounder_dim=nlcd.shape[-1], 
    window_size=ndvi_neighbors.shape[-1], 
    unobserved_confounder=True, 
    nys_space=[[0,2093]]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optim = 'sgd'
optim_params = {
    'lr': 1e-5,
}
epochs, patience = 1000, 10
trainer = Trainer(
    model=model, 
    data_generators=dataloaders, 
    optim=optim, 
    optim_params=optim_params, 
    device=device,
    epochs=epochs,
    patience=patience
)
trainer.train()