This code implements the DGME method proposed and analyzed in the paper ``DEEP BACKWARD AND GALERKIN METHODS FOR LEARNING FINITE
STATE MASTER EQUATIONS'' by Asaf Cohen, Mathieu Laurière and Ethan Zell.

The example solved here corresponds to Example 7.1 in the paper.

In [None]:
import tensorflow as tf
import numpy as np
import math, keras, random, copy
tf.random.set_seed(703)
np.random.seed(703)

In [None]:
# global, static parameters
d=2

In [None]:
class DGMDatasetGenerator4:
  '''
  A class that generates the dataset for the DGM.
  '''
  def __init__(self, d=d, horizon=1.):
    self.t_data = None
    self.x_data = None
    self.eta_data = None
    self.t_terminal_data = None
    self.x_terminal_data = None
    self.eta_terminal_data = None
    self.d = d
    self.horizon = horizon

  def uniformly_random_measure(self):
    '''
    Uses exponential random variables to generate a uniformly random probability vector.
    '''
    pre_normalized = np.random.exponential(1, size = (self.d,))
    return pre_normalized / sum(pre_normalized)

  def uniformly_random_measure_vec(self, samples):
    '''
    Vectorizes the prior function.
    '''
    data = np.zeros((samples, self.d))
    for k in range(samples):
      data[k,:] = self.uniformly_random_measure()
    return data

  def generate_dataset(self, samples=1000, terminal_samples = 100):

    '''
    Creates the dataset which, for the DGME, needs t, x, and eta data.
    '''

    t_data = np.random.uniform(0,self.horizon,size=samples)
    x_data = np.random.choice([float(m) for m in range(self.d)],size=samples)
    eta_data = self.uniformly_random_measure_vec(samples = samples)

    t_terminal_data = np.full(shape = (terminal_samples,), fill_value = self.horizon)
    x_terminal_data = np.random.choice([float(m) for m in range(self.d)], size=terminal_samples)
    eta_terminal_data = self.uniformly_random_measure_vec(samples = terminal_samples)

    self.t_data = t_data
    self.x_data = x_data
    self.eta_data = eta_data

    self.t_terminal_data = t_terminal_data
    self.x_terminal_data = x_terminal_data
    self.eta_terminal_data = eta_terminal_data

    return t_data, x_data, eta_data, t_terminal_data, x_terminal_data, eta_terminal_data

  def oversampling(self, oversample_T = True):
    '''
    Apply this function after generate_dataset to modify the domain of the sampled eta and T.
    Sampling outside the domain may improve performance along the domain's boundary.
    '''
    if oversample_T:
      terminal_samples = self.t_terminal_data.shape[0]

      random_mask = np.full(terminal_samples, False)
      random_mask[:int(terminal_samples * 0.5)] = True # this will make half of them oversampling
      np.random.shuffle(random_mask)

      self.t_terminal_data[random_mask == True] += np.random.uniform(0, self.horizon, size = int(terminal_samples * 0.5))

    return  self.t_data, self.x_data, self.eta_data, self.t_terminal_data, self.x_terminal_data, self.eta_terminal_data

  def data_to_tensors(self):

    self.t_data = tf.convert_to_tensor(self.t_data, dtype = 'float32')
    self.x_data = tf.convert_to_tensor(self.x_data, dtype = 'float32')
    self.eta_data = tf.convert_to_tensor(self.eta_data, dtype = 'float32')
    self.t_terminal_data = tf.convert_to_tensor(self.t_terminal_data, dtype = 'float32')
    self.x_terminal_data = tf.convert_to_tensor(self.x_terminal_data, dtype = 'float32')
    self.eta_terminal_data = tf.convert_to_tensor(self.eta_terminal_data, dtype = 'float32')

    return  self.t_data, self.x_data, self.eta_data, self.t_terminal_data, self.x_terminal_data, self.eta_terminal_data

In [None]:
class DGMModel(tf.keras.Model):
  '''
  This class defines the neural network model.
  '''
  def __init__(self, architecture):
    super(DGMModel, self).__init__()
    self.architecture = architecture # you can give a list specifying the number of nodes in each dense layer
    self.layer_list = []

    for i,number_of_nodes in enumerate(architecture):
      if i == 0:
        self.layer_list.append(tf.keras.layers.Dense(units=number_of_nodes, activation='sigmoid',
                                                              kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.),
                                                              bias_initializer='zeros'))
      else:
        self.layer_list.append(tf.keras.layers.Dense(units=number_of_nodes, activation='sigmoid',
                                                              kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.),
                                                              bias_initializer='zeros'))
    self.layer_list.append(tf.keras.layers.Dense(units=1, activation = 'elu'))

  def call(self, t , x, eta): # changed the concatenation because eta changed?
    t = tf.expand_dims(t, axis = -1)
    x = tf.expand_dims(x, axis = -1)
    input = tf.concat([t, x, eta], 1)
    result = input
    for layer in self.layer_list:
      result = layer(result)
    return result

Recall from the paper that we are interested in approximately solving the master equation:

$$
\partial_t U(t,x,\eta) = H(x,\Delta_x U(t,\cdot,\eta))+ F(x,\eta) + \sum_{y,z\in [d]} D^\eta_{yz} U(t,x,\eta) \gamma^*_z(y,\Delta_y U(t,\cdot,\eta)) \eta_y,
$$

where in this example:

$$F(x,\eta) = \eta_x,$$

$$H(x,p) := \min_{a} \Big\{\frac{1}{2}|a|^2 + a\cdot p\Big\},$$

and where $\gamma^*$ is the associated minimal argument that minimizes the Hamiltonian $H$. Recall that $\Delta_x b:= (b_y - b_x)_{y\in [d]}$ is a finite difference vector and $D^\eta_{yz}$ denotes the directional derivative in the $z$ minus $y$ direction (in terms of the standard basis).

In the Loss class below, $F$ is referred to as the mean_field_cost and $H$ is the Hamiltonian.

In [None]:
class Loss():
  '''
  This class defines the loss and involves the entire PDE.
  '''

  def __init__(self, model, d, a = 2., a_l = 1., a_u = 3., b = 4.):
    self.model = model
    self.d = d
    self.a = a
    self.a_l = a_l
    self.a_u = a_u
    self.b = b
    return

  def a_star(self, psi_left, psi_right):
    '''
    The computed value of $\gamma^*$, the minimal argument of the Hamiltonian. The paper derives this formula explicitly.
    '''
    numerator = psi_left - psi_right
    return ( numerator / ( (self.a_u - self.a_l) * self.b) ) + self.a

  def Hamiltonian_z1(self, x, z1_state_tensor, psi_output, psi_z1_output):
    '''
    The Hamiltonian, denoted H in the paper.
    '''
    a_star_z1 = self.a_star(psi_output, psi_z1_output)
    where_unequal = tf.cast(tf.math.logical_not(tf.math.equal(x, z1_state_tensor)), dtype = 'float32')
    pre_running = a_star_z1 - ( self.a * tf.ones(shape = a_star_z1.shape, dtype='float32') )
    running_cost = self.b * tf.multiply(tf.math.square(pre_running), where_unequal)
    change_of_state = tf.multiply(a_star_z1, (psi_z1_output - psi_output))
    return running_cost + change_of_state

  def interaction_sum_term_z1_z2(self, z1, z2, eta, eta_derivative, psi_z1, psi_z2):
    '''
    This is an individual term from the sum part of the PDE.
    '''
    mf = eta[:,int(z1)]
    directional_derivative = eta_derivative[:,int(z2)] - eta_derivative[:,int(z1)]
    control = self.a_star(psi_z1, psi_z2)
    result = tf.multiply(tf.multiply(mf, directional_derivative), control)
    return result

  def mean_field_cost(self, x, eta):
    '''
    The common cost, denoted F in the paper.
    '''
    mf = np.zeros(x.shape)
    for i,entry in enumerate(x):
      mf[i] = eta[i, int(entry)]
    mfc = tf.convert_to_tensor(mf, dtype='float32')
    return mfc

  def derivatives(self, t, x, eta):

    '''
    The time and measure derivatives are needed to compute the interaction_sum_term_z1_z2 function. This then goes into the sum term of the PDE.
    '''

    with tf.GradientTape(persistent=True) as tape:
      tape.watch(t)
      V = self.model(t,x,eta)
    time_derivative = tape.gradient(V,t)

    with tf.GradientTape(persistent=True) as tape2:
        tape2.watch(eta)
        V = self.model(t,x,eta)
    eta_derivative = tape2.gradient(V,eta)

    self.time_derivative = time_derivative
    self.eta_derivative = eta_derivative

    return time_derivative, eta_derivative

  def criterion(self, t, x, eta):

    '''
    Combining all the prior terms into the loss. We need two for loops in order to compute the sum term.
    '''

    time_derivative, eta_derivative = self.derivatives(t, x, eta)
    output = tf.squeeze(self.model(t, x, eta))

    sum = tf.zeros(t.shape)
    for z1 in range(d):

      z1_state_tensor = z1 * tf.ones(shape = t.shape, dtype='float32')
      psi_z1 = tf.squeeze(self.model(t, z1_state_tensor, eta))
      sum += self.Hamiltonian_z1(x, z1_state_tensor, output, psi_z1)

      for z2 in range(d):

        z2_state_tensor = z2 * tf.ones(shape = t.shape, dtype='float32')
        psi_z2 = tf.squeeze(self.model(t, z2_state_tensor, eta))
        term = self.interaction_sum_term_z1_z2(z1, z2, eta, eta_derivative, psi_z1, psi_z2)
        sum += self.interaction_sum_term_z1_z2(z1, z2, eta, eta_derivative, psi_z1, psi_z2)

    mean_field_cost = self.mean_field_cost(x, eta)

    loss_sum = time_derivative + sum + mean_field_cost

    squared_loss = tf.math.square(loss_sum)

    return squared_loss # not yet reduced

  def terminal_criterion(self, t_terminal, x_terminal, eta_terminal):

    '''
    We have an additional function to compute the loss at the terminal time.
    '''

    terminal_output = tf.squeeze(self.model(t_terminal, x_terminal, eta_terminal))
    squared_terminal = tf.math.square(terminal_output)

    return squared_terminal

  def total_criterion(self, t, x, eta, t_terminal, x_terminal, eta_terminal, factor=1.):
    unreduced_loss = self.criterion(t, x, eta)
    unreduced_terminal_loss = self.terminal_criterion(t_terminal, x_terminal, eta_terminal)
    loss = tf.reduce_mean(unreduced_loss)
    terminal_loss = tf.reduce_mean(unreduced_terminal_loss)
    return loss + (factor * terminal_loss)


In [None]:
class Train():
  def __init__(self, model, dataset_generator, b=1., factor = 1., oversampling = True, return_losses = False, verbose = False, visual_output = False):
    self.model = model
    self.dsg = dataset_generator
    self.return_losses = return_losses
    self.losses = []
    self.verbose = verbose
    self.visual_output = visual_output
    self.factor = factor
    self.b = b

  def loss_gradient(self):
    loss_fn = Loss(model = self.model, d = d, b=self.b)
    with tf.GradientTape(persistent=True) as loss_tape:
      loss = loss_fn.total_criterion(self.t, self.x, self.eta, self.t_T, self.x_T, self.eta_T, factor = self.factor)
    return loss, loss_tape.gradient(loss, self.model.trainable_variables)

  def step(self, optimizer):

    '''
    A single step in the training regime of the neural network.
    '''

    loss, loss_grad = self.loss_gradient()

    if self.verbose:
      self.avg_losses.append(loss.numpy())

    if self.return_losses:
      self.losses.append(loss)

    optimizer.apply_gradients(zip(loss_grad, self.model.trainable_variables))
    return self.model

  def train_nn(self, epochs, steps_per_epoch, learning_rate = 1e-3, verbose=False):

    '''
    The main training function to train the neural network.
    '''

    print('Training the DGM network.')

    lr_fn = tf.optimizers.schedules.PolynomialDecay(initial_learning_rate=learning_rate, decay_steps = int(epochs*steps_per_epoch),
                                                    end_learning_rate=1e-6, power = 0.9)
    opt = tf.keras.optimizers.Adam(lr_fn)

    for m in range(epochs):

      self.avg_losses = []

      self.dsg.generate_dataset()
      self.dsg.oversampling()
      self.t, self.x, self.eta, self.t_T, self.x_T, self.eta_T, = self.dsg.data_to_tensors()

      for step in range(steps_per_epoch):

        self.model = self.step(opt)

      if self.verbose:
        print(f'Avg loss for epoch {m} was: {np.mean(self.avg_losses)}')


    self.model.save_weights('dgm_weights')

    if self.return_losses:
      return self.model, self.losses
    return self.model

In [None]:
trainer = Train(model = DGMModel([d+2,50,50,50]), dataset_generator = DGMDatasetGenerator4(d=d), factor = 10., verbose=False)
trainer.train_nn(epochs = 1, steps_per_epoch=2, learning_rate = 1e-3)
# save whole model
trainer.model.save(f'model_tf', save_format = 'tf')

Training the DGM network.




Below are the plotting functions used for the DGME in the paper.

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import seaborn as sns
import imageio
import copy
import matplotlib as mpl
import plotly.graph_objects as go

Below there are two visualization classes, Viz2 and VizPropagation.

In [None]:
class Viz2:
  def __init__(self, model, num_measure_points =1_000):
    self.model = model
    self.num_points = num_measure_points


  def single_graph_population_updated(self, t=0., x=0.):
    '''
    This function gets the values from the neural network, populated on the simplex. Only for d=2.
    '''
    two_simplex = np.linspace(start = 0, stop = 1, num = self.num_points, endpoint = True)
    eta_data = np.zeros((self.num_points, d))
    eta_data[:,0] = two_simplex
    eta_data[:,1] = 1. - two_simplex
    eta_data = tf.convert_to_tensor(eta_data, dtype='float32')
    t = tf.fill((self.num_points,), t)
    x = tf.fill((self.num_points,), x)
    y = self.model(t,x,eta_data)
    y_for_graph = y.numpy()[:,0]
    return two_simplex, y_for_graph

  def display_single_graph_updated(self, t=0., x=0.):
    '''
    Uses the prior function to graph the DGME neural network's values. Only for d=2.
    '''
    two_simplex, y_for_graph = self.single_graph_population_updated(t=t, x=x)
    fig, ax = plt.subplots(figsize=(6, 4)) #, tight_layout=True)
    ax.set_ylim([0,1])
    ax.plot(two_simplex, y_for_graph)
    ax.set_xlabel(f'$\mu(x=1)$')
    ax.set_ylabel(f'$U(t={round(t,2)},x={int(x)+1},\eta=\mu)$')
    ax.set_title(r'')
    ax.plot(two_simplex, y_for_graph, color = 'black')
    return


  def display_graph_rainbow_updated(self, x=0., num_times = 15, T = .5):
    '''
    Displays the DGME result but at multiple times. Only for d=2.
    '''
    fig, ax = plt.subplots(figsize=(6, 4)) #, tight_layout=True)
    ax.set_ylim([-0.01,0.34])
    ax.set_xlabel(f'$\mu(x=1)$')
    ax.set_ylabel('')
    ax.set_title(f'$U(x={int(x)+1},\eta=\mu)$')

    time_points = np.linspace(0,T, num = num_times, endpoint = True)

    color = cm.rainbow(time_points)

    for k, c in enumerate(color):
      two_simplex, y_for_graph = self.single_graph_population_updated(t=time_points[k], x=x)
      if k == 0:
        ax.set_ylim([-0.01,max(y_for_graph)*1.05])
      ax.plot(two_simplex, y_for_graph, color = c, label = f't={round(time_points[k],2)}')

    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    return

  def display_heatmap(self, num_points = 25, t=0., x=0.):

    '''
    A specific function to graph the DGME network's values in the case d=3.
    '''
    # create numpy array of inputs
    num_array = int(num_points*(num_points+3)/2.) + 1
    eta_data = np.zeros((num_array, 3))
    k=0
    num_round = 0
    for i in range(num_points + 1):
      for j in range(num_points - i + 1):
        eta1 = float(i)/num_points
        eta2 = float(j)/num_points
        eta3 = 1. - eta1 - eta2
        eta_data[k, 0] = eta1
        eta_data[k, 1] = eta2
        eta_data[k, 2] = eta3
        k+=1
    eta_data = tf.convert_to_tensor(eta_data, dtype='float32')
    t_data = tf.fill((num_array,), t)
    x_data = tf.fill((num_array,), x)
    output = self.model(t_data, x_data, eta_data)

    grid_data = np.zeros((num_points+1, num_points+1))
    np_out = output.numpy()[:,0]
    i = num_points
    j = 0
    for k in range(num_array):
      grid_data[i,j] = np_out[k]
      if j == i:
        i -= 1
        j = 0
      else:
        j += 1

    skip = 3
    x_labels = []
    for i in range(grid_data.shape[0]):
      if i % skip == 0:
        val = round(i/(grid_data.shape[0]-1),2)
      else:
        val = ''
      x_labels.append(val)
    y_labels = []
    for i in range(grid_data.shape[1]):
      if i % skip == 0:
        val = round(1.- i/(grid_data.shape[1]-1),2)
      else:
        val = ''
      y_labels.append(val)

    mask = np.triu(np.ones_like(grid_data, dtype=bool))

    # x_labels = [round(i/(grid_data.shape[0]-1),2) for i in range(0,grid_data.shape[0],2)]
    # y_labels = [round(1.- i/(grid_data.shape[1]-1),2) for i in range(0, grid_data.shape[1],2)]
    ax = sns.heatmap(grid_data, mask=mask, xticklabels=x_labels, yticklabels=y_labels, vmin=0, vmax=0.225, cmap="PiYG")

    # decrease density of tick labels
    for index, label in enumerate(ax.get_xticklabels()):
      if index % 2 == 0:
          label.set_visible(True)
      else:
          label.set_visible(False)

    for index, label in enumerate(ax.get_yticklabels()):
      if index % 2 == 0:
          label.set_visible(True)
      else:
          label.set_visible(False)

    ax.set_xlabel(f'$\mu(x=2)$')
    ax.set_ylabel('$\mu(x=1)$')
    ax.set_title(f'$U(t={round(t,2)}, x={int(x)+1},\eta=\mu)$')
    plt.show()

    return ax

  def heatmap_timelapse(self, num_points = 25, time_steps = 100, T=0.5, x=0.):
    '''
    Make a gif by using the d=3 values for multiple times.
    '''
    images = []
    for t_bar in range(time_steps):
      t = (t_bar / (time_steps - 1) ) * T
      plt.clf()
      hm = self.display_heatmap(num_points = 25, t=t, x=0.)

      fig = hm.get_figure()
      fig.savefig(f"hm_{t_bar}.png")
      images.append(imageio.imread(f"hm_{t_bar}.png"))
    imageio.mimsave('hm.gif',images,fps=55)
    return

In [None]:
class VizPropagation:

  def __init__(self, model, d=3):
    self.model = model
    self.a = 2
    self.au = 3
    self.al = 1
    self.b = 4
    self.d = d

  def a_star(self, numerator):
    return self.a + ( numerator / ((self.au - self.al) * self.b) )

  def uniformly_random_measure(self):
    pre_normalized = np.random.exponential(1, size = (self.d,))
    return pre_normalized / sum(pre_normalized)

  def get_measure_points(self, delta_t, T, initial_eta):

    '''
    This function uses the network and the Fokker--Planck equation to get the corresponding points of the mean field equilibrium.
    '''
    num_steps = int(T/delta_t)+1
    measure_data = np.zeros((num_steps,self.d))
    current_time = 0
    for step in range(num_steps):
      # each step propagate by the Kolmogorov equation
      if (step == 0):
        past_mu = initial_eta
        measure_data[0, :] = past_mu
      else:
        past_mu = measure_data[step-1,:]

        U_vals = []
        for x in range(self.d):
          t_data = tf.cast(tf.fill((1,), current_time), dtype = 'float32')
          x_data = tf.cast(tf.fill((1,), x), dtype = 'float32')
          eta_data = tf.expand_dims(input = tf.convert_to_tensor(past_mu, dtype='float32'), axis = 0)
          U_x = self.model(t_data, x_data, eta_data).numpy()[:,0]
          U_vals.append(U_x)

        for x in range(self.d):
          temp = 0
          for y in range(self.d):
            if y != x:
              gamma = self.a_star(U_vals[y] - U_vals[x])
              temp  += ((past_mu[y]-past_mu[x]) * gamma * delta_t)
          measure_data[step, x] = (copy.copy(past_mu[x]) + temp)

      current_time += delta_t
    return measure_data

  def get_value_points(self, delta_t, T, initial_eta):
    '''
    Along the path from the previous function, this function gives the corresponding value.
    '''
    num_steps = int(T/delta_t)+1
    measure_data = np.zeros((num_steps,self.d))
    current_time = 0

    measure_data = self.get_measure_points(delta_t=delta_t, T=T, initial_eta=initial_eta)
    value_data = np.zeros_like(measure_data)

    for step in range(num_steps):
      if (step == 0):
        past_mu = initial_eta
        measure_data[0, :] = past_mu
      else:
        past_mu = measure_data[step-1,:]
      for x in range(self.d):
        t_data = tf.cast(tf.fill((1,), current_time), dtype = 'float32')
        x_data = tf.cast(tf.fill((1,), x), dtype = 'float32')
        eta_data = tf.expand_dims(input = tf.convert_to_tensor(past_mu, dtype='float32'), axis = 0)
        U_x = self.model(t_data, x_data, eta_data).numpy()[:,0]
        value_data[step,x] = U_x

      current_time += delta_t
    return measure_data, value_data

  def plot_mf_evolution(self, delta_t, T, initial_eta):
    '''
    A plotting function for the mean field equilibrium.
    '''
    fig, ax = plt.subplots(figsize=(6, 4)) #, tight_layout=True)
    ax.set_ylim([0,1.])
    if self.d > 3:
      ax.set_ylim([0,max(initial_eta)+0.05])
    ax.set_xlabel('$t$')
    ax.set_ylabel('$\mu(t,x)$')

    title_string = 'Initial State: ('
    for k in range(len(initial_eta)):
      title_string += str(round(initial_eta[k],2))
      if k < len(initial_eta) - 1:
        title_string += ', '
      else:
        title_string += ')'
    ax.set_title(title_string)

    num_steps = int(T/delta_t)+1

    states = np.linspace(0, self.d, num = self.d, endpoint = False)
    time_steps = np.linspace(0, T, num = num_steps, endpoint = True)

    color = mpl.colormaps['Paired'](np.arange(0,self.d))

    measure_data = self.get_measure_points(delta_t = delta_t, T=T, initial_eta = initial_eta)

    for k, c in enumerate(color):
      state_k_data = measure_data[:,k]
      ax.plot(time_steps, state_k_data, color = c, label = f'$\mu(t,{int(k)+1})$')
    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    return

  def plot_value_propagation(self, delta_t, T, initial_eta):
    '''
    A plotting function for the cost of the mean field equilibrium.
    '''
    fig, ax = plt.subplots(figsize=(6, 4)) #, tight_layout=True)
    ax.set_xlabel('$t$')
    ax.set_ylabel('$\mu(t,x)$')

    title_string = 'Initial State: ('
    for k in range(len(initial_eta)):
      title_string += str(round(initial_eta[k],2))
      if k < len(initial_eta) - 1:
        title_string += ', '
      else:
        title_string += ')'
    ax.set_title(title_string)

    num_steps = int(T/delta_t)+1

    states = np.linspace(0, self.d, num = self.d, endpoint = False)
    time_steps = np.linspace(0, T, num = num_steps, endpoint = True)

    color = mpl.colormaps['Paired'](np.arange(0,self.d))

    measure_data, value_data = self.get_value_points(delta_t = delta_t, T=T, initial_eta = initial_eta)

    ax.set_ylim([0,0.35])

    for k, c in enumerate(color):
      state_k_data = value_data[:,k]
      ax.plot(time_steps, state_k_data, color = c, label = f'$U(t,{int(k)+1},\mu(t))$')
    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    return

  def plot_mf_on_simplex(self, delta_t, T, initial_eta):
    # ONLY for dimension d=3
    '''
    A plotting function for the mean field equilibrium, only for d=3.
    '''
    num_steps = int(T/delta_t)+1
    time_steps = np.linspace(0, T, num = num_steps, endpoint = True)
    measure_data = self.get_measure_points(delta_t = delta_t, T=T, initial_eta = initial_eta)
    x, y, z = measure_data[:,0], measure_data[:,1], measure_data[:,2]

    fig = go.Figure(data=[go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(
            size=3,
            color=z,                # set color to an array/list of desired values
            colorscale='Viridis',   # choose a colorscale
            opacity=0.95
        )
    )])

    # tight layout
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    fig.update_layout(
    scene = dict(
        xaxis = dict(range=[0,1],),
                     yaxis = dict(range=[0,1],),
                     zaxis = dict(range=[0,1],),),
    width=700,
    margin=dict(r=20, l=10, b=10, t=10))

    x = np.outer(np.linspace(0, 1, 30), np.ones(30))
    y = x.copy().T
    z = (1-x) - y
    fig.add_trace(go.Surface(x=x, y=y, z=z, colorscale = 'gray', opacity = 0.3, showscale = False))
    fig.show()
    return

The following give some examples of how to implement the Viz2 and VizPropagation classes:

For the d=2, rainbow graph:

In [None]:
# d=2
# loaded_model = keras.models.load_model('LOCAL PATH/Most_Models/i_dgm_max_2')
# loaded_model.compile()
# viz = Viz2(model = loaded_model)
# viz.display_graph_rainbow_updated(T=0.5)

Plotting the cost along the mean field equilibrium over time:

In [None]:
# loaded_model = keras.models.load_model('LOCAL PATH/Most_Models/i_dgm_max_2')
# loaded_model.compile()
# vizp = VizPropagation(model = loaded_model, d=2)

# for eta1 in [0.1 * k for k in range(1,6,1)]:
#   vizp.plot_value_propagation(delta_t = 0.005, T=0.5, initial_eta = [eta1,1.-eta1])

The special dimension 3 plot of the mean field equilibrium on the simplex:

In [None]:
# loaded_model = keras.models.load_model('LOCAL PATH/Most_Models/i_dgm_max_3')
# loaded_model.compile()
# vizp = VizPropagation(model = loaded_model)
# vizp.plot_mf_on_simplex(delta_t = 0.005, T = 0.5, initial_eta = [0.7, 0.2, 0.1])

The d=3 heatmaps:

In [None]:
# d=3
# loaded_model = keras.models.load_model('LOCAL PATH/Most_Models/i_dgm_max_3')
# loaded_model.compile()
# viz = Viz2(model = loaded_model)
# plt.clf()
# for t in [0., 0.1, 0.2, 0.3, 0.4, 0.5]:
#   viz.display_heatmap(num_points = 60, t=t, x=0.)