<a href="https://colab.research.google.com/github/dbizzaro/Minesweeper/blob/main/SSRF_Algorithm_for_causal_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install igraph

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
!git clone https://github.com/xunzheng/notears.git

fatal: destination path 'notears' already exists and is not an empty directory.


In [3]:
import numpy as np
import igraph as ig

In [4]:
import scipy.linalg as slin
import scipy.optimize as sopt
from scipy.special import expit as sigmoid
import random

import sys
sys.path.insert(0,'/notears/notears/') 

from notears.notears.linear import notears_linear
from notears.notears.utils import is_dag, simulate_dag, simulate_parameter, simulate_linear_sem, count_accuracy

In [5]:
def minplus_sum(A,B):
  '''
  matricial sum in the minplus semiring
  '''
  m, n = A.shape
  Y = np.zeros((m,n))
  for i in range(m):
    for j in range(n):
      Y[i,j] = min(A[i,j], B[i,j])
  return Y

def minplus_product(A,B):
  '''
  matricial sum in the minplus semiring
  see https://stackoverflow.com/questions/47359743/how-to-make-min-plus-matrix-multiplication-in-python-faster for a better implementation
  '''
  m = A.shape[0]
  n = B.shape[1]
  Y = np.zeros((m,n))
  for i in range(m):
    for j in range(n):
      Y[i,j] = min(A[i,:] + B[:,j])
  return Y

def minplus_identity(n):
  '''
  produce the identity matrix of dimension nxn wrt the minplus semiring
  '''
  matrix = np.full((n,n), np.inf)
  for i in range(n):
    matrix[i,i] = 0
  return matrix

In [6]:
def transitive_closure(W):
  '''
  Floyd-Warshall algorithm for computing the "shortest path" weighted transitive closure
  '''
  W_bar = W.copy()
  n = W.shape[0]
  for k in range(n):
    for i in range(n):
      for j in range(n):
        W_bar[i,j] = min(W_bar[i,j], W_bar[i,k]+W_bar[k,j])
  return W_bar

In [7]:
def signals_from_spectra(W, C, W_transitively_closed = False):
  '''
  returns matrix of signals, given matrix of weights and matrix of spectra
  set W_transitively_closed to True only if sure that W is already transitively closed
  '''
  m, n = C.shape
  if not W_transitively_closed:
    W = transitive_closure(W)
  X = minplus_sum(C, minplus_product(C, W))
  return X

In [8]:
def from_semiring(W):
  return np.where(W == np.inf, 0, W)

def to_semiring(W):
  return np.where(W == 0, np.inf, W)

In [9]:
def simulate_weighted_dag(B, range = (0.4,0.8)):
  '''
  generate matrix of random weights, given an adjacency matrix
  the weights have uniform distribution between 0 and "range"
  '''
  W = (range[0] + (range[1]-range[0]) * np.random.rand(B.shape[0], B.shape[1])) * B
  return to_semiring(W)

In [17]:
def uniform_random_spectra(m, n, range=2, prob = 0.1):
  '''
  generate matrix of random spectra, given the dimensions
  the elements have probability "prob" to be finite, and in that case have uniform distribution between 0 and "range"
  '''
  C = range * np.random.rand(m,n) * np.random.binomial(1, prob, (m,n))
  return to_semiring(C)


def uniform_random_noise(m, n, range=0.1):
  return range * np.random.rand(m,n)

def gaussian_random_noise(m, n, sigma=0.03):
  return np.random.normal(0, sigma, size=(m, n))

In [11]:
def core_algorithm(X, eps = 1e-12):
  '''
  causal learning algorithm that returns a possible matrix of weights given a set of signals. 
  eps represents the tolerance wrt equality
  d represents the number of signals that could indicate edges that we ask for before drawing that edge
  '''
  m, n = X.shape
  maximums = np.zeros((n,n), dtype=X.dtype)
  signals = [[[] for i in range(n)] for j in range(n)]
  for i in range(n):
    for j in range(n):
      if i != j:
        for k in range(m):
          if np.isposinf(X[k,i]):
            w = -np.inf
          elif np.isposinf(X[k,j]):
            maximums[i,j] = np.inf
            break
          else:
            w = X[k,j] - X[k,i]
          if w > maximums[i,j]:
            maximums[i,j] = w
            if w > maximums[i,j] + eps:
              signals[i][j] = []
            signals[i][j] = [k2 for k2 in signals[i][j] if X[k2,j]-X[k2,i]>w-eps]
          if abs(w - maximums[i,j]) < eps:
            signals[i][j].append(k)
  counters = np.zeros((n,n), dtype=np.int_)
  for i in range(n):
    for j in range(n):
      if i != j and len(signals[i][j]) > 1 and maximums[i,j] > 0 and not np.isposinf(maximums[i,j]):
        for k in signals[i][j]:
          flag = False
          for l in range(n):
            if l != i and l != j and abs(X[k,j]-X[k,l]-maximums[l,j])<eps and abs(X[k,i]-X[k,l]-maximums[l,i])<eps and X[k,j]-X[k,l]>0 and X[k,i]-X[k,l]>0:
              flag = True
              break
          if flag == False:
            counters[i,j] += 1
  return maximums, counters


def reconstruct_W(maximums, counters, threshold):
  W_hat = np.where(counters > threshold, maximums, np.inf)
  return transitive_closure(W_hat)


def algorithm(X, threshold=1, eps=1e-12):
  maximums, counters = core_algorithm(X, eps)
  return reconstruct_W(maximums, counters, threshold)

In [12]:
def print_differences(W_real, W_reconstructed, eps=1e-12):
  '''
  prints the number of edges in "W_real" that are not in "W_reconstructed"
  '''
  counter_missed = 0
  counter_greater = 0
  counter_lower = 0
  counter_added = 0
  for i in range(W_real.shape[0]):
    for j in range(W_real.shape[1]):
      if np.isposinf(W_real[i,j]) and not np.isposinf(W_reconstructed[i,j]):
        counter_added += 1
      elif not np.isposinf(W_real[i,j]) and np.isposinf(W_reconstructed[i,j]):
        counter_missed += 1
      elif not np.isposinf(W_real[i,j]) and not np.isposinf(W_reconstructed[i,j]) and W_reconstructed[i,j]-W_real[i,j]>eps:
        counter_greater += 1
      elif not np.isposinf(W_real[i,j]) and not np.isposinf(W_reconstructed[i,j]) and -W_reconstructed[i,j]+W_real[i,j]>eps:
        counter_lower += 1
  print("Total number of edges:                    {}".format((W_real != np.inf).sum()))
  print("Edges missed:                             {}".format(counter_missed))
  print("Edges that are not there:                 {}".format(counter_added))
  print("Edges found with weight larger than real: {}".format(counter_greater))
  print("Edges found with weight lower than real:  {}".format(counter_lower))  


In [13]:
m, n = 500, 100
W = simulate_weighted_dag(simulate_dag(n, 500, 'ER'))
W_bar = transitive_closure(W)
X = signals_from_spectra(W_bar, uniform_random_spectra(m,n), True) + uniform_random_noise(m,n)
W_bar

array([[        inf, 14.37473463,         inf, ..., 12.53764271,
                inf,         inf],
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf],
       [        inf,         inf,         inf, ...,  9.22188749,
                inf,         inf],
       ...,
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf],
       [        inf, 10.75856064,         inf, ..., 10.50662009,
                inf,         inf],
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf]])

In [14]:
maximums, counters = core_algorithm(X, eps=2)

In [15]:
W_reconstructed = reconstruct_W(maximums, counters, 7)
print(is_dag(from_semiring(W_reconstructed)))
W_reconstructed

True


array([[        inf, 15.26491593,         inf, ..., 13.09793552,
                inf,         inf],
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf],
       [        inf,         inf,         inf, ..., 10.02888993,
                inf,         inf],
       ...,
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf],
       [        inf, 11.57794915,         inf, ..., 11.13974881,
                inf,         inf],
       [        inf,         inf,         inf, ...,         inf,
                inf,         inf]])

In [16]:
print_differences(W_bar, W_reconstructed, eps=2)

Total number of edges:                    2812
Edges missed:                             96
Edges that are not there:                 0
Edges found with weight larger than real: 87
Edges found with weight lower than real:  0
