# Training a graph neural network to imitate a simulator

In this notebook, we explain how to use our package to train a simple neural network to imitate the output of an AC power flow simulator.

In [1]:
import numpy as np
import tqdm
import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.style.use('dark_background')
mpl.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.tab10.colors)

%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../../..')
import ml4ps as mp

from torch.utils.data import DataLoader


## Downloading a dataset

First of all, we need to download a dataset. We propose to download a small dataset of power grids derived from the case60nordic file (also known as nordic32), randomly generated using [powerdatagen](https://github.com/bdonon/powerdatagen).

The dataset is available on zenodo [here](https://zenodo.org/record/7077699). The following code downloads the dataset if it is not already here. Please be patient, as it may take several minutes (not more than 10 minutes though).

If you have already downloaded the dataset, then this does nothing.

In [2]:
%%bash
if [ ! -d data/case60/ ]
then
    zenodo_get '10.5281/zenodo.7077699' -o data/
    unzip -qq data/case60.zip -d data/
    rm data/case60.zip data/md5sums.txt
fi

## Backend instantation

We need to import a backend, which will serve to read power grid data. In some more complex problem, it will be used to perform power grid simulations.

In this case, we are considering a dataset of .json files that can be read by pandapower. We thus choose the backend that uses pandapower.

In [3]:
backend = mp.PandaPowerBackend()

In [4]:
train_dir = '../../../../powerdatagen/data/case60nordic_small/train'
train_dir_pkl = '../../../../powerdatagen/data/case60nordic_small/train_pkl'

#train_dir = 'data/case60/train'
#train_dir_pkl = 'data/case60/train_pkl'

In [5]:
#mp.pickle_dataset(train_dir, train_dir_pkl, backend=backend)

## Building a normalizer

In [6]:
normalizer = mp.Normalizer(data_dir=train_dir, backend=backend, n_samples=100)

Loading power grids.: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.71it/s]
Extracting features.: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 120.57it/s]


## Building a train set and a data loader

The normalizer is fed to the data loader, so that ...

In [7]:
train_set = mp.PowerGridDataset(data_dir=train_dir_pkl, backend=backend, normalizer=normalizer,
                                pickle=True, load_in_memory=True)
#train_set = mp.PowerGridDataset(data_dir=train_dir, backend=backend, normalizer=normalizer)
train_loader = DataLoader(train_set,
                          batch_size=8,
                          shuffle=True,
                          collate_fn=mp.collate_power_grid)

Loading the dataset in memory.: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:54<00:00, 182.67it/s]


## Building a Hyper Heterogeneous Multi Graph Neural Ordinary Differential Equation

First of all, we need to tell the neural network which features it should take as input, and wich features we want it to output. In this case, we want the neural network to output predictions for the voltage magnitude at each bus.

Moreover, since we are working with a graph neural network, we need to pass the information of where the GNN should look for object addresses.

In [8]:
local_input_feature_names = {
    'bus': ['in_service', 'max_vm_pu', 'min_vm_pu', 'vn_kv'],
    'load': ['const_i_percent', 'const_z_percent', 'controllable', 'in_service', 
                                'p_mw', 'q_mvar', 'scaling', 'sn_mva'],
    'sgen': ['controllable', 'in_service', 'p_mw', 'q_mvar', 'scaling', 'sn_mva',
                                'current_source'],
    'gen': ['controllable', 'in_service', 'p_mw', 'scaling', 'sn_mva', 'vm_pu',
                'slack', 'max_p_mw', 'min_p_mw', 'max_q_mvar', 'min_q_mvar', 'slack_weight'],
    'shunt': ['q_mvar', 'p_mw', 'vn_kv', 'step', 'max_step', 'in_service'],
    'ext_grid': ['in_service', 'va_degree', 'vm_pu', 'max_p_mw', 'min_p_mw', 'max_q_mvar',
                                'min_q_mvar', 'slack_weight'],
    'line': ['c_nf_per_km', 'df', 'g_us_per_km', 'in_service', 'length_km', 'max_i_ka',
                                'max_loading_percent', 'parallel', 'r_ohm_per_km', 'x_ohm_per_km'],
    'trafo': ['df', 'i0_percent', 'in_service', 'max_loading_percent', 'parallel', 
                                'pfe_kw', 'shift_degree', 'sn_mva', 'tap_max', 'tap_neutral', 'tap_min',
                                'tap_phase_shifter', 'tap_pos', 'tap_side', 'tap_step_degree', 
                                'tap_step_percent', 'vn_hv_kv', 'vn_lv_kv', 'vk_percent', 'vkr_percent']
}
local_output_feature_names = {'bus': ['res_vm_pu']}
local_address_names = {
    'bus': ['id'],
    'load': ['bus'],
    'sgen': ['bus'],
    'gen': ['bus'],
    'shunt': ['bus'],
    'ext_grid': ['bus'],
    'line': ['from_bus', 'to_bus'],
    'trafo': ['hv_bus', 'lv_bus']
}

Since we are working with a fully connected neural network, we need to pass a sample to the constructor, so that  it knows how many object of each class will be present in the data. This is due to the fact that fully connected neural networks can only take vector data as input. By telling the neural network the amount of objects, it is able to initialize its weights using the right dimensions.

In [9]:
x, nets = next(iter(train_loader))

In [11]:
h2mgnode = mp.H2MGNODE(x=x,
                       local_input_feature_names=local_input_feature_names,
                       local_output_feature_names=local_output_feature_names,
                       local_address_names=local_address_names,
                       phi_hidden_dimensions=[64],
                       psi_hidden_dimensions=[32, 32],
                       phi_scale_init=[[1e-2, 1e-2], [1e-2, 0]],
                       psi_scale_init=[[1e-2, 1e-2], [1e-2, 1e-2], [1e-2, 0]],
                       latent_dimension=8
                      )

In addition, we need to specify post-processing functions, so that our model starts its training in a reasonable range. Here, we know that voltage magnitudes should be around 1 p.u., so we post-process the neural network output by adding an offset of 1.

In [12]:
# functions = {'bus': {'res_vm_pu': [mp.AffineTransform(offset=1.)]}}
# postprocessor = mp.PostProcessor(functions=functions)

In [13]:
postprocessor = mp.PostProcessor(config={'bus':{'res_vm_pu':[['affine', {'offset':1.}]]}})

## Training loop

Here, we propose to train our neural network using the Adam optimizer. The loss function is the squared distance between the neural network prediction and the output of the simulator.

In [14]:
from jax.example_libraries import optimizers

learning_rate = 3e-3#3e-4#3e-4#3e-4#
opt_init, opt_update, get_params = optimizers.adam(learning_rate)#, b1=0.9999, b2=0.99999)
opt_state = opt_init(h2mgnode.params)

In [15]:
def loss_function(params, start_state, y):
    y_hat = h2mgnode.solve_and_decode_batch(params, start_state)
    y_post = postprocessor(y_hat)
    return jnp.mean((y_post['bus']['res_vm_pu'] - y['bus']['res_vm_pu'])**2)

@jax.jit
def update_jit(params, start_state, y, opt_state, step):
    loss, grads = jax.value_and_grad(loss_function)(params, start_state, y)
    opt_state = opt_update(step, grads, opt_state)
    return get_params(opt_state), opt_state, loss
#@jax.jit
def update(params, x, y, opt_state, step):
    start_state = h2mgnode.init_state_batch(x)
    return update_jit(params, start_state, y, opt_state, step)

In [16]:
step = -1
losses = []

In [17]:
# test_dir = '../../../../powerdatagen/data/case60nordic/test'
# #test_dir = 'data/case60/train'
# test_set = mp.PowerGridDataset(data_dir=test_dir, backend=backend, normalizer=normalizer)
# test_loader = DataLoader(test_set,
#                           batch_size=8,
#                           shuffle=True,
#                           num_workers=2,
#                           collate_fn=mp.collate_power_grid)
# x_fixed, nets = next(iter(test_loader))
# init_state = h2mgnode.init_state_batch(x_fixed)

# from jax.experimental.ode import odeint
# start_and_end_times = jnp.linspace(0., 1., 50)

# def odenet(params, init_state):
#     intermediate_states = odeint(h2mgnode.dynamics, init_state, start_and_end_times, params)
#     return intermediate_states

# batched_odenet = jax.vmap(odenet, in_axes=(None, 0))
# import matplotlib.pyplot as plt
# import os
# #os.mkdir('trajectory')
# #os.mkdir('output')

# y_truth = backend.get_data_batch(nets, feature_names={'bus': ['res_vm_pu']})
# y_truth = np.reshape(y_truth['bus']['res_vm_pu'], [-1])

In [18]:
for epoch in range(10):
    for x, nets in (pbar := tqdm.tqdm(train_loader)):
        step += 1
        
        #backend.run_batch(nets) # AC power flow simulation 
        y = backend.get_data_batch(nets, feature_names={'bus': ['res_vm_pu']}) # Ground truth extraction
        h2mgnode.params, opt_state, loss = update(h2mgnode.params, x, y, opt_state, step)
        
        pbar.set_description("Epoch {}, Loss = {:.2e}".format(epoch, loss))
        losses.append(loss)

Epoch 0, Loss = 5.34e-04:   4%|█████████                                                                                                                                                                                                                                              | 46/1250 [00:32<14:04,  1.43it/s]


KeyboardInterrupt: 

In [None]:
# x, nets = next(iter(train_loader))
# y = backend.get_data_batch(nets, feature_names={'bus': ['res_vm_pu']})
# for _ in tqdm.tqdm(range(10000)):
#     h2mgnode.params, opt_state, loss = update(h2mgnode.params, x, y, opt_state, step)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses, linewidth=0.2)
plt.yscale('log')
plt.show()

## Testing the model

We now wish to take a look at how well our model performs on the test data. In this notebook we propose to plot the ground truth versus the prediction for a sample of data.

In [None]:
test_dir = '../../../../powerdatagen/data/case60nordic_small/test'
#test_dir = 'data/case60/train'
test_set = mp.PowerGridDataset(data_dir=test_dir, backend=backend, normalizer=normalizer)
test_loader = DataLoader(test_set,
                          batch_size=8,
                          shuffle=True,
                          num_workers=2,
                          collate_fn=mp.collate_power_grid)

In [None]:
x, nets = next(iter(test_loader))

In [None]:
# Perform prediction
y_hat = h2mgnode.forward_batch(h2mgnode.params, x)
y_post = postprocessor(y_hat)
y_post = np.reshape(y_post['bus']['res_vm_pu'], [-1])

# Get ground truth
y_truth = backend.get_data_batch(nets, feature_names={'bus': ['res_vm_pu']})
y_truth = np.reshape(y_truth['bus']['res_vm_pu'], [-1])

# Compare results
import matplotlib.pyplot as plt
plt.scatter(y_truth, y_post, s=0.4)
plt.xlabel('Ground truth')
plt.ylabel('Prediction')
plt.show()

np.corrcoef(y_truth, y_post)[0,1]

## Bonus : visualization of latent variables

In [None]:
from jax.experimental.ode import odeint
start_and_end_times = jnp.linspace(0., 1., 100)

def odenet(params, init_state):
    intermediate_states = odeint(h2mgnode.dynamics, init_state, start_and_end_times, params,
                                rtol=1e-4, atol=1e-4)
    return intermediate_states

batched_odenet = jax.vmap(odenet, in_axes=(None, 0))
    
x, nets = next(iter(test_loader))
init_state = h2mgnode.init_state_batch(x)
intermediate_states = batched_odenet(h2mgnode.params, init_state)

y_plot = intermediate_states['h_v'][0,:,:,0]
plt.plot(start_and_end_times, y_plot)
plt.show()

x_plot = intermediate_states['h_v'][0,:,:,0]
y_plot = intermediate_states['h_v'][0,:,:,1]
plt.plot(x_plot, y_plot)
plt.show()

x_plot = intermediate_states['h_v'][0,:,:,0]
y_plot = intermediate_states['h_v'][0,:,:,2]
plt.plot(x_plot, y_plot)
plt.show()

x_plot = intermediate_states['h_v'][0,:,:,2]
y_plot = intermediate_states['h_v'][0,:,:,1]
plt.plot(x_plot, y_plot)
plt.show()

In [None]:
start_and_end_times = jnp.array([0.,1.])
#start_and_end_times = jnp.linspace(0., 1., 50)


def odenet_(params, init_state, atol, rtol):
    intermediate_states = odeint(h2mgnode.dynamics, init_state, start_and_end_times, params,
                                rtol=rtol, atol=atol)
    return intermediate_states

batched_odenet_ = jax.vmap(odenet_, in_axes=(None, 0))
    
def loss_function(params, start_state, y):
    y_hat = h2mgnode.solve_and_decode(params, start_state)
    y_post = postprocessor(y_hat)
    return jnp.mean((y_post['bus']['res_vm_pu'] - y['bus']['res_vm_pu'])**2)

    
x, nets = test_set.__getitem__(0)
init_state = h2mgnode.init_state(x)
y = backend.get_data_network(nets, feature_names={'bus': ['res_vm_pu']})

In [None]:
loss, grads = jax.value_and_grad(loss_function)(h2mgnode.params, init_state, y)

In [None]:
grads

In [None]:
intermediate_states = odeint(h2mgnode.dynamics, init_state, start_and_end_times,
                             h2mgnode.params, rtol=1e-12, atol=1e-12)

In [None]:
intermediate_states['h_v'][-1]

In [None]:
intermediate_states['h_v'][-1]

In [None]:
#os.mkdir('latent_traj')
for t in range(101):
    x_plot = intermediate_states['h_v'][0,:t,:,0]
    y_plot = intermediate_states['h_v'][0,:t,:,1]
    plt.figure(figsize=[10,10], dpi=200)
    plt.plot(x_plot, y_plot)
    plt.axis('off')
    plt.xlim([-1, 1])
    plt.ylim([-1, 1])
    plt.savefig('latent_traj/step_{}'.format(t))
    plt.show()

In [None]:
#h2mgnode.save('model_0')