# 🕸️ 🔬 INSA - Interaction Graph Inferecne via Sensitivity Analysis

### Todos

- Clear and tidy outputs.
- Make it runnable without wandb.
- Ensure consistent indentation.


### WandB Setup

Currently, you must have a WandB account to execute the code. Please insert your WandB token in the designated area.

In [None]:
import os, wandb

PROJECT_NAME = "insa"
WANDB_MODE = "online"   
USE_WANDB = True
NUM_RUNS = 10

DIFFERENT_PARAMS_FOR_SNAPSHOTS = False

os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"

def get_wand_api_key():
    return ""  # ADD TOKEN HERE!


wandb.login(key=get_wand_api_key())

sweep_config = {
    "name": "insa",
    "method": "random",
    "metric": {
        "name": "position",
        "goal": "minimize",
    },
    "parameters": {
        "num_epoch": {"values": [15]},
        "num_train_sample": {"values": [20000]},
        "num_test_sample": {"values": [10000]},
        "batch_size": {"values": [2048]},
        "dropout": {"values": [0.15]},
        "layer_num": {"values": [5]},
        "hidden_size_scale":  {"values": [2]},
        "use_train_as_test": {"values": [False]},
    },
}

def get_config(name):
    try:
        assert(USE_WANDB)
        return wandb.config[name]
    except:
        return sweep_config["parameters"][name]["values"][0]

## Imports

In [None]:
import math, random, time, traceback
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import scipy
import pickle
from tqdm import tqdm
import time
import netrd
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.utils.data import DataLoader, TensorDataset
from torch import optim

from sklearn.cluster import AgglomerativeClustering
from sklearn.mixture import GaussianMixture

## Utils

In [None]:
def exclude_index(vector, index):
    if isinstance(vector, torch.Tensor):
        before_index = vector[:, :index]
        after_index = vector[:, index+1:]
        new_vector = torch.cat((before_index, after_index), dim=1)
        return new_vector, vector[:, index:index+1].clone()
    elif isinstance(vector, np.ndarray):
        before_index = vector[:, :index]
        after_index = vector[:, index+1:]
        new_vector = np.concatenate((before_index, after_index), axis=1)
        return new_vector, np.copy(vector[:, index:index+1])
    else:
        raise TypeError("Input vector must be a PyTorch Tensor or a numpy ndarray")
        
        
def state_to_color(state):
    if len(state) == 1:
        cmap = plt.get_cmap("viridis")
        return cmap(state[0])
    colors = ['blue', 'red', 'black', 'green', 'orange', 'yellow', 'pink']
    return colors[np.argmax(state) % 6]

        
def plot_snapshots(G, snapshots, exp_name, max_plot_num=10):
    try:
        for i, snapshot in enumerate(snapshots):
            assert(len(snapshot) == G.number_of_nodes())
            if max_plot_num is not None and i >= max_plot_num:
                break
            node_color = [state_to_color(snapshot[j]) for j in range(len(snapshot))]
            plt.clf()
            nx.draw_networkx(G, node_color=node_color, pos=nx.spectral_layout(G))
            plt.savefig(f'{exp_name}_snapshot_{i:03}_spec.png')
            plt.clf()
            nx.draw_networkx(G, node_color=node_color, pos=nx.kamada_kawai_layout(G))
            plt.savefig(f'{exp_name}_snapshot_{i:03}_kam.png')
    except Exception as e:
        print(f"Could not draw snapshot: {e}")
        


def get_position(graph_distance_dict):
    #print("get_position(graph_distance_dict)", (graph_distance_dict))
    graph_distance = [(k,v) for k,v in graph_distance_dict.items() if "hierachical_down" in k]
    graph_distance = sorted(graph_distance, key=lambda x: x[1])
    graph_distance = [k for k,v in graph_distance]
    #print("graph_distance ",graph_distance)
    our_method = "our method (hierachical_down)"
    pos = None
    for i, k in enumerate(graph_distance):
        if k == our_method:
            return i
    raise ValueError("method not found")

    
def get_node_num_and_dim_from_loader(traintest_loader):
    first_data_batch = next(iter(traintest_loader))
    node_num = first_data_batch.shape[1]
    node_dim = 1
    
    try:
        node_dim = first_data_batch.shape[2]
    except IndexError:
        pass
    
    return node_num, node_dim


def get_node_num_and_dim_from_loader(traintest_loader):
    first_data_batch = next(iter(traintest_loader))
    node_num = first_data_batch.shape[1]
    node_dim = 1
    try:
        node_dim = first_data_batch.shape[2]
    except:
        pass
    return node_num, node_dim


def clean_shuffle_graph(G):
    node_mapping = dict(zip(sorted(G.nodes()), sorted(G.nodes(), key=lambda _: random.random()))) 
    G = nx.relabel_nodes(G, node_mapping)
    G = nx.convert_node_labels_to_integers(G)
    if not nx.is_connected(G):
        print('Graph is not connected, try a differnt one.')
        assert(nx.is_connected(G))
    return G


def set_seed(exp_name):
    config_name = ""
    config_keys = sorted(sweep_config["parameters"].keys())
    config_name = str([(k,repr(get_config(k))) for k in config_keys])
    name_as_int = int(str_to_float(config_name + exp_name)*10000000)
    np.random.seed(name_as_int)
    torch.random.manual_seed(name_as_int)
    random.seed(name_as_int)

# from stack overflow
def str_to_float(s, encoding="utf-8"):
    from zlib import crc32
    def bytes_to_float(b):
        return float(crc32(b) & 0xffffffff) / 2**32
    return bytes_to_float(s.encode(encoding))


def str_to_int(s):
    return int(str_to_float(s)*10000000)

def visualize_matrix(matrix, filepath):
    print(filepath)
    plt.clf()
    plt.imshow(matrix, cmap='gray')
    plt.colorbar() # optional, to see color map scale
    plt.savefig(filepath)
    plt.show()
    

    
def add_exp_to_results(exp_name, results, graph_distance_dict, graph_position_dict, time_elapsed_dict, filename="results.csv"):
    method_names = sorted(graph_distance_dict.keys())
    for method in method_names:
        results["exp_name"].append(exp_name)
        results["method"].append(method)
        results["graph_distance"].append(graph_distance_dict[method])
        results["position"].append(graph_position_dict[method])
        try:
            results["runtime"].append(time_elapsed_dict[method])
        except:
            results["runtime"].append(-1.0)

    results_pandas = pd.DataFrame(results)
    results_pandas.to_csv(filename)
    return results


def compute_graph_distance(adj_matrix_gt, adj_matrix_pred):
    if 'numpy' not in str(type(adj_matrix_gt)).lower():
        adj_matrix_gt = nx.to_numpy_array(adj_matrix_gt)
    if 'numpy' not in str(type(adj_matrix_pred)).lower():
        adj_matrix_pred = nx.to_numpy_array(adj_matrix_pred) 

    z_ij = adj_matrix_gt - adj_matrix_pred
    z_ij = np.abs(z_ij)
    np.fill_diagonal(z_ij, 0.0) # should not need this
    return np.sum(z_ij)

## Tresholding

In [None]:
def cluster_elements_gmm(vector):
    X = np.array(vector).reshape(-1, 1)
    gmm = GaussianMixture(n_components=2)
    gmm.fit(X)
    labels = gmm.predict(X)
    return labels

def cluster_elements_hierachical(vector):
    X = np.array(vector).reshape(-1, 1)
    clustering = AgglomerativeClustering(n_clusters=2)
    clustering.fit(X)
    labels = clustering.labels_

    return labels


def cluster2d_elements_normalize(matrix):
    for _ in range(100):
        row_sums = matrix.sum(axis=1)
        matrix = matrix / row_sums[:, np.newaxis]
        column_sums = matrix.sum(axis=0)
        matrix = matrix / column_sums[np.newaxis, :]
    return binarize_weight_matrix(matrix, cluster_method = "hierachical")


def cluster2d_elements_multiply(matrix):
    binary_adj = np.zeros_like(matrix)
    for i in range(matrix.shape[0]):
        row = np.delete(matrix[i,:], i)
        column = np.delete(matrix[:, i], i)
        row = row/np.max(row)
        column = column/np.max(column)
        combined = row * column
        labels = cluster_elements_hierachical(combined).flatten()
        binary_adj[i, :i] = labels[:i]
        binary_adj[i, i+1:] = labels[i:]

        binary_adj[:i, i] = labels[:i]
        binary_adj[i+1:, i] = labels[i:]

    return binary_adj

def cluster_elements_maxmargin(vector):
    vector = np.array(vector).flatten()
    best_threshold = None
    best_score = float('-inf')
    epsilon = 1e-10  # to avoid boundary values
    grid = np.linspace(np.min(vector) + epsilon, np.max(vector) - epsilon, 1000)
    
    for threshold in grid:
        largest_smaller = np.max(vector[vector < threshold])
        smallest_larger = np.min(vector[vector > threshold])
        score = np.sqrt(threshold - largest_smaller) + np.sqrt(smallest_larger - threshold)
        if score > best_score:
            best_score = score
            best_threshold = threshold
    
    labels = (vector > best_threshold).astype(int)
    return labels




def get_cluster_methods():
    return {"gmm": cluster_elements_gmm, "hierachical": cluster_elements_hierachical, "maxmargin": cluster_elements_maxmargin, "2dmultiply": cluster2d_elements_multiply, "2dnormalize":cluster2d_elements_normalize}

def cluster_elements(vector, cluster_method=None):
    if cluster_method is None:
        cluster_method = "hierachical"
    if "str" in str(type(cluster_method)):
        cluster_method = get_cluster_methods()[cluster_method.lower()]

    X = np.array(vector).reshape(-1, 1)
    labels = cluster_method(vector)

    if np.mean(X[labels == 0]) > np.mean(X[labels == 1]):
        labels = 1 - labels
    if np.min(labels) == np.max(labels): # only one cluster
        labels = labels * 0 + 1 # assume a node connected to all other nodes instaed of an isolated node

    return labels

## Graph Generation

In [None]:
def get_graph_generators():

  G_grid_50 = nx.grid_2d_graph(5,5)
  G_erdos_50 = nx.erdos_renyi_graph(50, 0.15, seed=42)
  G_wsn_50 = nx.newman_watts_strogatz_graph(50, 4, 0.15, seed=42)

  ground_truth_graphset = dict()

  for n in [50]:
    for name in ["G_erdos_", "G_wsn_", "G_grid_"]:
      if "grid" in name and n in [25, 200]:
        continue
      try:
        ground_truth_graphset[name+str(n)] = eval(name+str(n))
      except:
        print("coud not create", name+str(n))

  ground_truth_graphset = {graph_name: clean_shuffle_graph(g) for graph_name, g in ground_truth_graphset.items()}
  return ground_truth_graphset


## Baseline

In [None]:
def binarize_weight_matrix(W, cluster_method = None):
  adj_matrix = np.zeros_like(W)
  if "2d" in cluster_method:
    cluster_method = get_cluster_methods()[cluster_method.lower()]
    return cluster_method(W)
  for i in range(W.shape[0]):
    row, _ = exclude_index(W[i,:].reshape(1,-1), i)
    row_binary = cluster_elements(row, cluster_method=cluster_method).flatten()

    adj_matrix[i, :i] = row_binary[:i]
    adj_matrix[i, i+1:] = row_binary[i:]

  return adj_matrix


def binarize_weight_matrix_multi(impact_scores):
  # Compute binary impact score (graph)
  cluster_method_names = sorted(get_cluster_methods().keys())
  impact_scores_binary = dict()
  for cluster_method in cluster_method_names:
    impact_scores_binary[cluster_method] = binarize_weight_matrix(impact_scores, cluster_method = cluster_method)

  # Compute symmetry
  impact_scores_symmetric = {'weighted': impact_scores}
  for cluster_method, adj_matrix in impact_scores_binary.items():
    impact_scores_symmetric[cluster_method+'_directed'] = adj_matrix
    impact_scores_symmetric[cluster_method+'_up'] = (adj_matrix + adj_matrix.T > 0.5).astype(float)
    impact_scores_symmetric[cluster_method+'_down'] = (adj_matrix + adj_matrix.T > 1.5).astype(float)

  return impact_scores_symmetric



def reduce_to_argmax(snapshot_tensor):
  snapshot_tensor = snapshot_tensor.squeeze()
  if len(snapshot_tensor.shape) > 1:
    #print("reduce snapshot")
    snapshot_tensor = torch.argmax(snapshot_tensor,dim=1)  #todo does not work because might be same
  return snapshot_tensor.flatten()


def convert_snapshots_to_TS(train_loader, test_loader):
  print("each snapshot has dim:", test_loader.dataset[0].shape)
  node_num, _= get_node_num_and_dim_from_loader(train_loader)
  observations = len(train_loader.dataset) + len(test_loader.dataset)
  TS = np.zeros([node_num, observations])
  for i in range(len(train_loader.dataset)):
    TS[:,i] = reduce_to_argmax(train_loader.dataset[i])
  shift = len(train_loader.dataset)
  for i in range(len(test_loader.dataset)):
    TS[:,i+shift] = reduce_to_argmax(test_loader.dataset[i])
  return TS


BASELINE_METHODS = {
          'CorrelationMatrix': netrd.reconstruction.CorrelationMatrix(),
          'MutualInformationMatrix': netrd.reconstruction.MutualInformationMatrix(),
          'PartialCorrelationMatrix': netrd.reconstruction.PartialCorrelationMatrix()}

def get_specific_baseline(TS, recon):
  start_time = time.time()
  _ = recon.fit(TS, threshold_type='degree', avg_k = 5)  #avg_k should not matter
  W = recon.results['weights_matrix']  # todo symmetric?
  time_elapsed = time.time()-start_time
  return W, time_elapsed

def get_baseline(train_loader, test_loader):
  time_elapsed = dict()
  TS = convert_snapshots_to_TS(train_loader, test_loader)
  print("TS is", TS.shape, TS )
  weighted_adj_matrix_dict = dict()
  for name, recon in BASELINE_METHODS.items():
    weighted_adj_matrix, time = get_specific_baseline(TS, recon)
    weighted_adj_matrix_dict[name] = weighted_adj_matrix
    time_elapsed[name] = time

  binary_adj_matrix_dict = dict()
  time_elapsed_dict_new = dict()
  for name, weighted_adj_matrix in weighted_adj_matrix_dict.items():
    results = binarize_weight_matrix_multi(weighted_adj_matrix)
    for binarization_name, binary_adj_matrix in results.items():
      binary_adj_matrix_dict[name+'_'+binarization_name] = binary_adj_matrix
      time_elapsed_dict_new[name+'_'+binarization_name] = time_elapsed[name]

  return binary_adj_matrix_dict, time_elapsed_dict_new

## Data Generation

In [None]:
def gen_cascade(G, agent_rng, steps=None, become_active_prob = 0.5):
  record = list()
  S = [1, 0, 0]
  R = [0, 1, 0]
  I = [0, 0, 1]


  if steps is None:
    steps = 5000
  states = [S] * G.number_of_nodes()
  states[random.choice(range(G.number_of_nodes()))] = I
  become_active_prob_of_n = [agent_rng.random() for i in range(G.number_of_nodes())]


  for _ in range(100000):
    new_states = list(states)
    for n in G.nodes():
      if states[n] == I:
        new_states[n] = R
        continue
      if states[n] == S and len([n_j for n_j in G.neighbors(n) if states[n_j] == I]) > 0:
        if random.random() < become_active_prob_of_n[n]:
            new_states[n] = I
        else:
            new_states[n] = R

    states = list(new_states)
    record.append(list(states))
    if len([n_j for n_j in G.nodes() if states[n_j] == I]) == 0:
        break

  last_state = record[-1]
  last_state = [s[0:2] for s in last_state]
  return last_state 


def gen_mixedsis(G, agent_rng, inf_rate=1.0, rec_rate=2.0, noise=0.1):
  S = [1., 0.]
  I = [0., 1.]
  steps = 1000 + random.choice(range(1000))
  states = [random.choice([S, I]) for i in range(G.number_of_nodes())]
  inf_rate_list = agent_rng.exponential(scale=inf_rate, size=G.number_of_nodes())
  rec_rate_list = agent_rng.exponential(scale=rec_rate, size=G.number_of_nodes())
  for _ in range(steps):
    rates = np.zeros(G.number_of_nodes())
    for n in range(G.number_of_nodes()):
      rates[n] = noise
      if states[n] == I:
        rates[n] += rec_rate_list[n]
      if states[n] == S:
        rates[n] += np.sum([inf_rate_list[n_j] for n_j in G.neighbors(n) if states[n_j] == I])
      rates[n] = 1.0/rates[n]
    jump_time = np.random.exponential(rates)
    change_n = np.argmin(jump_time)
    states[change_n] = S if states[change_n] == I else I
  return states


# opinion
def gen_voterpartinv(G, agent_rng, noise=0.01):
  A = [1., 0.]
  B = [0., 1.]
  steps = 1000 + random.choice(range(1000))
  states = [random.choice([A, B]) for i in range(G.number_of_nodes())]
  type_of_node = [agent_rng.random() for i in range(G.number_of_nodes())] # <0.5 is voter, >=0.5 is inv voter
  for _ in range(steps):
    rates = np.zeros(G.number_of_nodes())
    for n in range(G.number_of_nodes()):
      if type_of_node[n] >= 0.5:
        rates[n] = len([n_j for n_j in G.neighbors(n) if states[n_j] == states[n]]) + noise
      else:
        rates[n] = len([n_j for n_j in G.neighbors(n) if states[n_j] != states[n]]) + noise
      rates[n] = 1.0/rates[n] 
    jump_time = np.random.exponential(rates)
    change_n = np.argmin(jump_time)
    states[change_n] = A if states[change_n] == B else B
  return states


def gen_majority(G, agent_rng, change_rate=1.0, noise=0.05):
    A = [1., 0.]
    B = [0., 1.]
    steps = 1000 + random.choice(range(1000))
    states = [random.choice([A, B]) for i in range(G.number_of_nodes())]
    slow_down = [agent_rng.random() for i in range(G.number_of_nodes())]
    for _ in range(steps):
        rates = np.zeros(G.number_of_nodes())
        for n in range(G.number_of_nodes()):
            rates[n] = noise
            neig_A = len([n_j for n_j in G.neighbors(n) if states[n_j] == A])
            neig_B = len([n_j for n_j in G.neighbors(n) if states[n_j] == B])
            if states[n] == A and neig_B > neig_A:
                rates[n] += change_rate
            if states[n] == B and neig_A > neig_B:
                rates[n] += change_rate
            rates[n] *= slow_down[n]
            rates[n] = 1.0/rates[n] # numpy uses mean as rate param
        jump_time = np.random.exponential(rates)
        change_n = np.argmin(jump_time)
        states[change_n] = A if states[change_n] == B else B
    return states

#### Data Loader

In [None]:
def create_snapshots(graph_nx, dynamics, num_samples, exp_name):

  assert all(char.isalpha() for char in dynamics)
  func = eval("gen_"+dynamics)

  snapshots = list()
  print("create snapshots for", dynamics)
  for i in tqdm(range(num_samples)):
    seed = str_to_int(exp_name)
    if DIFFERENT_PARAMS_FOR_SNAPSHOTS:
      seed += i
    agent_rng = np.random.RandomState(seed)
    snapshots.append(func(graph_nx, agent_rng))

  plot_snapshots(graph_nx, snapshots, exp_name)

  dataset = list()
  for s in snapshots:
    # check if dim == 1
    if len(s[0]) == 2:  
      s = torch.Tensor(s) * 2.0 - 1.0  # we want -1 values instaed of 0, 0 are for masked values
      s = s[:,:-1].squeeze()  # the last column is redundant
    else:
      s = torch.Tensor(s)
    dataset.append(s)
  return dataset

def get_dataloder(graph_nx, dynamics, exp_name):
  train_set = create_snapshots(graph_nx, dynamics, get_config("num_train_sample"), exp_name)
  train_loader = DataLoader(train_set, batch_size=get_config("batch_size"), shuffle=True)

  test_set = create_snapshots(graph_nx, dynamics, get_config("num_test_sample"), exp_name)
  test_loader = DataLoader(test_set, batch_size=get_config("batch_size"), shuffle=True)

  return train_loader, test_loader


def run_get_dataloder(exp_name, graph_nx, dynamics):
  pickle_path = f"{exp_name}_snapshots.pickle"

  if os.path.exists(pickle_path):
    with open(pickle_path, "rb") as f:
      train_loader, test_loader = pickle.load(f)
      print("Found ", pickle_path)
      return train_loader, test_loader
  else:
    print(pickle_path, "not found")
    train_loader, test_loader = get_dataloder(graph_nx, dynamics, exp_name)
    with open(pickle_path, "wb") as f:
      pickle.dump((train_loader, test_loader), f)

  return train_loader, test_loader


## NN Model

In [None]:
class MLP(nn.Module):
  def __init__(self, node_num=9, node_dim=1, num_layers=get_config("layer_num"), clamp_output = True):
    super(MLP, self).__init__()
    print(f"create model with node_num {node_num} and node dim {node_dim}")
    self.clamp_output = clamp_output
    self.node_dim = node_dim
    self.input_size = node_num * node_dim
    self.hidden_dim = get_config("hidden_size_scale") * self.input_size

    self.layers = nn.ModuleList()

    self.layers.append(nn.Linear(self.input_size, self.hidden_dim))
    for _ in range(num_layers - 2):
      self.layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
    self.layers.append(nn.Linear(self.hidden_dim, node_dim))

    self.dropout = nn.Dropout(get_config("dropout"))

  def forward(self, x):
    # fix dims
    if len(x.shape) == 1:
      x = x.view(1,-1,1)
    if len(x.shape) == 2:
      x = x.view(x.shape[0],x.shape[1],1)
    assert(len(x.shape) == 3)
    batch_dim = x.shape[0]
    node_num = x.shape[1]
    node_dim = x.shape[2]
    x = x.view(-1, node_num * node_dim)
    assert(x.shape[0] == batch_dim)

    # actual forward pass
    x = self.dropout(x)
    for layer in self.layers[:-1]:
      x = F.relu(layer(x))

    x = self.layers[-1](x)
    if self.clamp_output:
      x = (torch.sigmoid(x) * 2.1) - 1.0 # values between -1 and 1, use 2.1 to give a little bit of extra space
    else:
      x = F.relu(x)
    assert(batch_dim == x.shape[0] and node_dim == x.shape[1])
    return x



# Instantiate the MLP
model = MLP(node_num=20, node_dim=3)

# Test it with random data
dataset = torch.randn(10,20,3)
output = model(dataset)
print(output)

In [None]:
def train_model(node_index, train_loader):
  node_num, node_dim = get_node_num_and_dim_from_loader(train_loader)
  model = MLP(node_num=node_num-1, node_dim=node_dim)
  criterion = nn.MSELoss()  # Use appropriate loss function for your problem
  optimizer = optim.Adam(model.parameters())

  model.train()

  # Training loop
  for epoch in tqdm(range(get_config("num_epoch"))):  # 100 epochs, adjust as needed
    epoch_loss = list()
    for data in train_loader:
      data_in, data_gt = exclude_index(data, node_index)#

      optimizer.zero_grad()

      # Forward pass
      output = model(data_in)

      # Calculate the loss
      loss = criterion(data_gt.squeeze(), output.squeeze())
      epoch_loss.append(loss.item())

      loss.backward()
      optimizer.step()

  print(f"Epoch {epoch+1}, Loss: {np.mean(epoch_loss)}")

  return model


In [None]:
def reorder_rows(data):
  #print("there should be a one in each row", data.shape, data)
  assert len(data.shape) == 2
  rand_columns = [0,2,1]
  random.shuffle(rand_columns)
  while rand_columns == [0,1,2]:
    random.shuffle(rand_columns)
  data_new = torch.index_select(data, 1, torch.LongTensor(rand_columns))
  return data_new


def test_model_saliency(model, node_index, test_loader):
  node_num, _= get_node_num_and_dim_from_loader(test_loader)

  impact_score = torch.zeros(node_num-1)
  model.train()

  for data in test_loader:
    model.zero_grad()
    data_in, data_gt = exclude_index(data, node_index)
    data_in = data_in.view(data_in.shape[0], data_in.shape[1], -1)
    data_in.requires_grad = True
    output = model(data_in)
    output = torch.sum(torch.abs(output))
    output.backward()
    importance_score = torch.sum(torch.abs(data_in.grad),dim=(0,2))
    print("importance_score", importance_score.shape, importance_score)
    impact_score = impact_score + importance_score

  impact_score = impact_score.detach().cpu().flatten().numpy()
  return impact_score



def test_model(model, node_index, test_loader, method):
  assert(method in ["permutation", "masking", "saliency"])
  if method == "saliency":
    return test_model_saliency(model, node_index, test_loader)
  node_num, _= get_node_num_and_dim_from_loader(test_loader)

  impact_score = torch.zeros(node_num-1)

  criterion = nn.MSELoss()
  model.eval()

  for data in test_loader:
    #print("shape of data: ",data.shape, data)
    data_in, data_gt = exclude_index(data, node_index)
    output = model(data_in)
    loss_baseline = criterion(data_gt.squeeze(), output.squeeze())
    for alterd_i in range(node_num-1):
      data_in_altered = data_in.clone()
      if method == "masking":
        data_in_altered[:, alterd_i] =  0.0
      elif method == "permutation":
        if model.node_dim == 1:
          data_in_altered[:, alterd_i] =  -1.0 * data_in_altered[:, alterd_i] #resample
        else:
          data_in_altered[:, alterd_i, :] =  reorder_rows(data_in_altered[:, alterd_i, :].squeeze()) #this is still wrong
      else:
        raise ValueError("method not found")
      assert(data_in_altered.shape == data_in.shape)
      output_altered = model(data_in_altered)
      loss_i = criterion(data_gt.squeeze(), output_altered.squeeze())
      loss_relative_change = loss_i/(loss_baseline+0.00000000001)
      impact_score[alterd_i] += loss_relative_change
  impact_score = impact_score.detach().cpu().flatten().numpy()
  return impact_score

## Graph Inference

In [None]:
def infer_graph(train_loader, test_loader, method=False):
  assert(method in ["permutation", "masking", "saliency"])
  node_num, _= get_node_num_and_dim_from_loader(train_loader)
  weighted_adj_matrix = np.zeros([node_num, node_num])
  time_start = time.time()
  if get_config("use_train_as_test"):
    test_loader = train_loader # todo, concat datasets?

  for v_i in range(node_num):
    model = train_model(v_i, train_loader)
    impact_score = test_model(model, v_i, test_loader, method=method)
    impact_score = impact_score/np.sum(impact_score)
    impact_score = impact_score.flatten()
    weighted_adj_matrix[v_i, :v_i] = impact_score[:v_i]
    weighted_adj_matrix[v_i, v_i+1:] = impact_score[v_i:]

  time_elapsed = time.time() - time_start

  binary_adj_matrix_dict = binarize_weight_matrix_multi(weighted_adj_matrix)
  binary_adj_matrix_dict = {f'our method ({k})': v for k,v in binary_adj_matrix_dict.items()}
  time_elapsed_dict = {name: time_elapsed for name in binary_adj_matrix_dict.keys()}

  return binary_adj_matrix_dict, time_elapsed_dict

## Experiments

In [None]:
def get_graph_distance_dict(adj_dict, adj_matrix_gt):
  results = dict()
  for name, adj_matrix_pred in adj_dict.items():
    results[name] = compute_graph_distance(adj_matrix_pred, adj_matrix_gt)
  return results


In [None]:
def experiment(exp_name, graph, dynamics):
  set_seed(f"{exp_name}_snapshot")
  adj_matrix_gt = nx.to_numpy_array(graph)

  # our method
  train_loader, test_loader =  run_get_dataloder(exp_name, graph, dynamics)
  set_seed(exp_name + "_infergraph")

  # with permutation
  binary_adj_dict, time_elapsed_dict = infer_graph(train_loader, test_loader, method="permutation")
  graph_distance_dict = get_graph_distance_dict(binary_adj_dict, adj_matrix_gt)
  print("graph_distance_dict: ", graph_distance_dict)

  # with masking
  binary_adj_mask_dict, time_elapsed_mask_dict = infer_graph(train_loader, test_loader, method="masking")
  graph_distance_mask_dict = get_graph_distance_dict(binary_adj_mask_dict, adj_matrix_gt)
  print("graph_distance_dict: ", graph_distance_mask_dict)
  for name, dist in graph_distance_mask_dict.items():
    graph_distance_dict[name + " (masked)"] = dist

  # with saliency
  binary_adj_mask_dict, time_elapsed_mask_dict = infer_graph(train_loader, test_loader, method="saliency")
  graph_distance_mask_dict = get_graph_distance_dict(binary_adj_mask_dict, adj_matrix_gt)
  print("graph_distance_dict: ", graph_distance_mask_dict)
  for name, dist in graph_distance_mask_dict.items():
    graph_distance_dict[name + " (saliency)"] = dist


  # visualize our method
  for name, adj_matrix in binary_adj_dict.items():
    print(name)
    visualize_matrix(adj_matrix, exp_name+f"_impact_scores_{name}.png")
  visualize_matrix(adj_matrix_gt, exp_name+"_adj_matrix_gt.png")

  # baseline
  set_seed(exp_name + "_baseline")
  binary_adj_dict_baseline, time_elapsed_dict_baseline = get_baseline(train_loader, test_loader)
  graph_distance_dict_baseline = get_graph_distance_dict(binary_adj_dict_baseline, adj_matrix_gt)

  # merge
  graph_distance_dict.update(graph_distance_dict_baseline)
  time_elapsed_dict.update(time_elapsed_dict_baseline)


  results = {'adj_matrix_gt': adj_matrix_gt, "adj_matrix_pred": binary_adj_dict, "adj_matrix_baseline": graph_distance_dict_baseline}

  return graph_distance_dict, time_elapsed_dict, results


def run_experiment(exp_name, graph, dynamics):
  pickle_path = exp_name + '_solution.pickle'

  if os.path.exists(pickle_path):
    with open(pickle_path, "rb") as f:
      graph_distance_dict, time_elapsed_dict, results = pickle.load(f)
      print("Found ", pickle_path)
      print("Graph distance: ", graph_distance_dict)
  else:
    graph_distance_dict, time_elapsed_dict, results = experiment(exp_name, graph, dynamics)
    with open(pickle_path, "wb") as f:
      pickle.dump((graph_distance_dict, time_elapsed_dict, results), f)

  return graph_distance_dict, time_elapsed_dict, results

def save_src_file():
  for python_file in sorted(glob.glob('*.ipynb')):
    wandb.log_artifact(python_file, name=f"src_ipynb_{SWEEP_ID}", type="my_dataset")
  for python_file in sorted(glob.glob('*.py')):
    wandb.log_artifact(python_file, name=f"src_py_{SWEEP_ID}", type="my_dataset")

    
def extract_positions(graph_distance_dict):
  print("extract pos")
  print("graph dist dict ", graph_distance_dict)
  graph_distance_list = list(graph_distance_dict.items())
  gd_list = sorted(graph_distance_list, key = lambda x: x[1])
  gd_list = [method for method, _ in gd_list]
  print("gd_list ", gd_list)
  pos_dict = {method: gd_list.index(method) for method in graph_distance_dict}
  print("pos_dict ", pos_dict)
  return pos_dict    

In [None]:
def start_agent():
  results = {"exp_name": list(), "method": list(), "runtime": list(), "graph_distance": list(), "position": list()}
  graphs = get_graph_generators()
  graph_names = sorted(graphs.keys())
  prefix = str(PROJECT_NAME + str(int(random.random()*10000))) #+'_'+str(SWEEP_ID)+'_'+str(int(random.random()*10000))
  print("running", prefix)

  dynamical_models = sorted([var.replace("gen_","") for var in globals() if var.startswith("gen_")])
  dynamical_models = ['cascade', 'majority', 'mixedsis', 'voterpartinv']
  position_store = list()


  # actual runs
  counter = 0
  for run_i in range(NUM_RUNS):
    for graphname in graph_names:
      graph = graphs[graphname]
      for dynamics in dynamical_models:
        counter += 1
        exp_name = f"exp_{prefix}_{graphname}_{dynamics}_{run_i+1:03}" # "exp_"+graphname+"_"+dynamics+"_"+str(i+1).zfill(3)
        print("start experiment:", exp_name)
        graph_distance_dict, time_elapsed_dict, _ = run_experiment(exp_name, graph, dynamics)
        graph_position_dict = extract_positions(graph_distance_dict)
        results = add_exp_to_results(exp_name, results, graph_distance_dict, graph_position_dict, time_elapsed_dict, filename=f"results_{prefix}.csv")
        p = get_position(graph_distance_dict)
        position_store.append(p)
        print("position",p)
        if USE_WANDB:
          wandb.log({"position": p, "mean position": np.mean(position_store), "counter": counter})
          wandb.log_artifact(f"results_{prefix}.csv", name=f"results_{prefix}.csv", type="my_dataset")
          wandb.log({"results": wandb.Table(dataframe=pd.read_csv(f"results_{prefix}.csv"))})

        # final logging (each iteration)
        df = pd.DataFrame(results)
        df = df[df['method'].str.contains('hierachical_down')]
        df = df.sort_values(by=['exp_name', 'graph_distance'])
        df.to_csv(f"results_sorted_{prefix}.csv")
        if USE_WANDB:
          wandb.log_artifact(f"results_sorted_{prefix}.csv", name=f"results_sorted_{prefix}.csv", type="my_dataset")
          wandb.log_artifact(f"results_{prefix}.csv", name=f"results_{prefix}.csv", type="my_dataset")
          wandb.log({"table_sorted": wandb.Table(dataframe=df)})


  # final positions (only once)
  df = pd.DataFrame(results)
  mean_values = df.groupby('method')['position'].mean()
  result_list = list(mean_values.items())
    
  # sort by method name
  result_list = sorted(result_list, key = lambda x: x[0] if "our method" not in x[0] else ' '+x[0]) # our method should be first
  print("final absolute positions\n", result_list, "\n and now line by line:")
  for k,v in result_list:
    print(k, v)
  result_list_value = [value for method, value in result_list]
  if USE_WANDB:
    for i, value in enumerate(result_list_value):
      wandb.log({"absolute positions": value, "counter_ap":i})
    
  # sort by mean position
  result_list = sorted(result_list, key = lambda x: x[1])
  print("methods sorted")
  for method, mean_pos in result_list:
    print(method, mean_pos)
  df = pd.DataFrame(result_list, columns=['Method', 'Mean Pos'])
  df.to_csv(f"results_meanpos_{prefix}.csv", index=False)
  if USE_WANDB:
    wandb.log_artifact(f"results_meanpos_{prefix}.csv", name=f"results_meanpos_{prefix}.csv", type="my_dataset")
    wandb.log({"table_mean_pos": wandb.Table(dataframe=df)})




def main():
  with wandb.init():
    save_src_file()
    try:
      return start_agent()
    except Exception as e:
      error_message = traceback.format_exc()
      print("final in main error:\n", error_message)
      with open('_error_log.txt', 'a') as f:
        f.write(error_message + '\n')
      time.sleep(1)

def start_with_wandb():
  global SWEEP_ID, USE_WANDB
  print("start experiments")
  USE_WANDB = True
  os.environ["WANDB_MODE"] = WANDB_MODE
  try:
    SWEEP_ID = wandb.sweep(sweep_config, project=PROJECT_NAME)
    wandb.agent(SWEEP_ID, function=main, count=1)
  except Exception as e:
    error_message = traceback.format_exc()
    print("final error:\n", error_message)
    with open('_error_log.txt', 'a') as f:
      f.write(error_message + '\n')
    time.sleep(10)

def start_without_wandb():
  global SWEEP_ID, USE_WANDB
  USE_WANDB = False
  SWEEP_ID = "00000000"
  start_agent()

set_seed('42')
for _ in range(20):
  start_with_wandb()