In [1]:
import sys
sys.path.insert(0, '../src/')
import joblib
import numpy as np
import torch
from torch.utils.data import DataLoader

from data_generator import train_val_test_split
from models import NonlinearSCI
from trainers import Trainer

# To make this notebook's output stable across runs
np.random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed(2020)
torch.cuda.manual_seed_all(2020)
torch.backends.cudnn.deterministic = True

# Load and preprocess data

In [2]:
data = joblib.load('../data/dataset_1.joblib')
# Data must be in the order of: interventions, observed confounders, observed spatial confounders
center_idx = data[0][3].shape[0] // 2
ndvi = np.array([x[3][center_idx,center_idx] for x in data])   # intervention
nlcd = np.array([x[2] for x in data])                          # confounder
ndvi_neighbors = np.array([x[3] for x in data])                # spatial confounder
ndvi_neighbors = np.expand_dims(ndvi_neighbors, axis=1)
features = [ndvi, nlcd, ndvi_neighbors]
spatial_features = np.array([abs(x[0]-x[1]) for x in data])
targets = np.array([x[4] for x in data])

train_dataset, val_dataset, test_dataset = train_val_test_split(
    features, spatial_features, targets, train_size=0.6, val_size=0.2, test_size=0.2, shuffle=True
)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
dataloaders = {'train': train_dataloader, 'val': val_dataloader, 'test': test_dataloader}

# Build and train nonlinear SCI model

In [3]:
min_dist, max_dist = np.min(spatial_features), np.max(spatial_features)
model = NonlinearSCI(
    num_interventions=1, 
    num_confounders=1, 
    num_spatial_confounders=1, 
    confounder_dim=nlcd.shape[-1], 
    window_size=ndvi_neighbors.shape[-1], 
    unobserved_confounder=True, 
    nys_space=[[0,2093]]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optim = 'sgd'
optim_params = {
    'lr': 1e-5,
}
epochs, patience = 1000, 10
trainer = Trainer(
    model=model, 
    data_generators=dataloaders, 
    optim=optim, 
    optim_params=optim_params, 
    device=device,
    epochs=epochs,
    patience=patience
)
trainer.train()

Training started:

Epoch 1/1000
Learning rate: 0.000010
11s for 12 steps - 885ms/step - loss 33.6840
Validation:
0s - loss 36.1873

Epoch 2/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 33.2468
Validation:
0s - loss 35.7493

Epoch 3/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 32.8291
Validation:
0s - loss 35.2970

Epoch 4/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 32.3977
Validation:
0s - loss 34.8489

Epoch 5/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 31.9695
Validation:
0s - loss 34.4006

Epoch 6/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 31.5406
Validation:
0s - loss 33.9658

Epoch 7/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 31.1239
Validation:
0s - loss 33.5187

Epoch 8/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 30.6949
Validation:
0s - loss 33.0693

Epoch 9/1000
Learning rate: 0.000010
0s for 12 steps - 6ms/step - loss 30.2635
Validation:
0s - lo

0s for 12 steps - 4ms/step - loss 2.9277
Validation:
0s - loss 3.3413

Epoch 76/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.8384
Validation:
0s - loss 3.2383

Epoch 77/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.7578
Validation:
0s - loss 3.1434

Epoch 78/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.6847
Validation:
0s - loss 3.0469

Epoch 79/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.6102
Validation:
0s - loss 2.9618

Epoch 80/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.5456
Validation:
0s - loss 2.8871

Epoch 81/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.4889
Validation:
0s - loss 2.8099

Epoch 82/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.4314
Validation:
0s - loss 2.7375

Epoch 83/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 2.3778
Validation:
0s - loss 2.6684

Epoch 84/1000
Learning rate: 0.000010
0s for 12 steps - 4

Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7232
Validation:
0s - loss 1.6749

Epoch 151/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7222
Validation:
0s - loss 1.6720

Epoch 152/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7214
Validation:
0s - loss 1.6709

Epoch 153/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7207
Validation:
0s - loss 1.6695

Epoch 154/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7198
Validation:
0s - loss 1.6686

Epoch 155/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7192
Validation:
0s - loss 1.6666

Epoch 156/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7187
Validation:
0s - loss 1.6637

Epoch 157/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7177
Validation:
0s - loss 1.6628

Epoch 158/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.7171
Validation:
0s - loss 1.6628

Epoch 159/1000
Learning r

Validation:
0s - loss 1.6003

Epoch 225/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6728
Validation:
0s - loss 1.5986

Epoch 226/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6720
Validation:
0s - loss 1.5993

Epoch 227/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6713
Validation:
0s - loss 1.5982

Epoch 228/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6706
Validation:
0s - loss 1.5980

Epoch 229/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6700
Validation:
0s - loss 1.5981

Epoch 230/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6694
Validation:
0s - loss 1.5964

Epoch 231/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6688
Validation:
0s - loss 1.5956

Epoch 232/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6682
Validation:
0s - loss 1.5948

Epoch 233/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6674
Validation

0s for 12 steps - 4ms/step - loss 1.6264
Validation:
0s - loss 1.5554

Epoch 300/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6257
Validation:
0s - loss 1.5547

Epoch 301/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6252
Validation:
0s - loss 1.5533

Epoch 302/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6247
Validation:
0s - loss 1.5534

Epoch 303/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6241
Validation:
0s - loss 1.5524

Epoch 304/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6234
Validation:
0s - loss 1.5510

Epoch 305/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6228
Validation:
0s - loss 1.5503

Epoch 306/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6221
Validation:
0s - loss 1.5480

Epoch 307/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.6214
Validation:
0s - loss 1.5467

Epoch 308/1000
Learning rate: 0.000010
0s for 12 

0s - loss 1.5044

Epoch 374/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5809
Validation:
0s - loss 1.5035

Epoch 375/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5806
Validation:
0s - loss 1.5037

Epoch 376/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5799
Validation:
0s - loss 1.5033

Epoch 377/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5792
Validation:
0s - loss 1.5014

Epoch 378/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5786
Validation:
0s - loss 1.5007

Epoch 379/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5782
Validation:
0s - loss 1.5014

Epoch 380/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5773
Validation:
0s - loss 1.5001

Epoch 381/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5768
Validation:
0s - loss 1.5020

Epoch 382/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5764
Validation:
0s - loss 

0s for 12 steps - 4ms/step - loss 1.5387
Validation:
0s - loss 1.4508

Epoch 449/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5380
Validation:
0s - loss 1.4488

Epoch 450/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5379
Validation:
0s - loss 1.4491

Epoch 451/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5370
Validation:
0s - loss 1.4489

Epoch 452/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5364
Validation:
0s - loss 1.4491

Epoch 453/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5360
Validation:
0s - loss 1.4475

Epoch 454/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5354
Validation:
0s - loss 1.4474

Epoch 455/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5349
Validation:
0s - loss 1.4482

Epoch 456/1000
Learning rate: 0.000010
0s for 12 steps - 4ms/step - loss 1.5343
Validation:
0s - loss 1.4478

Epoch 457/1000
Learning rate: 0.000010
0s for 12 