In [1]:
import pandas as pd
import os
import json
import numpy as np
import torch
from steams.data.KVyQVx import KVyQVx
from steams.models.mads import mads
from steams.tepe.steams import attention_steams
from steams.tepe.tepe import train, evaluation,prediction_prime,ensemble_prime
from steams.utils.criterion import R2,RMSE,variance
import matplotlib.pyplot as plt

In [2]:
parent_dir = os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), os.pardir))

In [3]:
if torch.cuda.is_available():
    cuda_name = "1"
    num_workers=0
    pin_memory = True
    device = torch.device('cuda'+":"+cuda_name)
else:
    num_workers = 0
    pin_memory = False
    device = torch.device('cpu')

In [4]:
device

device(type='cuda', index=1)

## train/valid dataset

In [5]:
config_dir = os.path.join(parent_dir,'config')

In [6]:
f_config = open(os.path.join(config_dir,'train_eval_attention_cpu.json'),encoding='utf8')
params = json.load(f_config)

In [7]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_train_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_1000_X_train_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['X']['nb_sampling'] = 100

train_dataset = KVyQVx(params['data'])

In [8]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_valid_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_1000_X_valid_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['X']['nb_sampling'] = 100

valid_dataset = KVyQVx(params['data'])

## model, optimizer, criterion, scheduler

In [12]:
#model
model = mads(device,type="krig",kernel="exp",input_k=4)

In [13]:
# optimzer
optimizer = torch.optim.Adam([model.W], lr=1e-4) #8e-3

# criterion
criterion = torch.nn.MSELoss()

# steams object
obj = attention_steams(model,device)
obj.init_optimizer(optimizer)
obj.init_criterion(criterion)

## Training

In [14]:
# training
train(obj,train_dataset,valid_dataset,niter=10,n_iter_stop=20,batch_size=1,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,resdir=None)

KeyboardInterrupt: 

In [None]:
obj.save_model(parent_dir,"3a-krig")

## Evaluation

In [None]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_eval_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_1000_X_eval_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['X']['nb_sampling'] = 100

eval_dataset = KVyQVx(params['data'])

In [None]:
criterion = R2()
#criterion = torch.nn.MSELoss()
obj.init_criterion(criterion)

In [None]:
evaluation(obj,eval_dataset,batch_size=1,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,resdir=None)

In [None]:
criterion = RMSE()
obj.init_criterion(criterion)

In [None]:
evaluation(obj,eval_dataset,batch_size=1,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,resdir=None)

In [None]:
criterion = variance()
obj.init_criterion(criterion)

In [None]:
evaluation(obj,eval_dataset,batch_size=1,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,resdir=None)

## QQplot

In [None]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_eval_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_1000_X_eval_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['X']['nb_sampling'] = 100


eval_dataset = KVyQVx(params['data'])

In [None]:
results = prediction_prime(obj,eval_dataset)

In [None]:
pred_name = ['pred_' + v for v in eval_dataset.VALUE_X]

fig, ax= plt.subplots(figsize=(14, 12))
ax.scatter(results.loc[:,eval_dataset.VALUE_X],results.loc[:,pred_name], linewidth=2.0,c="black")
ax.axline((0, 0), slope=1., color='blue')
ax.set_xlim(0, 2.5)
ax.set_ylim(results.loc[:,pred_name].min().item()-0.5, results.loc[:,pred_name].max().item()+0.5)
ax.set(xlabel='observations', ylabel='predictions')

fig_filename = os.path.join(parent_dir,'fig','3a-qqplot.png')
plt.savefig(fig_filename, dpi = 300)

## Wq

In [None]:
W = obj.model.W.detach()
W

## ensemble, quantiles and p-value 

In [None]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_eval_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_all_6400_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6400
params['data']['X']['nb_sampling'] = 100

eval_dataset = KVyQVx(params['data'])

In [None]:
# !! might take some time
ensemble = ensemble_prime(obj,eval_dataset,N=2)

In [None]:
ensemble.to_csv(os.path.join(parent_dir,'3a-ensemble.csv'))

### metrics

In [None]:
obs = torch.tensor(ensemble[('sensor_hq', 'mean')].to_numpy())
obs = torch.reshape(obs,(1,obs.shape[0],1))
q0_5 = torch.tensor(ensemble[('pred_sensor_hq','q0_5')].to_numpy())
q0_5 = torch.reshape(q0_5,(1,q0_5.shape[0],1))

In [None]:
criterion = R2()
criterion(obs,q0_5)

In [None]:
criterion = RMSE()
criterion(obs,q0_5)

In [None]:
criterion = variance()
criterion(obs,q0_5)

### qqplot of the observation against the median of the ensemble at each location

In [None]:
fig, ax= plt.subplots(figsize=(14, 12))
ax.scatter(ensemble[('sensor_hq', 'mean')],ensemble[('pred_sensor_hq','q0_5')], linewidth=2.0,c="black")
ax.axline((0, 0), slope=1., color='blue')
ax.set_xlim(0, 2.5)
ax.set_ylim(ensemble[('pred_sensor_hq','q0_5')].min().item()-0.5, ensemble[('pred_sensor_hq','q0_5')].max().item()+0.5)
ax.set(xlabel='observations', ylabel='median-ensemble')

fig_filename = os.path.join(parent_dir,'fig','3a-ensemble_q0_5_qqplot.png')
plt.savefig(fig_filename, dpi = 300)

### p-qqplot of the p-value of the obsevration within its ensemble against an uniform distribution

In [None]:
ensemble = ensemble.sort_values(by='p_sensor_hq')
ensemble["U"] = np.linspace(0, 1, num=ensemble.shape[0])    

In [None]:
fig, ax= plt.subplots(figsize=(8, 8))
ax.scatter(ensemble['U'],ensemble['p_sensor_hq'], linewidth=2.0,c="black")
ax.axline((0, 0), slope=1., color='blue')
ax.set_xlim(0, 1)
ax.set_ylim(0,1)
ax.set(xlabel='U[0,1]', ylabel='p-value of observation')

fig_filename = os.path.join(parent_dir,'fig','3a-p_qqplot.png')
plt.savefig(fig_filename, dpi = 300)

## Illustration of quantile on maps

In [None]:
params['data']['Y']['path'] = os.path.join(parent_dir,"session","synth_1000_Y_eval_6s")
params['data']['Y']['VALUE'] = ['value']
params['data']['Y']['KEY'] = ['x','y','rmse','variance']
params['data']['Y']['nb_location'] = 6000 # 6x1000, 6 sources
params['data']['Y']['nb_sampling'] = 100

params['data']['X']['path'] = os.path.join(parent_dir,"session","synth_all_6400_6s")
params['data']['X']['VALUE'] = ['ref']
params['data']['X']['QUERY'] = ['x','y','rmse_ref','variance_ref']
params['data']['X']['nb_location'] = 6400
params['data']['X']['nb_sampling'] = 100

eval_dataset = KVyQVx(params['data'])

In [None]:
# !! might take some time
ensemble = ensemble_prime(obj,eval_dataset,N=2)

In [None]:
vmin = np.array([ensemble[('sensor_hq', 'mean')].min(),
                 ensemble[('pred_sensor_hq','q0_5')].min(),
                 ensemble[('pred_sensor_hq','q0_05')].min(),
                 ensemble[('pred_sensor_hq','q0_95')].min()]).min()

vmax = np.array([ensemble[('sensor_hq', 'mean')].max(),
                 ensemble[('pred_sensor_hq','q0_5')].max(),
                 ensemble[('pred_sensor_hq','q0_05')].max(),
                 ensemble[('pred_sensor_hq','q0_95')].max()]).max()


fig, (ax1, ax2,ax3,ax4) = plt.subplots(4,figsize=(14, 12))
cs = ax1.scatter(x=ensemble["x"], y=ensemble["y"], c=ensemble[('sensor_hq', 'mean')],vmin=vmin, vmax=vmax)
cs = ax2.scatter(x=ensemble["x"], y=ensemble["y"], c=ensemble[('pred_sensor_hq', 'q0_05')],vmin=vmin, vmax=vmax)
cs = ax3.scatter(x=ensemble["x"], y=ensemble["y"], c=ensemble[('pred_sensor_hq', 'q0_5')],vmin=vmin, vmax=vmax)
cs = ax4.scatter(x=ensemble["x"], y=ensemble["y"], c=ensemble[('pred_sensor_hq', 'q0_95')],vmin=vmin, vmax=vmax)

ax1.set(ylabel='Y')
ax2.set(ylabel='Y')
ax3.set(ylabel='Y')
ax4.set(xlabel='X', ylabel='Y')

fig.colorbar(cs, ax=[ax1,ax2,ax3,ax4])

fig_filename = os.path.join(parent_dir,'fig','3a-illustration_map_quantile.png')
plt.savefig(fig_filename, dpi = 300)
