# GINA
### Neural Graph Inference From Independent Snapshots of Interacting System   
Modeling and Simulation Group @ Saarland University 

https://arxiv.org/abs/2105.14329  
https://github.com/gerritgr/gina  


# Setup

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import math, random, os, time
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import scipy
import seaborn as sns
sns.set_style("white")
import pickle
from tqdm import tqdm
import time
import netrd

## Hyperparameters

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HOME = './'   # where to save experiments output

USE_VALIDATIONSET = True     # overwrites TRAINING_FRAC
TRAINING_FRAC = 0.8          # only relveant when using test data
LEARNING_RATE =  0.0001 
TEMPERATURE = 5.0            # sharpness parameter v, higher temperature => sharper step function (mu)
RANDOM_TEMPERATURE = False   # activates stochastic noise in step function (unused)
INCREASE_TEMPERATURE = 3.0   # increase sharpness over course of training
USE_LAYER_3 = False          # activate additional MLP layer
INDIVIDUAL_INPUT = False     # actives custom embedding for each node (unused)
INDIVIDUAL_OUTPUT = False    # custom output projection for each node
MINIBATCH_SIZE = 500
EPOCHNUM = 100*100
OUTPUT_EACH_X_EPOCHS = 100   # how often evaluate GINA
INIT_GRAPH_SHIFT = 1         # how to init graph (unused)
USESHIFT = True
COLLAPSE_IN_TEST = True      # only relevant when using training/test set
TRAIN_NETWORK = True         # use this to fix graph during training and only train MLP weights
EARLY_STOPPING = True        # only used in Exp3
GRAPH_DROPOUT = True         # used to increase numerial stability

print('Use  device: ', DEVICE)

# Misc

## Utils

In [None]:
def edgelist_to_graph(num_nodes, edgelist):
  G = nx.Graph()
  G.add_nodes_from(range(num_nodes))
  G.add_edges_from(edgelist)
  return G

def float2str(x):
  return str(x).zfill(20)

def int2str(x):
  return str(int(x+0.5)).zfill(20)

def read_snapshots(filepath):
  with open(filepath, 'r') as f:
      lines = f.read().split('\n')
  states = [eval(line) for line in lines if len(line) > 2]
  return states

def write_snapshots(snapshots, filepath):
  with open(filepath, 'w') as f:
      for state in snapshots:
          f.write(repr(state)+'\n')


def sample_temperature(mean_temperature, cut=3.0, var=0.2):
    t = np.random.normal(mean_temperature, var)
    if t < cut:
        return cut
    return t
    
def binarize_state(s, cut=None):
    if len(s) == 2:
        return s
    if cut is None:
        cut = int(len(s)/2+0.6)
    return [np.sum(s[:cut]), np.sum(s[cut:])]

def binarize_networkstate(networkstate, cut=None):
    return [binarize_state(s, cut=cut) for s in networkstate]

def binarize_snapshot(snapshot, cut=None):
    return [binarize_networkstate(s, cut=cut) for s in snapshot]

def snapshots_to_netrd(snapshots):
  num_snapshots = len(snapshots)
  network_size = len(snapshots[0])
  snapshots_np = np.zeros([num_snapshots,network_size])
  for i in range(num_snapshots):
    for j in range(network_size):
      snapshots_np[i][j] = np.argmax(snapshots[i][j])
  snapshots_np = np.array(snapshots_np)
  snapshots_np = snapshots_np.transpose()
  return snapshots_np


def nx_to_adj(G):
  a_ij = nx.adjacency_matrix(G).todense()
  a_ij = (a_ij > 0.0001).astype('float')
  a_ij = torch.FloatTensor(a_ij)
  a_ij = a_ij.to(DEVICE)
  return a_ij

def graph_dist(G1, G2):
  a_ij = nx.adjacency_matrix(G1).todense() # graphs have to be symmetric
  b_ij = nx.adjacency_matrix(G2).todense()

  #threshold
  a_ij = (a_ij > 0.0001).astype('float')
  b_ij = (b_ij > 0.0001).astype('float')

  z_ij = a_ij - b_ij
  z_ij = np.abs(z_ij)
  np.fill_diagonal(z_ij, 0.0) # should not need this
  return np.sum(z_ij)/2

def gen_folders(exp_name, home=HOME):
    os.system('mkdir '+home)
    os.system('mkdir '+home+exp_name)
    if not os.path.exists(home+exp_name+'/'+'NeuralNetworkReconstruction.ipynb'):
        orig_path = "NeuralNetworkReconstruction.ipynb"
    copy_path = home+exp_name+'/'+'NeuralNetworkReconstruction.ipynb'
    os.system("cp '{}' '{}'".format(orig_path, copy_path))
    print("cp '{}' '{}'".format(orig_path, copy_path))
    for out_folder in ['graph_evol', 'snapshots', 'weights', 'dynamics_prediction']:
        os.system('mkdir '+home+exp_name+'/'+out_folder)

# this needs to be tested more
def clean_shuffle_graph(G):
  random_seed_state =  int(random.random()*100000) # quick hack to go back to random afterwards
  random.seed(42)
  node_mapping = dict(zip(sorted(G.nodes()), sorted(G.nodes(), key=lambda _: random.random()))) # maybe sorted not really deterministic
  G = nx.relabel_nodes(G, node_mapping)
  G = nx.convert_node_labels_to_integers(G)
  if not nx.is_connected(G):
    print('Graph is not connected, try a differnt one.')
    assert(nx.is_connected(G))
  random.seed(random_seed_state)
  return G


def split_snapshots(snapshots, training_frac=TRAINING_FRAC):
  if not USE_VALIDATIONSET:
    training_frac = 1.0
  random.seed(42)
  random.shuffle(snapshots)
  cut_off = int(len(snapshots)*training_frac)
  training_data = snapshots[:cut_off]
  test_data = snapshots[cut_off:]
  random.seed(None)
  training_data = [torch.tensor(snapshot, dtype=torch.float, device=DEVICE) for snapshot in training_data]
  test_data = [torch.tensor(snapshot, dtype=torch.float, device=DEVICE) for snapshot in test_data]
  if not USE_VALIDATIONSET:
    return training_data, None
  return training_data, test_data

def split_single_snapshot(snapshot, pos = None):
  snapshot = snapshot.clone()
  network_size = snapshot.shape[0]
  state_num = snapshot.shape[1]
  if pos is None:
    pos = random.choice(range(network_size))
  ground_truth = snapshot[pos,:].clone()
  snapshot[pos,:] = 0.0
  return snapshot, ground_truth, pos


def state_to_color(state):
  colors = ['blue', 'red', 'black', 'green', 'orange', 'yellow', 'pink']
  return colors[np.argmax(state) % 6]

def plot_graph(G, name, exp_name, ground_truth=None, is_final=False):
  plt.clf()
  if ground_truth is not None:
    pos=nx.kamada_kawai_layout(ground_truth)
  else:
    pos=nx.kamada_kawai_layout(G)
  nx.draw(G, node_color='black', pos=pos, alpha=0.5)
  plt.savefig(HOME+exp_name+'/graph_evol/{}.png'.format(name))
  if is_final:
    nx.write_gml(G, HOME+exp_name+'/final_graph.gml'.format(name))
    plt.savefig(HOME+exp_name+'/graph_evol/final_graph.png'.format(name))


def plot_dynamics(name, exp_name, model, num=10, state_num=2):
  im = np.zeros([num, num])
  for i in range(num):
    for j in range(num):
      input = np.ones(state_num)
      input[0] = i
      input[1] = j
      output = model.forward_counts(torch.tensor(input, dtype=torch.float, device=DEVICE))
      output = output.view(-1)
      im[i,j] = float(output[0])
  plt.clf()
 # plt.imshow(im)
  ax = sns.heatmap(im, vmin=0, vmax=1, cmap='vlag', square=True)
  plt.ylim(0,num)
  plt.xlim(0,num)   
  plt.savefig(HOME+exp_name+'/dynamics_prediction/{}.png'.format(name))


def plot_snapshots(G, snapshots, exp_name, max_plot_num=30):
  print("plot_snapshots")
  for i, snapshot in enumerate(snapshots):
    assert(len(snapshot) == G.number_of_nodes())
    if max_plot_num is not None and i>=max_plot_num:
      break
    node_color = [state_to_color(snapshot[j]) for j in range(len(snapshot))]
    plt.clf()
    nx.draw(G, node_color=node_color, pos=nx.spectral_layout(G))
    plt.savefig(HOME+exp_name+'/snapshots/snapshot_{}.png'.format(int2str(i)))
    plt.clf()
    nx.draw(G, node_color=node_color, pos=nx. kamada_kawai_layout(G))
    plt.savefig(HOME+exp_name+'/snapshots/l_snapshot_{}.png'.format(int2str(i)))


def datasets_form_snapshots(snapshots, training_frac=TRAINING_FRAC):
  # 2 slow 2 use
    random.seed(42)
    random.shuffle(snapshots)
    cut_off = int(len(snapshots)*training_frac)
    training_data = snapshots[:cut_off]
    test_data = snapshots[cut_off:]
    random.seed(None)
    
    network_size = len(snapshots[0])
    
    train = np.array(training_data)
    train = train.reshape([-1,network_size])
    train = torch.FloatTensor(train, device=DEVICE)
    train_loader = data_utils.DataLoader(train, batch_size=MINIBATCH_SIZE, shuffle=True)

    test = np.array(test_data)
    test = test.reshape([-1,network_size])
    test = torch.FloatTensor(test, device=DEVICE)
    test_loader = data_utils.DataLoader(test, batch_size=test.shape[0])
    
    return train_loader, test_loader

def fill_symmetric_matrix(values, node_num):
  assert(len(values) == node_num*(node_num-1)/2)
  m = torch.zeros([node_num,node_num], dtype=torch.double, device=DEVICE)
  counter = -1
  for i in range(m.shape[0]):
    for j in range(m.shape[1]):
      if i<j:
        counter += 1
        m[i,j] = values[counter]
        m[j,i] = values[counter]
  return m
  
def create_all_graphs(node_num, check_if_connected=True):
  import itertools
  len_values = int(node_num*(node_num-1)/2)
  graphs = list(itertools.product([0.0,1.0], repeat=len_values))
  graphs = [fill_symmetric_matrix(g, node_num) for g in graphs]
  graphs = [nx.from_numpy_matrix(g.cpu().detach().numpy()) for g in graphs]
  if check_if_connected:
    graphs = [g for g in graphs if nx.is_connected(g)]
  return graphs

class SnapshotLoader(): # todo make to iterator
    def __init__(self, snapshots, batch_size=None, shuffle=True):
        self.snapshots = [torch.FloatTensor(snapshot, device=DEVICE) for snapshot in snapshots]
        if batch_size is None:
          batch_size = len(snapshots)
        self.batch_size = batch_size
        self.shuffle = shuffle
    def get_data(self):
      if self.shuffle:
        random.shuffle(self.snapshots)
      data_list = [list()]
      for i, snapshot in enumerate(self.snapshots):
        if len(data_list[-1]) >= self.batch_size:
          data_list.append(list())
        data_list[-1].append(snapshot)
      data_list = [torch.cat(test_data, dim=1) for test_data in data_list]
      return data_list

def snapshots_to_loader(snapshots, training_frac=TRAINING_FRAC):
    random.seed(42)
    random.shuffle(snapshots)
    cut_off = int(len(snapshots)*training_frac)
    training_data = snapshots[:cut_off]
    test_data = snapshots[cut_off:]
    random.seed(None)

    training_loader = SnapshotLoader(training_data, batch_size=MINIBATCH_SIZE)
    test_loader = SnapshotLoader(test_data, shuffle=False)
    return training_loader, test_loader

def index_set(l, minibatch_size=MINIBATCH_SIZE):
    v = list()
    start = 0
    while l>0:
        cut = min(l,minibatch_size)
        v.append((start,start+cut))
        l -= cut
        start += cut
    return v

def plot_atlas(outpath, df):
  # 'loss_train': list(), 'loss_test': list(), 'acc_test': list()
  plt.clf()
  plt.close()
  f, axes = plt.subplots(3, 1)
  sns.violinplot(data=df ,x='graphdist', y='loss_train', inner="stick", ax=axes[0], cut=0)
  sns.violinplot(data=df ,x='graphdist', y='loss_test', inner="stick", ax=axes[1], cut=0)
  sns.violinplot(data=df ,x='graphdist', y='acc_test', inner="stick", ax=axes[2], cut=0)
  plt.title('loss_train loss_test acc_test')
  plt.savefig(outpath)
  plt.clf()
  f, axes = plt.subplots(3, 1)
  sns.scatterplot(data=df ,x='graphdist', y='loss_train',  ax=axes[0], alpha=0.5)
  sns.scatterplot(data=df ,x='graphdist', y='loss_test', ax=axes[1], alpha=0.5)
  sns.scatterplot(data=df ,x='graphdist', y='acc_test',  ax=axes[2], alpha=0.5)
  plt.title('loss_train loss_test acc_test')
  plt.savefig(outpath.replace('.pdf', '_scatter.pdf'))


def  plot_exp2(outpath, df):
  plt.clf()
  plt.close()
    
  sns.lineplot(data= df[df.grid_dim.eq(5)] ,x='epoch', y='graphdist', hue="sample_num", palette='bright')
  plt.savefig(outpath.replace('.pdf', '_03.pdf'))
  plt.clf()
              
  sns.lineplot(data= df[df.grid_dim.eq(7)] ,x='epoch', y='graphdist', hue="sample_num", palette='bright')
  plt.savefig(outpath.replace('.pdf', '_07.pdf'))            
  plt.clf()
              
  sns.lineplot(data= df[df.grid_dim.eq(10)] ,x='epoch', y='graphdist', hue="sample_num", palette='bright')
  plt.savefig(outpath.replace('.pdf', '0_10.pdf'))
    
def adj_size(n):
    return n*(n-1)/2.0

def plot_exp3(df, outpath):
    for dynamicsname in sorted(list(set(df.dynamicsname))):
        df_i = df[df.dynamicsname == dynamicsname]
        df_i['adj_size'] = df_i.apply(lambda x: adj_size(x['node_num']), axis = 1) 
        df_i['acc'] = df_i.apply(lambda x: (x['adj_size']-x['graphdist'])/x['adj_size'], axis = 1)
        sns.catplot(x='graphname', y='acc', hue='method', data=df_i, kind="bar")
        plt.savefig(outpath.format(dynamicsname))

## Baseline

In [None]:
def find_threshold(datalist):
  from sklearn.cluster import KMeans
  datalist = sorted(datalist)
  datalist_2d = [[l] for l in datalist]
  kmeans = KMeans(n_clusters=2, random_state=0).fit(datalist_2d)
  labels = list(kmeans.labels_)
  num_clust1 = list(kmeans.labels_).count(labels[0])
  return datalist[num_clust1-1]+0.000000000001


def check_baseline_autothreshold(snapshots, recon, ground_truth_graph):
  import netrd
  TS = snapshots_to_netrd(snapshots)
  #G = ground_truth_graph

  #avg_k = np.mean([G.degree(n) for n in G.nodes()])
  avg_k = 5 # value should not matter
  #recon = netrd.reconstruction.MutualInformationMatrix()
  G_pred = recon.fit(TS, threshold_type='degree', avg_k = avg_k)
  W = recon.results['weights_matrix']
  W = W + W.transpose()
  np.fill_diagonal(W, 0.0)
  Wx = W[np.triu_indices(W.shape[0])]
  Wx = Wx.flatten()
  t = find_threshold(list(Wx)) 
  print('treshold: ',t, ' min ', np.min(Wx), ' max ', np.max(Wx))
  a_ij = (W > t).astype('float')

  gt_ij = nx.adjacency_matrix(ground_truth_graph).todense()
  gt_ij = (gt_ij > 0.0001).astype('float')

  z_ij = a_ij - gt_ij
  z_ij = np.abs(z_ij)
  np.fill_diagonal(z_ij, 0.0) # should not need this
  return np.sum(z_ij)/2, -1.0


def chec_spec_baseline(snapshots, ground_truth_graph, recon):
  TS = snapshots_to_netrd(snapshots)
  avg_k = np.mean([ground_truth_graph.degree(n) for n in ground_truth_graph.nodes()])
  #recon = netrd.reconstruction.GraphicalLasso()
  start_time = time.time() 
  G_pred = recon.fit(TS, threshold_type='degree', avg_k = avg_k)
  return graph_dist(ground_truth_graph, G_pred), time.time()-start_time

def baseline_summary(snapshots, ground_truth_graph):
  import netrd
  results = dict()
  baseline_methods = {#'ConvergentCrossMapping': netrd.reconstruction.ConvergentCrossMapping(),
           'CorrelationMatrix': netrd.reconstruction.CorrelationMatrix(),
          # 'CorrelationSpanningTree': netrd.reconstruction.CorrelationSpanningTree(),
           #'FreeEnergyMinimization': netrd.reconstruction.FreeEnergyMinimization(),
           #'GrangerCausality': netrd.reconstruction.GrangerCausality(),
           'GraphicalLasso': netrd.reconstruction.GraphicalLasso(),
           #'MarchenkoPastur': netrd.reconstruction.MarchenkoPastur(),
           #'MaximumLikelihoodEstimation': netrd.reconstruction.MaximumLikelihoodEstimation(),
           #'MeanField': netrd.reconstruction.MeanField(),
           'MutualInformationMatrix': netrd.reconstruction.MutualInformationMatrix(),
           #'NaiveTransferEntropy': netrd.reconstruction.NaiveTransferEntropy(),
           #'OptimalCausationEntropy': netrd.reconstruction.OptimalCausationEntropy(),
           #'PartialCorrelationInfluence': netrd.reconstruction.PartialCorrelationInfluence()}
           'PartialCorrelationMatrix': netrd.reconstruction.PartialCorrelationMatrix()}
           #'OUInference': netrd.reconstruction.OUInference()}
           #'ThoulessAndersonPalmer': netrd.reconstruction.ThoulessAndersonPalmer()
        
  for name, recon in baseline_methods.items():
    dist = -1.0
    try:
      dist, time = chec_spec_baseline(snapshots, ground_truth_graph, recon)
      print(name, dist, time)
    except:
      pass
    results[name] = (dist, time)
    # auto threshold
    dist_auto = -1.0
    time_auto = -1
    # add baseline
    #try:
    #  dist_auto, time_auto = check_baseline_autothreshold(snapshots, recon,ground_truth_graph)
    #except:
    #  pass
    #results[name+'_auto'] = (dist_auto, time_auto)
  return results
    

def check_baseline(snapshots, ground_truth_graph):
  import netrd
  TS = snapshots_to_netrd(snapshots)
  G = ground_truth_graph

  avg_k = np.mean([G.degree(n) for n in G.nodes()])
  recon = netrd.reconstruction.GraphicalLasso()
  G_pred = recon.fit(TS, threshold_type='degree', avg_k = avg_k)
  recon = netrd.reconstruction.MutualInformationMatrix()
  G_pred_witg = recon.fit(TS, threshold_type='degree', avg_k = avg_k)
  return graph_dist(G, G_pred), graph_dist(G, G_pred_witg)

def gen_init_graph(snapshots):
  import netrd
  TS = snapshots_to_netrd(snapshots)
  recon = netrd.reconstruction.MutualInformationMatrix()
  G_pred = recon.fit(TS, threshold_type='quantile', quantile = 0.7)
  a_ij = nx.adjacency_matrix(G_pred).todense()
  a_ij = (a_ij > 0.0001).astype('float')
  a_ij = torch.FloatTensor(a_ij)
  a_ij = a_ij.to(DEVICE)
  return a_ij

# PyTorch Model

In [None]:
import torch.nn.functional as F
class MyNetworkLayer(nn.Module):
    """ Custom Network layer """
    def __init__(self, network_size, state_num=2, init_graph = None, train_network = TRAIN_NETWORK):
        super().__init__()
        self.network_size = network_size
        self.state_num = state_num
        self.size_out = state_num # 2 local states S and I
        self.temperature = TEMPERATURE
        weights = torch.Tensor(self.network_size, self.network_size)
        self.weights = nn.Parameter(weights)  # nn.Parameter is a Tensor that's a module parameter.
        self.sigmoid = torch.nn.Sigmoid()
        self.weight_shift = nn.Parameter(torch.FloatTensor([0.0]))
        if not USESHIFT:
          self.weight_shift.requires_grad = False
        self.thresholding = None
        if True:
          temperature = self.temperature
          if RANDOM_TEMPERATURE:
            temperature = sample_temperature(self.temperature)
          self.thresholding = lambda x: self.sigmoid((self.sigmoid(x)-0.5) * temperature)   
        else:
          self.thresholding = torch.nn.Sigmoid()

        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        self.dropout = nn.Dropout(p=0.005)
        
        self.random_noise = list(np.linspace(0.1,0.025,int(EPOCHNUM/1.0))) + [0.0]*EPOCHNUM

        if init_graph is not None:
          init_graph = (init_graph - 0.5)*INIT_GRAPH_SHIFT
          self.weights = nn.Parameter(init_graph)
        if not TRAIN_NETWORK:
          self.weights.requires_grad = False
          self.weight_shift.requires_grad = False

    def diag_graph_matrix(self, step_i):
        m = self.thresholding(self.weights-self.weight_shift)  
        m = torch.triu(m)
        m = (m + m.t())
        #m = self.thresholding(m) 
        m.fill_diagonal_(0.0) #  TODO use diagonal=1 instead at triu(m)
        #epoch_i = 0
        if GRAPH_DROPOUT:
            noise = self.random_noise[step_i]
            m = m + (torch.randn(m.shape)*noise).to(DEVICE)
            #m = self.dropout(m)
        return m
    
    def increase_temperature(self, step_i, increase=None):
        #if self.temperature > 20:
        #    return
        increase = INCREASE_TEMPERATURE if increase is None else float(increase)
        self.temperature += increase 

    def forward(self, x, step_i=0, graph_collapse = False):
        adj_m = self.diag_graph_matrix(step_i)
        if graph_collapse:
          adj_m = (adj_m > 0.5).float()
        env_count = adj_m.mm(x)
        env_count = env_count.view(-1,self.state_num) 
        return env_count

    def collapse(self, stochastic=True):
      # TODO fix
      with torch.no_grad():
        weights = self.weights
        weights_new = torch.zeros(weights.shape)

        for i in range(weights.shape[0]):
          for j in range(weights.shape[1]):
            if i==j:
              weights_new[i,j] = 0.0
              continue
            if i<j:
              w_ij = weights[i,j]
            else:
              w_ij = weights[j,i]
            if stochastic:
              cut_off = np.random.random()
            weights_new[i,j] = 0.1 if w_ij > 0.0 else -0.1

        for i in range(weights.shape[0]):
          for j in range(weights.shape[1]):
            self.weights[i,j] = weights_new[i,j]

    def threshold_egelist(self):
      edgelist = list()
      with torch.no_grad():
        weights = self.diag_graph_matrix(EPOCHNUM-1)
        for i in range(weights.shape[0]):
          for j in range(weights.shape[1]):
            if i>j:
              continue 
            if weights[i,j] > 0.5:
              edgelist.append([i, j])
      return edgelist

    def threshold_nx(self):
      import networkx as nx
      edgelist = self.threshold_egelist()
      G = nx.Graph()
      G.add_nodes_from(range(self.network_size))
      G.add_edges_from(edgelist)
      return G

    def set_graph(self, edge_list):
      with torch.no_grad():
        for i in range(weights.shape[0]):
          for j in range(weights.shape[1]):
            self.weights[i,j] = -3.0 if i != j else 0.0
          
        for i, j in edge_list:
          self.weights[i,j] = 3.0
          self.weights[j,i] = 3.0


class MyGraphNetwork(nn.Module):
    """ Custom Network layer """
    def __init__(self, network_size = 4, state_num=2, init_graph=None):
        super().__init__()
        self.state_num = state_num
        self.network_size = network_size
        state_num_latent = state_num
        if INDIVIDUAL_INPUT:
            state_num_latent *= 2
        self.state_num_latent = state_num_latent
        self.dropout = nn.Dropout(p=0.05)
        
        #Create Network Layer
        self.network_push = MyNetworkLayer(network_size,state_num=state_num_latent, init_graph=init_graph)
        
        # Fully Connected
        self.fc1 = nn.Linear(state_num_latent, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 10)  #optional
        self.fc4 = nn.Linear(10, state_num)
        if INDIVIDUAL_OUTPUT:
          individual_output_weight = torch.eye(state_num)
          individual_output_weight = individual_output_weight.reshape((1, state_num, state_num))
          individual_output_weight =  individual_output_weight.repeat(network_size, 1, 1)
          self.individual_output_weight = nn.Parameter(individual_output_weight) 
        if INDIVIDUAL_INPUT:
          individual_input_weight = torch.Tensor(state_num, state_num_latent)
          individual_input_weight = individual_input_weight.reshape((1, state_num, state_num_latent))
          individual_input_weight =  individual_input_weight.repeat(network_size, 1, 1)
          self.individual_input_weight = nn.Parameter(individual_input_weight)
          nn.init.kaiming_uniform_(self.individual_input_weight, a=math.sqrt(5))
            
            
    def forward(self, x, step_i=0, graph_collapse=False):
        
        # Prepare Network pass
        if INDIVIDUAL_INPUT:
          x = self.compute_individual_input(x)

        # Network Pass
        x =  self.network_push(x, step_i, graph_collapse=graph_collapse)
        # MLP Pass
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        #x = self.dropout(x)
        if USE_LAYER_3:
            x = self.fc3(x)
            x = F.relu(x)
        x = self.fc4(x)

        output = F.softmax(x, dim=1)
        
        if INDIVIDUAL_OUTPUT and self.network_push.temperature > 0:
          # version 1
          #sample_num = x.shape[0]
          #repeat = int(sample_num/self.network_size)
          #sample_range = list(range(self.network_size)) * repeat
          #x_new = x.clone().detach()
          #for node_i in range(self.network_size):
          #  node_i_mask = [1 if i == node_i else 0 for i in sample_range]
          #  weight_i = self.individual_output_weight[node_i,:,:]
          #  x_new[node_i_mask,:] = x[node_i_mask,:].matmul(weight_i)
          #output = F.softmax(x, dim=1)
            
          # version 2
          line_num = int(x.shape[0])
          snapshot_num = int(line_num/self.network_size) 
          #print('line_num ', line_num, ' snapshot_num ',snapshot_num, '  output shape ', output.shape)
          individual_output_weight_r = self.individual_output_weight.repeat(snapshot_num,1,1)
          #print('individual_output_weight_r ',individual_output_weight_r.shape)
          output = output.view(line_num,1,self.state_num)
          output = output.matmul(individual_output_weight_r).view(line_num,self.state_num)
          output = F.softmax(output, dim=1)
        return output

    def compute_individual_input(self, x):
      x = x.view(-1, self.state_num)
      line_num = int(x.shape[0])
      snapshot_num = int(line_num/self.network_size) 
      individual_input_weight_r = self.individual_input_weight.repeat(snapshot_num,1,1)
      x = x.view(line_num,1,self.state_num)
      x = x.matmul(individual_input_weight_r)
      #x = x.view(line_num,self.state_num_latent)
      #x = F.relu(x)
      x = x.view(-1,self.state_num)
      x  =  F.softmax(x, dim=1)
      x = x.view(self.network_size,-1)
      return x
    
    # for debugging
    def forward_counts(self, x):
      with torch.no_grad():
        x = x.view([1,self.state_num])
        if INDIVIDUAL_INPUT:
            #x = x.view(-1,self.state_num)
            #x = self.compute_individual_input(x) #stupid hack, not really valid
            #x = x.view([1,self.state_num])
            return torch.zeros(x.shape)
        # MLP Pass
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        if USE_LAYER_3:
          x = self.fc3(x)
          x = F.relu(x)
            
        x = self.fc4(x)
        output = F.softmax(x, dim=1)
        return output

# Training

### Training Tools

In [None]:
def test_net_new(model, test_loader):
  criterion = nn.MSELoss()
  model.eval()
  with torch.no_grad():
    acc_list = list()
    loss_list = list()
    for snapshots in test_loader.get_data(): # make sure there is only one
      output = model(snapshots, graph_collapse=COLLAPSE_IN_TEST)
      ground_truth = snapshots.view(-1,model.state_num)
      loss_i = float(criterion(output, ground_truth))
      acc_i = torch.argmax(output, dim=1) == torch.argmax(ground_truth, dim=1)
      acc_i = torch.sum(acc_i) / acc_i.shape[0]
      model.train()
      return float(acc_i), loss_i

def test_net(model, test_data):
  criterion = nn.MSELoss()
  model.eval()
  with torch.no_grad():
    acc_list = list()
    loss_list = list()
    snapshots = torch.cat(test_data, dim=1)
    output = model(snapshots, graph_collapse=True)
    ground_truth = snapshots.view(-1,model.state_num)
    loss_i = float(criterion(output, ground_truth))
    acc_i = torch.argmax(output, dim=1) == torch.argmax(ground_truth, dim=1)
    acc_i = float(torch.sum(acc_i) / acc_i.shape[0])
    model.train()
    return acc_i, loss_i

def load_model(model, exp_name):
  import glob
  weight_list = glob.glob(HOME+exp_name+'/weights/train_*.weight')
  if len(weight_list) == 0:
    print('start without init weights')
    return None, 0
  weight_list = sorted(weight_list)
  weight_path = weight_list[-1]
  print('load: '+weight_path)
  model.load_state_dict(torch.load(weight_path))

  record_df = pd.read_csv(HOME+exp_name+'/summary.csv', sep=';')
  start_shift = list(record_df['epoch'])[-1]

  return record_df, int(start_shift)


def start_training(*args, **kwargs):
  import pickle
  exp_name = args[0]
  assert("<class 'str'>" in str(type(exp_name)))

  try:
    filepath = HOME+exp_name+'/result.pickel'
    return_values = pickle.load(open(filepath, "rb" ))
    print('read results: ',filepath)
    return return_values
  except:
    return_values = training(*args, **kwargs)
    filepath = HOME+exp_name+'/result.pickel'
    pickle.dump(return_values, open(filepath, "wb" ))
    return return_values

### Training Rounds

In [None]:
def training(exp_name, snapshots, ground_truth_graph=None, state_num=2, init_graph=None):
  from torch.optim import Adam
  assert(len(snapshots[0][0]) == state_num) # is model.state_num
  network_size = len(snapshots[0])
  lr = LEARNING_RATE #0.0005 #0.008 # 0.0001  # 0.001
  epochs = EPOCHNUM 
  training_snapshots, test_snapshots = split_snapshots(snapshots) # old
  #train_loader, test_loader = snapshots_to_loader(snapshots) #new
  model = MyGraphNetwork(network_size=network_size, state_num=state_num, init_graph=init_graph)
  model = model.to(DEVICE)
  record_df, start_shift = load_model(model, exp_name)
  optimizer = Adam(model.parameters(), lr = lr) 
  criterion = nn.MSELoss()
  if record_df is None:
    record_df = {'epoch': list(), 'loss_train': list(), 'loss_test': list(), 'acc_test': list(), 'graph_edges': list(), 'weightpath': list(), 'graphdist': list()}
    record_df = pd.DataFrame(record_df)

  if ground_truth_graph is not None:
    plot_graph(ground_truth_graph, 'ground_truth', exp_name=exp_name)

  start_time = time.time() # measure time per epoch
  forward_count = 0
  best_mean_test_acc_sofar = -1.0
  decr_epoch_count = 0
  best_graphdiff_sofar = 0
  processed_samples = 0
  overall_loss = 0.0
  graph_epoch_lists = [list() for _ in range(10)] # start with empty graphs 

  for step_j in range(epochs):
    step_i = step_j+start_shift
    if step_i >= EPOCHNUM:
        break
    weightpath = HOME+exp_name+'/weights/train_{}.weight'
    loss_sum = torch.tensor([0.0], requires_grad=True, dtype=torch.float, device=DEVICE)
    random.shuffle(training_snapshots)  # old
    epoch_loss = list()
    for cut in index_set(len(training_snapshots)):
      #cut = min(MINIBATCH_SIZE, len(training_snapshots))  # old
      #trainings_data = train_loader.get_data() #new
      #for snapshot_minibatch in trainings_data:  #new
      if step_i == 0: # to make sure 1st evaluation is based on untrained weight matrix
        break
            
      snapshot_minibatch = torch.cat(training_snapshots[cut[0]:cut[1]], dim=1) #old
      model.zero_grad()
      processed_samples += cut[1] - cut[0]  #int(snapshot_minibatch.shape[1]/state_num)
      output = model(snapshot_minibatch, step_i)
      forward_count += 1
      ground_truth = snapshot_minibatch.view(-1,model.state_num)
      loss = criterion(output, ground_truth)
      loss.backward()   
      overall_loss += float(loss)
      epoch_loss.append(float(loss))
      optimizer.step()
    
    #model.network_push.increase_temperature(step_i, increase=1/100.0 * INCREASE_TEMPERATURE)

    if (step_i == 0 or step_i % OUTPUT_EACH_X_EPOCHS == 0) or step_i == EPOCHNUM-1: # todo as hyper param
      #if step_j == range(epochs)[-1]:
      #model.network_push.collapse(stochastic=False)
      #mean_test_acc, mean_test_loss = test_net(model, test_loader) #new
      # dummy values to avoid div by zero 
      if forward_count == 0:
         forward_count = 0.00001
      if processed_samples == 0:
        processed_samples = 0.00001
      if USE_VALIDATIONSET:
        mean_test_acc, mean_test_loss = test_net(model, test_snapshots) #old
      else:
        mean_test_acc, mean_test_loss = -1, -1
      graph_edges = model.network_push.threshold_egelist()
      weightpath_i = weightpath.format(int2str(step_i))
      graphdist = graph_dist(ground_truth_graph, edgelist_to_graph(ground_truth_graph.number_of_nodes(), graph_edges))

      record_i = {'epoch': step_i, 'loss_train': overall_loss/forward_count, 'loss_test': mean_test_loss, 'acc_test': mean_test_acc, 'graph_edges': str(graph_edges), 'weightpath': weightpath_i, 'graphdist': graphdist}
      record_df = record_df.append(record_i, ignore_index=True)

      plot_graph(model.network_push.threshold_nx(), 'trainplos_'+int2str(step_i), exp_name=exp_name, ground_truth=ground_truth_graph)

      plot_dynamics('dyn_'+str(10000000000000+step_i), exp_name, model, state_num = state_num)
      torch.save(model.state_dict(), weightpath_i)
      mean_loss = list()
      record_df['exp_name'] = exp_name
      record_df.to_csv(HOME+exp_name+'/summary.csv', sep=';', index=False)
      best = ''
      mean_epoch_loss = np.mean(epoch_loss) if len(epoch_loss)>0 else -1.0

      if mean_test_acc > best_mean_test_acc_sofar:
        best_mean_test_acc_sofar = mean_test_acc
        best = '  !'
        decr_epoch_count = 0
        best_graphdiff_sofar = graphdist
      decr_epoch_count += 1
      print('rolling loss: {:.6f}'.format(overall_loss/forward_count), '  epoch loss: {:.6f}'.format(mean_epoch_loss), '  test loss: {:.6f}'.format(mean_test_loss) , '  test acc: {:.6f}'.format(mean_test_acc) ,'graphdist: ',graphdist, '   mean time: {:.8f}'.format((time.time() - start_time)/processed_samples) + best)
      model.network_push.increase_temperature(step_i)
    
      if EARLY_STOPPING and step_i > 50 and not False in [graph_edges == g for g in graph_epoch_lists]:
        step_i = EPOCHNUM * 2
        print('early stopping')
        break
      graph_epoch_lists = graph_epoch_lists[-9:] + [sorted(graph_edges)]
    
      # ploit
      plt.clf()
      plt.close()
      sns.displot(model.network_push.weights.cpu().detach().flatten(), kde=True)
      plt.savefig(weightpath.format(int2str(step_i))+'_w.png')
      #if step_i >= 100*10:#40: #5 for atlas test or 10 when indiv layer is active
      #  break
        #print('finish') 
        #plot_graph(model.network_push.threshold_nx(), 'final_graph', exp_name=exp_name, ground_truth=ground_truth_graph, is_final=True)
        #return record_df,  graphdist #best_graphdiff_sofar

  print('finish training')
  plot_graph(model.network_push.threshold_nx(), 'final_graph', exp_name=exp_name, ground_truth=ground_truth_graph, is_final=True)
  record_df['exp_name'] = exp_name

  return record_df, graphdist


# Graph Generators


In [None]:
# deterministic
G_karate = nx.karate_club_graph()
G_grid5x5 = nx.grid_2d_graph(5,5)
G_grid10x10 = nx.grid_2d_graph(10,10)
G_grid20x20 = nx.grid_2d_graph(20,20)
G_lollipop = nx.lollipop_graph(10,10)
G_circular_ladder = nx.circular_ladder_graph(100)

# random
G_erdos_small = nx.erdos_renyi_graph(25, 0.15, seed=43)
G_erdos_small_wellcon = nx.erdos_renyi_graph(20, 0.35, seed=42)
G_erdos = nx.erdos_renyi_graph(50, 0.1, seed=42)
G_ba_small = nx.barabasi_albert_graph(25,3, seed=42)
G_ba = nx.barabasi_albert_graph(100,4, seed=42)
G_geom_small = nx.random_geometric_graph(50, 0.3, seed=42)
G_geom = nx.random_geometric_graph(200, 0.125, seed=42)
G_geom_large = nx.random_geometric_graph(400, 0.1, seed=42)
G_wsn = nx.newman_watts_strogatz_graph(50, 4, 0.15, seed=42)

# tiny
G_erdos_tiny= nx.erdos_renyi_graph(10, 0.5, seed=42)
G_lollipop_tiny = nx.lollipop_graph(5,5)
G_circular_ladder_tiny = nx.circular_ladder_graph(10)
G_grid_tiny = nx.grid_2d_graph(3,3)

# use these graphs:
ground_truth_graphset = {'G_erdos_small': G_erdos_small, 'G_geom': G_geom, 'G_grid10x10': G_grid10x10, 'G_wsn':G_wsn}
ground_truth_graphset = {graph_name: clean_shuffle_graph(g) for graph_name, g in ground_truth_graphset.items()}

#ground_atlas = [g for g in nx.graph_atlas_g() if g.number_of_nodes()==7 and nx.is_connected(g)]
#ground_atlas = {'atlas'+str(100000+i):g for i,g in enumerate(ground_atlas)}

# Data Generation

### Inverted Voter

In [None]:
def gen_invvoter(G, noise=0.01):
    A = [1., 0.]
    B = [0., 1.]
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([A, B]) for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            rates[n] = len([n_j for n_j in G.neighbors(n) if states[n_j] == states[n]]) + noise
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        states[change_n] = A if states[change_n] == B else B     
    return states

### SIS

In [None]:
def gen_sis(G, inf_rate=1.0, rec_rate=2.0, noise=0.1):
    S = [1., 0.]
    I = [0., 1.]
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([S, I]) for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            rates[n] = noise
            if states[n] == I:
                rates[n] += rec_rate
            if states[n] == S:
                rates[n] += inf_rate * len([n_j for n_j in G.neighbors(n) if states[n_j] == I])
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        states[change_n] = S if states[change_n] == I else I     
    return states

### Rock-Paper-Scissors

In [None]:

def gen_rps(G, change_rate=1.0, noise=0.1):
    R = [1, 0, 0]
    P = [0, 1, 0]
    S = [0, 0, 1]
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([R, P, S]) for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            rates[n] = noise
            if states[n] == R:
                rates[n] += change_rate * len([n_j for n_j in G.neighbors(n) if states[n_j] == P]) # paper wins against rock
            if states[n] == P:
                rates[n] += change_rate * len([n_j for n_j in G.neighbors(n) if states[n_j] == S])
            if states[n] == S:
                rates[n] += change_rate * len([n_j for n_j in G.neighbors(n) if states[n_j] == R])
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        if states[change_n] == R:
            states[change_n] = P
        elif states[change_n] == P:
            states[change_n] = S
        elif states[change_n] == S:
            states[change_n] = R
    return states

### Forest Fire Model

In [None]:
def gen_forestfire(G, growth_rate=1.0, lightning_rate=0.1, firespread_rate = 2.0, fireextinct_rate = 2.0, noise=0.01):
    E = [1, 0, 0] # empty
    T = [0, 1, 0] # tree
    L = [0, 0, 1] # fire
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([E, T, L]) for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            rates[n] = noise
            if states[n] == E:
                rates[n] += growth_rate
            if states[n] == T:
                rates[n] += firespread_rate if len([n_j for n_j in G.neighbors(n) if states[n_j]==L])>0 else 0.0
                rates[n] += lightning_rate
            if states[n] == L:
                rates[n] += fireextinct_rate
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        if states[change_n] == E:
            states[change_n] = T
        elif states[change_n] == T:
            states[change_n] = L
        elif states[change_n] == L:
            states[change_n] = E
    return states

### Coupled Map Lattice

In [None]:
def solve_cmp(G, s=0.1, r=3.57):
  def f_map(x): return r*x*(1.0-x)

  states = [random.random() for j in range(G.number_of_nodes())]
  steps = 100 + random.choice(range(100))

  for _ in range(steps):
    new_states = list(states)
    for n in G.nodes():
      v = states[n]
      neig = list(G.neighbors(n))
      new_states[n] = (1.0-s) * f_map(v) + s/len(neig) * np.sum([f_map(states[n_j]) for n_j in neig])
    states = list(new_states)
  
  return states

def gen_cmp(G, s=0.1, r=3.57, levels=2):
    def get_bin(value, levels):
        z = np.linspace(0,1,levels+1)[1:]
        return np.sum(z < value)

    def one_hot(value, levels):
        vec = [0]*levels
        vec[value] = 1
        return vec
    states = solve_cmp(G, s, r)
    states = [get_bin(states[j], levels) for j in range(G.number_of_nodes())]
    states = [one_hot(s, levels) for s in states]
    return states

def gen_cmp10(G, s=0.1, r=3.57):
    return gen_cmp(G, s=s, r=r, levels=10)

### Majority-Flip

In [None]:
def gen_majorityflip(G, change_rate=1.0, noise=0.01):
    A = [1., 0.]
    D = [0., 1.]
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([A, D]) for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            alive_neighbors =  len([n_j for n_j in G.neighbors(n) if states[n_j] == A])
            dead_neighbors =  len([n_j for n_j in G.neighbors(n) if states[n_j] == D])
            alive_frac = alive_neighbors/(alive_neighbors+dead_neighbors)
            rates[n] = noise
            if states[n] == A and (alive_frac < 0.2 or alive_frac > 0.8):
                rates[n] += change_rate
            if states[n] == D and (alive_frac < 0.3 or alive_frac > 0.7):
                rates[n] +=change_rate
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        states[change_n] = A if states[change_n] == D else D     
    return states

### Summary Dynamics

In [None]:
# name -> method, number_of_states    
ground_truth_dynamicsset = {
    'inv_voter': (gen_invvoter, 2),
    'sis': (gen_sis, 2),
    'rock_paper_scissors': (gen_rps, 3),
    'forest_fire': (gen_forestfire, 3),
    'cmp': (gen_cmp10, 10),
    'majorityflip': (gen_majorityflip, 2)
    }

# Putting Everything Together

In [None]:
def experiment(exp_name, G, dynamics, snapshots_num = 10000, state_num=2, overwrite=False, init=False, snapshots = None):
  gen_folders(exp_name, home=HOME)
  time.sleep(0.1)
  try:
    assert(False)
    assert(not overwrite)
    G = nx.read_gml(HOME+exp_name+'/ground_truth_graph.gml') # stupid nx reads nodes as str and not as int
    node_mapping = {n: int(n) for n in G.nodes()} # TODO test
    G = nx.relabel_nodes(G, node_mapping)
    print('found graph')
  except:
    print('write ground truth graph: ',HOME+exp_name+'/ground_truth_graph.gml')
    nx.write_gml(G, HOME+exp_name+'/ground_truth_graph.gml')
    nx.write_edgelist(G, HOME+exp_name+'/ground_truth_graph.edgelist', data=False)
  
  assert(set(range(G.number_of_nodes())) == set(G.nodes()))
  assert(nx.is_connected(G))

  if snapshots is None:
    try:
      assert(not overwrite)
      snapshots = read_snapshots(HOME+exp_name+'/snapshots.txt')
      print('found snapshot data, using this instead. Length is: ', len(snapshots))
    except:
      print('write snapshots')
      snapshots = [None for _ in range(snapshots_num)]
      time.sleep(0.1)
    # 1)
      for i in tqdm(range(snapshots_num)):
        snapshots[i] = dynamics(G)
    # 2)
    #  #snapshots = [dynamics(G) for _ in range(snapshots_num)]
    # 3)
      #from joblib import Parallel, delayed
      #snapshots = Parallel(n_jobs=16)(delayed(dynamics)(G) for i in range(snapshots_num))
      write_snapshots(snapshots, HOME+exp_name+'/snapshots.txt')
      plot_snapshots(G, snapshots, exp_name)
  
  #data = snapshots_to_data(snapshots) # seed

  print('start training')
  init_graph = None
  if 'networkx' in str(type(init)):
    init_graph = nx_to_adj(init)
  elif 'bool' in str(type(init)) and init:
    init_graph = get_baseline_graph(snapshots)
    #init_graph = nx_to_adj(init)
  record, score = start_training(exp_name, snapshots, ground_truth_graph = G, state_num=state_num, init_graph=init_graph) #also seed for test train split ...
  return score, snapshots, record



# Run Experiments

### Exp 0: Test Random Model

In [None]:
def example_run(state_num=2):
  #HOME = './'
  G = G_ba_small
  G = clean_shuffle_graph(G)

  score, snapshots, record = experiment("SIS_5k", G, gen_sis, snapshots_num = 3000, state_num=state_num, init=False) 
  bs = baseline_summary(snapshots, G)
  for name, (dist, time) in bs.items():
    print(name, ':   ' ,dist)

 
EPOCHNUM=20*100
USE_VALIDATIONSET=False
INDIVIDUAL_INPUT = False
INDIVIDUAL_OUTPUT = True
RANDOM_TEMPERATURE = True
INCREASE_TEMPERATURE = 1.0

OUTPUT_EACH_X_EPOCHS = 100
MINIBATCH_SIZE = 100


example_run()

### Exp 1: Identifiability


In [None]:
def test_atlas(dynamic_func, exp_name, indiv_out=True, state_num=2, individual_input= False):
  global GRAPH_DROPOUT, EARLY_STOPPING, TRAIN_NETWORK, MINI_BATCHSIZE, USESHIFT, TEMPERATURE, LEARNING_RATE, INIT_GRAPH_SHIFT, HOME,INDIVIDUAL_OUTPUT, EPOCHNUM, INCREASE_TEMPERATURE, USE_VALIDATIONSET, INDIVIDUAL_INPUT
  TRAIN_NETWORK = False
  MINI_BATCHSIZE = 1000
  USESHIFT = True
  TEMPERATURE = 20.0 
  INCREASE_TEMPERATURE = 0.0
  graph_id = 426 # this is int(len(ground_atlas)/2)
  LEARNING_RATE =  0.001
  INIT_GRAPH_SHIFT = 10
  os.system('mkdir EXP1')
  HOME = './EXP1/exp1_{}/'.format(exp_name)
  INDIVIDUAL_OUTPUT = indiv_out
  EPOCHNUM= 6 * 100
  USE_VALIDATIONSET=False
  INDIVIDUAL_INPUT = individual_input
  EARLY_STOPPING = False
  GRAPH_DROPOUT = False
    
  if INDIVIDUAL_OUTPUT:
    EPOCHNUM = 8

  print("run atlas: ", exp_name)
  try:
    picklepath = HOME+'atlas_exp.pickle'
    return_values = pickle.load(open(picklepath, "rb" ))
    print('read atlas results: ',picklepath)
    return return_values
  except:
    pass

  atlas_summary =  {'exp_name': list(), 'epoch': list(), 'loss_train': list(), 'loss_test': list(), 'acc_test': list(), 'graph_edges': list(), 'weightpath': list(), 'graphdist': list()}

  ground_atlas = create_all_graphs(node_num=5, check_if_connected=True)

  G = nx.generators.small.bull_graph()
  G = clean_shuffle_graph(G)
  samples = 40000  # sample num should be 40000

  score, snapshots, record = experiment('atlasAll_'+str(graph_id)+'_gt', G, dynamic_func, snapshots_num = samples, state_num=state_num, init=G, snapshots=None)
  for key in atlas_summary:
      if key == 'acc_test':
          atlas_summary[key].append(np.max(list(record[key])[1:]))
      elif 'loss' in key:
        atlas_summary[key].append(np.min(list(record[key])[1:]))
      else:
        atlas_summary[key].append(list(record[key])[-1])

  atlas_random = list(enumerate(ground_atlas))
  random.seed(42)
  random.shuffle(atlas_random)
  random.seed(None)
  counter = 0

  for i, G_i in atlas_random:
    counter = counter + 1
    score, snapshots, record = experiment('atlasAll_'+str(i), G, gen_sis, snapshots_num = None, state_num=state_num, init=G_i, snapshots=snapshots)
    for key in atlas_summary:
      if key == 'acc_test':
          atlas_summary[key].append(np.max(list(record[key])[1:]))
      elif 'loss' in key:
        atlas_summary[key].append(np.min(list(record[key])[1:]))
      else:
        atlas_summary[key].append(list(record[key])[-1])
    if counter % 20 == 0:
      atlas_summary_df = pd.DataFrame(atlas_summary)
      atlas_summary_df.to_csv(HOME+'atlas_summary.csv')
      plot_atlas(HOME+'atlas_summary.pdf', atlas_summary_df)
      print(counter)
    #print(atlas_summary)

  atlas_summary_df = pd.DataFrame(atlas_summary)
  atlas_summary_df.to_csv(HOME+'atlas_summary.csv')
  print(atlas_summary_df)
  plot_atlas(HOME+'atlas_summary.pdf', atlas_summary_df)
  pickle.dump(atlas_summary_df, open(picklepath, "wb" ))
  HOME = 'drive/MyDrive/colab/NeuralGraphInference/'
  return atlas_summary_df

for dynamicsname, dyn_func_state_num in ground_truth_dynamicsset.items():
    dyn_func, state_num = dyn_func_state_num
    test_atlas(dyn_func, dynamicsname+"_exp1", False, state_num=state_num)


### Exp 2: Reconstruction Accuracy w.r.t. epoch num/sample num/node num


In [None]:
def exp_samplenum(dynamic_func, exp_name, state_num=2, random_temperature=False, indiv_out=False):
    global EPOCHNUM, GRAPH_DROPOUT, USE_VALIDATIONSET, HOME, MINIBATCH_SIZE, INCREASE_TEMPERATURE, RANDOM_TEMPERATURE, OUTPUT_EACH_X_EPOCHS, EARLY_STOPPING, INDIVIDUAL_OUTPUT
    INCREASE_TEMPERATURE = 1.0
    RANDOM_TEMPERATURE = False
    SAMPLE_NUM = 50000
    OUTPUT_EACH_X_EPOCHS = 50
    MINIBATCH_SIZE = 100
    USE_VALIDATIONSET = False
    INDIVIDUAL_OUTPUT = True 
    INDIVIDUAL_INPUT = False
    GRAPH_DROPOUT = True
    EPOCHNUM = 40*100
    EARLY_STOPPING = False
    
    os.system('mkdir EXP2')
    HOME = './EXP2/exp2_{}/'.format(exp_name)
    
    exp_summary = None
    
    print("run Exp2: ", exp_name)
    
    try:
        picklepath = HOME+'EXP2data.pickle'
        return_values = pickle.load(open(picklepath, "rb" ))
        print('read EXP2 results: ',picklepath)
        return return_values
    except:
        pass
    
    counter = 0
    for grid_dim in [5,7,10]:
        G = nx.grid_2d_graph(grid_dim,grid_dim)
        G = nx.convert_node_labels_to_integers(G)
        for sample_num in [100, 1000, 10*1000]:
            MINIBATCH_SIZE = int(sample_num/10)
            for run_id in range(5):
                counter = counter + 1
                exp_name_i = exp_name+'_'+str(grid_dim).zfill(10)+'_'+str(sample_num).zfill(10)+'_'+str(run_id)
                score, snapshots, record = experiment(exp_name_i, G, dynamic_func, snapshots_num = sample_num, state_num=state_num)
                record['sample_num'] = sample_num
                record['grid_dim'] = grid_dim
                record['run_id'] = run_id
                if exp_summary is None:
                    exp_summary = record
                else:
                    exp_summary = exp_summary.append(record, ignore_index=True)
                if counter % 2 == 0:
                    exp_summary.to_csv(HOME+'EXP2_summary.csv')
                    plot_exp2(HOME+'EXP2_summary.pdf', exp_summary)

    exp_summary.to_csv(HOME+'EXP2_summary.csv')
    plot_exp2(HOME+'EXP2_summary.pdf', exp_summary)
    print(exp_summary)
    pickle.dump(exp_summary, open(picklepath, "wb" ))
    return exp_summary

USE_VALIDATIONSET=False
for dynamicsname, dyn_func_state_num in ground_truth_dynamicsset.items():
    dyn_func, state_num = dyn_func_state_num
    exp_samplenum(dyn_func, dynamicsname+"_exp2", state_num=state_num) 


### Exp 3: GBN vs Baselines

In [None]:
def load_baseline(snapshots, graph, exp_name_i):
    try:
        picklepath = HOME+exp_name_i+'/baseline.pickle'
        bl = pickle.load(open(picklepath, "rb" ))
        return bl
    except:
        picklepath = HOME+exp_name_i+'/baseline.pickle'
        bl = baseline_summary(snapshots, graph)
        pickle.dump(bl, open(picklepath, "wb" ))
        return bl


def exp_GBCvsBaseline(exp_name, random_temperature=False, individual_output=True, individual_input=False):
    global EPOCHNUM, GRAPH_DROPOUT, USE_VALIDATIONSET, HOME, MINIBATCH_SIZE, INCREASE_TEMPERATURE, RANDOM_TEMPERATURE, OUTPUT_EACH_X_EPOCHS, INDIVIDUAL_INPUT
    os.system('mkdir EXP3')
    HOME = './EXP3/exp3_{}/'.format(exp_name)
    INCREASE_TEMPERATURE = 1.0
    RANDOM_TEMPERATURE = random_temperature
    SAMPLE_NUM = 50000 # should be 50k
    OUTPUT_EACH_X_EPOCHS = 50
    MINIBATCH_SIZE = 100
    USE_VALIDATIONSET = False
    INDIVIDUAL_OUTPUT = individual_output 
    INDIVIDUAL_INPUT = False
    GRAPH_DROPOUT = True
    EPOCHNUM = 100*100
    
    exp_summary = None
    exp_summary_short = None
    
    print("run Exp3: ", exp_name)
    
    try:
        picklepath = HOME+'EXP3data.pickle'
        return_values = pickle.load(open(picklepath, "rb" ))
        print('read EXP3 results: ',picklepath)
        return return_values
    except:
        pass
    

    exp_id = 0
    for dynamicsname, dyn_func_state_num in ground_truth_dynamicsset.items():
        dyn_func, state_num = dyn_func_state_num
        for graphname, graph in ground_truth_graphset.items():
            for run_id in range(1):
                exp_id += 1
                exp_name_i = exp_name+'_'+str(dynamicsname)+'_'+str(graphname)+'_'+str(run_id).zfill(3)
                
                start_time = time.time() # measure time of complete reconstruction
                score, snapshots, record = experiment(exp_name_i, graph, dyn_func, snapshots_num = SAMPLE_NUM, state_num=state_num)
                time_elapsed = time.time() - start_time
                record['dynamicsname'] = dynamicsname
                record['graphname'] = graphname
                record['run_id'] = run_id
                record['exp_id'] = exp_id
                record['time_elapsed'] = time_elapsed
                final_graph_dist =  list(record.graphdist)[-1]
                record['final_graphdiff'] = final_graph_dist

                bl = load_baseline(snapshots, graph, exp_name_i)
                for name, (value, elapsed_time) in bl.items():
                    record['Baseline_'+name] = value
                    record['Baseline_'+name+'_time'] = elapsed_time
                    
                if exp_summary is None:
                    exp_summary = record
                else:
                    exp_summary = exp_summary.append(record, ignore_index=True)
                    
                if exp_summary_short is None:
                    df_i = {'method': 'out_method', 'node_num': [graph.number_of_nodes()], 'time': [elapsed_time], 'graphdist': [final_graph_dist], 'final_graph': ['None'], 'graphname': [graphname], 'dynamicsname':[dynamicsname]}
                    print(df_i)
                    exp_summary_short = pd.DataFrame(df_i)
                else:
                    df_i = {'method': 'out_method', 'time':elapsed_time, 'graphdist': final_graph_dist, 'final_graph': 'None', 'graphname': graphname, 'dynamicsname':dynamicsname,  'node_num': graph.number_of_nodes()}
                    exp_summary_short = exp_summary_short.append(df_i, ignore_index=True)
                for name, (value, elapsed_time) in bl.items():
                    df_i = {'method': 'Baseline_'+name, 'node_num': graph.number_of_nodes(), 'time':elapsed_time, 'graphdist': value, 'final_graph': 'None', 'graphname': graphname, 'dynamicsname':dynamicsname}
                    exp_summary_short = exp_summary_short.append(df_i, ignore_index=True)
                    
    
                if exp_id % 1 == 0:
                    exp_summary.to_csv(HOME+'EXP3_summary.csv')
                    exp_summary_short.to_csv(HOME+'EXP3_summaryshort.csv')
                    plot_exp3(exp_summary_short, HOME+'EXP3_{}.jpg')
    
    exp_summary.to_csv(HOME+'EXP3_summary.csv')

    

exp_GBCvsBaseline('run1')
# there will be some pandas warnings that can be ignored
