<a href="https://colab.research.google.com/github/ctorney/learning-to-simulate-tf2/blob/main/test-files/ProcessFishData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ### Imports { form-width: "30%" }

import os, sys
import numpy as np
from math import *
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation, rc
import datetime
import json
import random

import tensorflow as tf

import scipy as sp
from scipy import stats

import pickle

import functools
from tqdm import tqdm

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Input, Concatenate
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

import tensorflow_probability as tfp

from spektral.layers import ECCConv, GlobalAvgPool, MessagePassing, XENetConv, GlobalAttentionPool, GlobalMaxPool, GlobalSumPool,GlobalAttnSumPool
tfpl = tfp.layers
tfd = tfp.distributions

plt.style.use('ggplot')
plt.style.use('seaborn-paper') 
plt.style.use('seaborn-whitegrid')

In [None]:
#with open('test-files/fish_and_frames_alpha (1).mat') as mat_file:
mat = sp.io.loadmat('test-files/fish_and_frames_alpha (1).mat')
full_data = mat['frames']

In [None]:
vxs = []
for d in full_data['speed'][2:]:
    for i in d:
        vxs.extend(np.squeeze(i))
print(len(vxs))
plt.hist(vxs,bins=20)

In [None]:
#@title ### Metadata { form-width: "30%" }
n_frames = len(full_data['px'])

max_id = 0
min_id = 100000000
for d in tqdm(full_data['onfish'][2:],total=n_frames):
    for i in np.squeeze(d[0]):
        if i> max_id: max_id = i
        if i< min_id: min_id = i

total_id = max_id-min_id+1

print(f'max_id = {max_id}, min_id = {min_id}, total_id = {total_id}')

In [None]:
#@title ### Datatypes { form-width: "30%" }

for key in ['px','py','vx','vy','speed']:
    for timestep, layer in tqdm(enumerate(full_data[key][2:]), total=n_frames-2):
        for i, val in enumerate(np.squeeze(layer[0])):
            if full_data[key][2+timestep][0][i,0]==None: full_data[key][2+timestep][0][i,0]=np.nan
            full_data[key][2+timestep][0][i,0] = float(full_data[key][2+timestep][0][i,0])
            if key=='speed':
                if abs(full_data[key][2+timestep][0][i,0])>5: full_data[key][2+timestep][0][i,0]=np.nan

In [None]:
WINDOW_SIZE=3
REPEATS=10
BATCH_SIZE=128
DOMAIN_SIZE=(1920,1080)
domain_max = max(DOMAIN_SIZE)
#append_speeds=True
reject_length = 40

def wind(data,i,w,append_speed=True):
    indices=[]
    for j in range(w):
        index_j=np.squeeze(data['onfish'][2+i+j][0])
        indices.append(index_j)
        
    true_indices=functools.reduce(np.intersect1d,indices)
    n=len(true_indices)
    arr=np.zeros((w,n,4))
    for j in range(w):
        idx=np.searchsorted(np.squeeze(data['onfish'][2+i+j][0]),true_indices)
        arr[j,:,0]=np.squeeze(data['px'][2+i+j][0])[idx]
        arr[j,:,1]=np.squeeze(data['py'][2+i+j][0])[idx]
        arr[j,:,2]=np.squeeze(data['vx'][2+i+j][0])[idx]
        arr[j,:,3]=np.squeeze(data['vy'][2+i+j][0])[idx]
        arr[j,:,2]*=np.squeeze(data['speed'][2+i+j][0])[idx]
        arr[j,:,3]*=np.squeeze(data['speed'][2+i+j][0])[idx]
    
    arr=arr[:,~np.isnan(arr).any(axis=(0,2)),:]
                
    return arr

windows_as_list=[]
for i in tqdm(range(n_frames-WINDOW_SIZE-1),total=n_frames-WINDOW_SIZE-1):
    x=wind(full_data,i,WINDOW_SIZE)
    #x=tf.convert_to_tensor(x)
    if x.shape[1]>=reject_length: windows_as_list.append(x)
    
# windows=[]
# for i in tqdm(windows_as_list):
#     windows.append(tf.ragged.constant(i))

In [None]:
c=0
lens = np.zeros(len(windows_as_list))
for i, window in tqdm(enumerate(windows_as_list),total=n_frames-WINDOW_SIZE-1):
    lens[i]=window.shape[1]
        #print(i, window.shape)
plt.hist(lens,bins=20)

In [None]:
windows_as_list = [window for window in windows_as_list if not window.shape[1]<reject_length]

In [None]:
!nvidia-smi

In [None]:
#@title ### Naive Parse Graph Function { form-width: "30%"}
#@markdown Edge distance not normalised to L

#@tf.function
def naive_pre_process_function(X, V, _Lx=DOMAIN_SIZE[0], _Ly=DOMAIN_SIZE[1], interaction_radius=100, train_mode=True, add_noise=True):
        
        max_dist=max(_Lx,_Ly)
        #X, V = inputs
        # input shape [batch, steps, num. agents, dims]
        # node features xpos, ypos, xvel, yvel
        # edge features distance, rel angle to receiver
        #print(X.shape, V.shape)
        X_current = X[:,-1:,:,:]
        V_current = V[:,-1:,:,:]

        X_current = X_current.merge_dims(1,2)
        V_current = V_current.merge_dims(1,2)

        #X_current = tf.strided_slice(X, [0,-1,0,0], [])
        #print(X_current.shape)
        
        def transpose_tensors_pos(i):
            #print(i)
            #i = i.to_tensor()
            ii = tf.expand_dims(i[...,0], -1)
            dx = tf.linalg.matrix_transpose(ii)-ii
            dx = tf.where(dx>0.5*_Lx, dx-_Lx, dx)
            dx = tf.where(dx<-0.5*_Lx, dx+_Lx, dx)
            dx = tf.RaggedTensor.from_tensor(dx)

            jj = tf.expand_dims(i[...,1], -1)
            dy = tf.linalg.matrix_transpose(jj)-jj
            dy = tf.where(dy>0.5*_Ly, dy-_Ly, dy)
            dy = tf.where(dy<-0.5*_Ly, dy+_Ly, dy)
            dy = tf.RaggedTensor.from_tensor(dy)

            dd = tf.math.sqrt(tf.square(dx)+tf.square(dy))

            return dx, dy, dd

        def transpose_tensors_vel(i):
            #i = i.to_tensor()
            ii = tf.expand_dims(i[...,0], -1)
            dx = tf.linalg.matrix_transpose(ii)-ii
            dx = tf.RaggedTensor.from_tensor(dx)

            jj = tf.expand_dims(i[...,1], -1)
            dy = tf.linalg.matrix_transpose(jj)-jj
            dy = tf.RaggedTensor.from_tensor(dy)

            dnorm = tf.math.sqrt(tf.square(dx)+tf.square(dy))

            dx = tf.math.divide_no_nan(dx,dnorm)
            dy = tf.math.divide_no_nan(dy,dnorm)

            return dx, dy

        dx, dy, dist = tf.map_fn(fn=transpose_tensors_pos,
                          elems=X_current,
                          fn_output_signature=(tf.RaggedTensorSpec(shape=[None, None],
                                                                   ragged_rank=1, 
                                                                   dtype=tf.float64),
                                               tf.RaggedTensorSpec(shape=[None, None], 
                                                                   ragged_rank=1, 
                                                                   dtype=tf.float64),
                                               tf.RaggedTensorSpec(shape=[None, None], 
                                                                   ragged_rank=1, 
                                                                   dtype=tf.float64))
                          )

        dvx, dvy = tf.map_fn(fn=transpose_tensors_vel,
                          elems=V_current,
                          fn_output_signature=(tf.RaggedTensorSpec(shape=[None, None], 
                                                                   ragged_rank=1, 
                                                                   dtype=tf.float64),
                                               tf.RaggedTensorSpec(shape=[None, None], 
                                                                   ragged_rank=1, 
                                                                   dtype=tf.float64))
                          )
        
        bx = _Lx - X_current[...,0:1]
        by = _Ly - X_current[...,1:2]
        boundary_dists = tf.concat([X_current, bx, by], axis=-1)
        boundary_dists = boundary_dists/max_dist

        def angles(x):
            #x = x.to_tensor()
            return tf.expand_dims(tf.math.atan2(x[...,1],x[...,0]),-1)

        angles = tf.map_fn(fn=angles,
                           elems=V_current,
                           fn_output_signature=tf.RaggedTensorSpec(shape=[None, None], 
                                                                   ragged_rank=0, 
                                                                   dtype=tf.float64)
                           )
        angle_to_neigh = tf.math.atan2(dy, dx)

        rel_angle_to_neigh = angle_to_neigh - angles

        #dist = tf.math.sqrt(tf.square(dx)+tf.square(dy))

        #interaction_radius = 25.0# tf.reduce_mean(dist,axis=[1,2],keepdims=True)
        
        def set_diag_func(x):
            x = x.to_tensor()
            x = tf.linalg.set_diag(x, tf.zeros(tf.shape(x)[0],dtype=tf.int32))
            return tf.RaggedTensor.from_tensor(x)
        
        adj_matrix = tf.where(dist<interaction_radius, tf.ones_like(dist,dtype=tf.int32), tf.zeros_like(dist,dtype=tf.int32))
        #adj_matrix = tf.linalg.set_diag(adj_matrix, tf.zeros(tf.shape(adj_matrix)[:2],dtype=tf.int32))
        adj_matrix = tf.map_fn(fn=set_diag_func,
                               elems=adj_matrix,
                               fn_output_signature=(tf.RaggedTensorSpec(shape=[None,None],
                                                                        ragged_rank=1, 
                                                                        dtype=tf.int32)))
        
        sender_recv_list = tf.where(adj_matrix)
        n_edge = tf.reduce_sum(adj_matrix, axis=[1,2])
        #print('sr list', sender_recv_list)
        # def count_nodes(x):
        #     return tf.shape(x)[-1]

        # #n_node = tf.ones_like(n_edge)*tf.shape(adj_matrix)[-1]
        # n_node = tf.map_fn(fn=count_nodes,
        #                    elems=adj_matrix,
        #                    fn_output_signature=tf.RaggedTensorSpec(shape=[None,None], ragged_rank=1, dtype=tf.int32))
        n_node = adj_matrix.row_lengths(axis=1)
        #print(tf.range(adj_matrix.get_shape()[0]))
        #output_i = tf.repeat(tf.range(adj_matrix.get_shape()[0]),n_node)

        # Finds batch indices (s1) for each edge and multiplies by 
        # the number of nodes in each graph (n_node) to determine the 
        # shift along the sparse matrix axes for sender and receiver indices

        s1 = tf.squeeze(tf.slice(sender_recv_list,(0,0),size=(-1,1)))
        s2 = tf.squeeze(tf.slice(sender_recv_list,(0,1),size=(-1,1)))
        s3 = tf.squeeze(tf.slice(sender_recv_list,(0,2),size=(-1,1)))
        # print('s1, s2, s3', s1.shape, s2.shape, s3.shape)
        # print('nodes', n_node)#.shape)

        def indice_func(n):
            return n_node[n]
        
        indice_update = tf.map_fn(fn=indice_func,
                                  elems=s1
                                  #fn_output_signature=
                                  )
        senders = s2 + indice_update
        receivers = s3 + indice_update

        # def send_func(n):
        #     return s2 + s1*n
      
        # def rec_func(n):
        #     return s3 + s1*n
        
        # senders = tf.map_fn(fn=send_func,
        #                     elems=n_node,
        #                     fn_output_signature=tf.TensorSpec(shape=[None], dtype=tf.int64)
        #                     )
        # receivers = tf.map_fn(fn=send_func,
        #                     elems=n_node,
        #                     fn_output_signature=tf.TensorSpec(shape=[None], dtype=tf.int64)
        #                     )
        #print(senders.shape)
        # senders = tf.squeeze(tf.slice(sender_recv_list,(0,1),size=(-1,1))) + tf.squeeze(tf.slice(sender_recv_list,(0,0),size=(-1,1)))*n_node
        # receivers = tf.squeeze(tf.slice(sender_recv_list,(0,2),size=(-1,1))) + tf.squeeze(tf.slice(sender_recv_list,(0,0),size=(-1,1)))*tf.shape(adj_matrix,out_type=tf.int64)[-1]
        total_nodes = tf.reduce_sum(n_node, axis=0)
        output_a = tf.sparse.SparseTensor(indices=tf.stack([senders,receivers],axis=1), values = tf.ones_like(senders),dense_shape=[total_nodes, total_nodes])
        edge_distance = tf.expand_dims(tf.gather_nd(dist/max_dist,sender_recv_list),-1)
        #print("ed", edge_distance.shape)
        edge_x_distance = tf.expand_dims(tf.gather_nd(tf.math.cos(rel_angle_to_neigh),sender_recv_list),-1)  # neigbour position relative to sender heading
        edge_y_distance = tf.expand_dims(tf.gather_nd(tf.math.sin(rel_angle_to_neigh),sender_recv_list),-1)  # neigbour position relative to sender heading
        edge_x_orientation = tf.expand_dims(tf.gather_nd(dvx,sender_recv_list),-1)  # neigbour velocity relative to sender heading
        edge_y_orientation = tf.expand_dims(tf.gather_nd(dvy,sender_recv_list),-1)  # neigbour velocity relative to sender heading


        output_e = tf.concat([edge_distance,edge_x_distance,edge_y_distance,edge_x_orientation,edge_y_orientation],axis=-1)
        #edges = tf.concat([edge_distance,edge_x_distance,edge_y_distance],axis=-1)
        #node_velocities = tf.transpose(V, perm=[0,2,1,3])
        def vel_transpose_func(x):
            x = x.to_tensor()
            x = tf.transpose(x, perm=[1,0,2])
            return tf.reshape(x, (-1, 2*(WINDOW_SIZE-1)))
        
        #print(V.shape)
        node_velocities = tf.map_fn(fn=vel_transpose_func,
                                    elems=V,
                                    fn_output_signature=tf.RaggedTensorSpec(shape=[None, None],
                                                                            ragged_rank=0, 
                                                                            dtype=tf.float64),
                                    infer_shape=False
                                    )

        #node_accelerations = tf.transpose(A, perm=[0,2,1,3])
        output_x = tf.concat([node_velocities, boundary_dists], axis=-1)
        output_x = output_x.merge_dims(0,1)
        #output_x = tf.reshape(node_velocities,(-1,2*(WINDOW_SIZE-1)))

        return output_x, output_a, output_e#, output_i

naive_ppfunc = functools.partial(naive_pre_process_function,
                           _Lx=DOMAIN_SIZE[0],
                           _Ly=DOMAIN_SIZE[1]
                           )

In [None]:
for i in windows_as_list:
    print(i.shape)
    break

In [None]:
def split_targets(x, w=WINDOW_SIZE):
    inputs = (x[0:w-1,:,0:2], x[0:w-1,:,2:4])
    targets = x[None,-1,:,2:4]
    return (inputs, targets)

split_with_window = functools.partial(
  split_targets,
  w=WINDOW_SIZE)  

train_total = 500
valid_total = 20

#windows_list = random.sample(windows_as_list,train_total+valid_total)
window_dataset = tf.data.Dataset.from_generator(lambda: windows_as_list,
                                                    output_signature=tf.TensorSpec(shape=(WINDOW_SIZE,None,4), dtype=tf.float64))
    
for i in window_dataset:
    print(i.shape)
    break
#window_dataset.map(lambda x: tf.data.Dataset.from_tensor_slices(x))

# #window_dataset = tf.data.Dataset.from_tensor_slices(windows_list)
window_dataset = window_dataset.map(split_with_window)
window_dataset = window_dataset.repeat(REPEATS)
window_dataset = window_dataset.shuffle(10000, reshuffle_each_iteration=True)
window_dataset = window_dataset.apply(tf.data.experimental.dense_to_ragged_batch(BATCH_SIZE, drop_remainder=True))
window_dataset = window_dataset.take(train_total+valid_total)
window_dataset = window_dataset.map(lambda x, y: (naive_ppfunc(x[0], x[1]), y))
train_dataset = window_dataset.skip(valid_total)
valid_dataset = window_dataset.take(valid_total)
# for i in window_dataset:
#     print(i[1])
#     break

#print(len(list(window_dataset)))

#valid_dataset = window_dataset.take(valid_total)
#train_dataset = window_dataset.skip(valid_total)
#train_total = train_dataset.reduce(0, lambda x, _: x + 1)

In [None]:
#@title ### Model { form-width: "30%" }

min_lr = 1e-6
lr = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-3 - min_lr,
                                decay_steps=int(5e6),
                                decay_rate=0.1) #+ min_lr
MLP_SIZE=32
n_out=2

class GNNNet(Model):
    def __init__(self,n_out=4,mp_steps=2):
        super().__init__()
        #self.preprocess = PreprocessingLayer(DOMAIN_SIZE)
        self.encoder = XENetConv([MLP_SIZE,MLP_SIZE], MLP_SIZE, 2*MLP_SIZE, use_bias=False, node_activation="tanh", edge_activation="tanh") #ECCConv(32, activation="tanh")
        
        self.process = XENetConv([MLP_SIZE,MLP_SIZE], MLP_SIZE, 2*MLP_SIZE, use_bias=False, node_activation="relu", edge_activation="relu")# MessagePassing(aggregate='mean')# 32, activation="relu")
        
        #self.decoder = Dense(tfpl.IndependentNormal.params_size(n_out), activation="tanh",use_bias=False) #ECCConv(n_out, activation="tanh")
        
        self.mean_decoder = Dense(n_out, activation="linear",use_bias=False) #ECCConv(n_out, activation="tanh")
        self.std_decoder = Dense(n_out, activation="linear",kernel_initializer=tf.keras.initializers.Zeros(),bias_initializer=tf.keras.initializers.Constant(5.),activity_regularizer=tf.keras.regularizers.l2(1000)) #ECCConv(n_out, activation="tanh")

        self.distribution = tfp.layers.IndependentNormal(n_out)#lambda (mu, std): tfd.Normal(loc=mu, scale=std))
        #self.global_pool = GlobalAvgPool()
        #self.dense = Dense(32, activation="relu")
        #self.final = Dense(n_out, activation="sigmoid")
        self.mp_steps=mp_steps
    
    #@tf.function()
    def call(self, inputs, train_mode=True):
        #x, a, e, i = self.preprocess(inputs, train_mode=train_mode)
        #x.set_shape([None,None])
        #print(x.shape, a.shape, e.shape, i.shape)
        x, a, e = inputs
        x, e = self.encoder([x, a, e])
        #x = self.node_encoder(x)
        #e = self.edge_encoder(e)
        for _ in range(self.mp_steps):
            x, e = self.process([x, a, e])         
        mu = self.mean_decoder(x)
        std = self.std_decoder(x)
        x = tf.keras.layers.concatenate([mu,std])
#        x = self.distribution([mu,std])
#        x = self.decoder(x)
        x = self.distribution(x)
        #print('model', x)
        #x = self.global_pool([x, i])
        #x = self.dense(x)
        #x = self.final(x)

        #x = tf.reshape(x, (-1, N, 2))

        return x


model = GNNNet(n_out=n_out)
optimizer = Adam(lr)
#loss_fn = MeanSquaredError()

def loss_fn(target,predicted_dist):
    return -tf.reduce_sum(predicted_dist.log_prob(target.merge_dims(0,2)))

In [None]:
model.load_weights('fish_weights_1')

In [None]:
#@title ### Train Step

#@tf.function()#input_signature=[[tf.TensorSpec(shape=(None,None,2), dtype=tf.float32),tf.TensorSpec(shape=(None,None,2), dtype=tf.float32),tf.TensorSpec(shape=(None,None,2), dtype=tf.float32)],tf.TensorSpec(shape=(None,4), dtype=tf.float32)], experimental_relax_shapes=True)
def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        #print('predtar', predictions, target)
        target = tf.cast(target, dtype=tf.float32)
        loss = loss_fn(target, predictions) + sum(model.losses)
        #print(sum(model.losses))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, sum(model.losses)

In [None]:
#@title ### Train Model { form-width: "30%" }

step = losses = 0
epochs = 30  # Number of training epochs
batch_size = BATCH_SIZE  # Batch size
print(f"total = {train_total}, batch size = {batch_size}")
divisor=20
loss_values = np.zeros((epochs,train_total))
reg_values = np.zeros((epochs,train_total))

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    pbar = tqdm(enumerate(train_dataset), total=train_total)
    
    for step, batch in pbar:
        inputs, target = batch
        losses, regs = train_step(inputs,target)
        loss_values[epoch,step]=losses.numpy()
        reg_values[epoch,step]=regs.numpy()
        if step%divisor==0:
            pbar.set_description(f"loss {tf.reduce_mean(loss_values[epoch,max(0,step-divisor):step+1]).numpy()}, \
            regs {tf.reduce_mean(reg_values[epoch,max(0,step-divisor):step+1]).numpy()}")
    print(f'Epoch Averages: -Loss {tf.reduce_mean(loss_values[epoch,0:step+1]).numpy()} \
    -Regs {tf.reduce_mean(reg_values[epoch,0:step+1]).numpy()}')

In [None]:
model.save_weights('fish_weights_2')

In [None]:
#@title ### Loss Plot

compression = 100

loss = np.ndarray.flatten(loss_values)
reg = np.ndarray.flatten(reg_values)
#loss = np.nanmean(np.pad(loss.astype(float), (0, compression - loss.size%compression), mode='constant', constant_values=np.NaN).reshape(-1, compression), axis=1)
plt.plot(loss[100:])
plt.plot(reg[100:])
plt.yscale('log')

In [None]:
#@title ### Validation

pred_list = []
true_values = []
valid_loss=0
c=0

for databatch in tqdm(valid_dataset, total=valid_total):

    target = databatch[1]
    true_values.append(tf.expand_dims(target.merge_dims(0,-2),0).numpy())

    predictions = model(databatch[0],train_mode=False)
    pred_list.append((predictions.sample(1).numpy()))
    target = tf.cast(target, dtype=tf.float32)
    #loss_value = tf.keras.losses.MeanSquaredError()(target,predictions).numpy()
    loss_value = loss_fn(target, predictions).numpy()
    valid_loss+= loss_value
    c+=1
    if c>=50:
        break

print('validation loss', valid_loss/c)
#print(pred_list)

In [None]:
from IPython.core.pylabtools import figsize
#@title ### Sample Predictions v Targets Plot { form-width: "30%" }
# plt.plot(tf.expand_dims(target.merge_dims(0,-2),0).numpy()[0,:,1],
#          predictions.sample(1).numpy()[0,:,1],
#          '.')
# plt.ylim((-1,1))
#target.numpy().shape

fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(10,4))
ax1.plot(tf.expand_dims(target.merge_dims(0,-2),0).numpy()[0,:,0],
         predictions.sample(1).numpy()[0,:,0],
         '.')
ax2.plot(tf.expand_dims(target.merge_dims(0,-2),0).numpy()[0,:,1],
         predictions.sample(1).numpy()[0,:,1],
         '.')
ax1.set_xlim(-5,5)
ax1.set_ylim(-5,5)
ax2.set_xlim(-5,5)
#ax2.set_ylim(-5,5)
ax1.set_title('X-coords')
ax2.set_title('Y-coords')
#ax2.scatter(x, y)

In [None]:
#@title ### Predictions v Targets Mean

fig, axs = plt.subplots(1,2, figsize=(16, 8), facecolor='w', edgecolor='k')  
print(pred_list[0].shape)
axs = axs.ravel()
for pred_i in range(2):
    #pred_vals = np.array([pp[:,:,pred_i] for pp in pred_list],dtype=object).flatten()
    #true_vals = np.array([tt[:,:,pred_i] for tt in true_values]).flatten()

    pred_vals = np.concatenate([(np.tanh(pp[:,:,pred_i])/0.7).squeeze() for pp in pred_list])
    true_vals = np.concatenate([(np.tanh(pp[:,:,pred_i])/0.7).squeeze() for pp in true_values])

    bin_means, bin_edges, binnumber = stats.binned_statistic(true_vals, pred_vals,bins=100,range=(-1,1))
    bin_width = (bin_edges[1] - bin_edges[0])
    bin_centers = bin_edges[1:] - bin_width/2

    bin_stds, bin_edges, binnumber = stats.binned_statistic(true_vals, pred_vals,statistic='std',bins=100,range=(-1,1))


    axs[pred_i].plot(bin_centers,bin_means,c='C0')

    axs[pred_i].fill_between(bin_centers,bin_means-bin_stds,bin_means+bin_stds,color='C0',alpha=0.5)

    xx = np.linspace(bin_edges.min(),bin_edges.max(),10)
    axs[pred_i].plot(xx,xx,c='k',ls='--')

    axs[pred_i].set_ylabel('GNN prediction of parameter')
    axs[pred_i].set_xlabel('True parameter that generated the microstate')
    #axs[pred_i].set_xlim(-1,1)
    #axs[pred_i].set_ylim(-1,1)


#plt.savefig('gnn_' + str(run) + '.png',dpi=300)
#plt.show()