<font color='red'>WARNING: This notebook assumes that normalizer.pkl and h2mgnode.pkl are present in the current directory. Run the previous notebooks to generate them.</font> 

In [None]:
import ML4PS as ml
import numpy as np
from matplotlib import pyplot as plt

# Loading and preprocessing the data

In [None]:
data_dir = '../../data/case14'

normalizer = ml.Normalizer(file = 'normalizer.pkl')

interface = ml.Interface(data_dir = data_dir,
    backend_name = 'pandapower', batch_size = 1)

# Loading a trained method

In [None]:
h2mgnode = ml.H2MGNODE(file = 'h2mgnode.pkl')

# Compute metrics

In [None]:
class PostProcessor:
    def __call__(self, y):
        return {'bus': {'res_vm_pu': self.bus_v_mag(y['bus']['res_vm_pu'])}}
    def bus_v_mag(self, y):
        return 1.+y
postprocessor = PostProcessor()

In [None]:
def loss(params, init_state, y_truth):
    y_hat = h2mgnode.solve_and_decode_batch(params, init_state)
    y_post = postprocessor(y_hat)
    return ml.mean((y_truth['bus']['res_vm_pu'] - y_post['bus']['res_vm_pu'])**2)

@ml.jit
def update(params, init_state, y_truth, opt_state):
    value, grads = ml.value_and_grad(loss)(params, init_state, y_truth)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

In [None]:
step_size = 3e-4
opt_init, opt_update, get_params = ml.optimizers.adam(step_size)
opt_state = opt_init(h2mgnode.weights)

In [None]:
losses = []
maes = []

for a, x, nets in interface.get_test_batch():
    
    # Perform prediction
    x_norm = normalizer(x)
    y_hat = h2mgnode.forward_batch(h2mgnode.weights, a, x_norm)
    y_post = postprocessor(y_hat)

    # Get ground truth
    interface.run_load_flow_batch(nets)
    y_truth = interface.get_features_dict(nets, {'bus':['res_vm_pu']})
    
    # Compute metrics
    loss = (y_post['bus']['res_vm_pu'] - y_truth['bus']['res_vm_pu'])**2
    mae = np.abs(y_post['bus']['res_vm_pu'] - y_truth['bus']['res_vm_pu'])
    losses.extend(list(loss))
    maes.extend(list(mae))
    
print('Loss')
print('    max        = {:e}'.format(np.max(losses)))
print('    90th perc. = {:e}'.format(np.percentile(losses, 90)))
print('    50th perc. = {:e}'.format(np.percentile(losses, 50)))
print('    10th perc. = {:e}'.format(np.percentile(losses, 10)))
print('    min        = {:e}'.format(np.min(losses)))
print('')
print('MAE')
print('    max        = {:e}'.format(np.max(maes)))
print('    90th perc. = {:e}'.format(np.percentile(maes, 90)))
print('    50th perc. = {:e}'.format(np.percentile(maes, 50)))
print('    10th perc. = {:e}'.format(np.percentile(maes, 10)))
print('    min        = {:e}'.format(np.min(maes)))

# Plot prediction against ground truth

In [None]:
a, x, nets = next(iter(interface.test))
x_norm = normalizer(x)
y_hat = h2mgnode.forward_batch(h2mgnode.weights, a, x_norm)
y_post = postprocessor(y_hat)

# Get ground truth
interface.run_load_flow_batch(nets)
y_truth = interface.get_features_dict(nets, {'bus': ['res_vm_pu']})

# Compare results
plt.scatter(y_truth['bus']['res_vm_pu'], y_post['bus']['res_vm_pu'])
plt.xlabel('Ground truth')
plt.ylabel('Prediction')
plt.show()

# Plot the evolution of latent variables

The H2MGNODE architecture is based on the interaction of latent variables defined at both hyper edges and addresses (namely $h_e$ and $h_v$). We propose 

In [None]:
start_and_end_times = ml.linspace(0., 1., 50)

def odenet(params, init_state):
    intermediate_states = ml.odeint(h2mgnode.dynamics, init_state, start_and_end_times, params)
    return intermediate_states

batched_odenet = ml.vmap(odenet, in_axes=(None, 0))
    
a, x, nets = next(iter(interface.test))
x_norm = normalizer(x)
init_state = h2mgnode.init_state_batch(a, x_norm)
intermediate_states = batched_odenet(h2mgnode.weights, init_state)

y_plot = intermediate_states['h_e']['bus'][0,:,:,0]
plt.plot(start_and_end_times, y_plot)
plt.show()

x_plot = intermediate_states['h_e']['bus'][0,:,:,0]
y_plot = intermediate_states['h_e']['bus'][0,:,:,2]
plt.plot(x_plot, y_plot)
plt.show()

y_plot = intermediate_states['h_v'][0,:,:,0]
plt.plot(start_and_end_times, y_plot)
plt.show()

x_plot = intermediate_states['h_v'][0,:,:,0]
y_plot = intermediate_states['h_v'][0,:,:,1]
plt.plot(x_plot, y_plot)
plt.show()