In [None]:
import networkx as nx
import random
import numpy as np
import matplotlib.pyplot as plt
import math

In [None]:
def card(g,**kwargs):
  weight=kwargs.get("weight","weight")
  epsilon=kwargs.get("epsilon",0.0000001)
  t=nx.get_node_attributes(g,name=weight)
  nodelist=list(t)
  start_weights=np.array(list(t.values()))
  adj=nx.convert_matrix.to_numpy_array(g,nodelist)
  adj=adj.T

  weights=start_weights
  finished=False
  adjsum=np.sum(adj,axis=1)
  while not finished:
    new_weights=start_weights/(1+adjsum+np.divide(np.matmul(adj,weights),adjsum,out=np.zeros_like(weights), where=adjsum!=0))
    if np.linalg.norm(abs(new_weights-weights)/new_weights)<epsilon:
      finished=True
    weights=new_weights
  return weights

def h_cat(g,**kwargs):
  """takes an argumentation graph as input. Each node must have a weight parameter, default "weight". Epsilon tells us when to stop iterating"""
  weight=kwargs.get("weight","weight")
  epsilon=kwargs.get("epsilon",0.0000001)
  t=nx.get_node_attributes(g,name=weight)
  nodelist=list(t)
  start_weights=np.array(list(t.values()))
  adj=nx.convert_matrix.to_numpy_array(g,nodelist)
  adj=adj.T

  weights=start_weights
  finished=False
  while not finished:
    new_weights=start_weights/(np.matmul(adj,weights)+1)
    if np.linalg.norm(abs(new_weights-weights)/new_weights)<epsilon:
      finished=True
    weights=new_weights
  return weights

In [None]:
cache={frozenset([0]):0}

def init_ssp():
  global cache
  cache={frozenset([0]):0}

def ssp_dfs(t,s,candidate,**kwargs):
  """should return the subset of s used to add up to t, within a tolerance tol (default 0.00001).
     Does a DFS storing the results in the cache for use later until we exceed the target t.
  """
  global cache
  ans=None
  tol=kwargs.get("tolerance",0.0000001)
  k=kwargs.get("k",-1)
  #print(t)
  for a in s:
    #print(candidate,a)
    if a in candidate:
      continue

    new_candidate=frozenset(candidate.union([a]))
    if k!=-1 and len(new_candidate)>k: #don't bother exploring if we have more than k elements in our set
        continue

    if cache.get(new_candidate,None)==None:
        cache[new_candidate]=cache[candidate]+a
    #if cache[new_candidate]==t:  #commented out as this is the precise solution
    if abs(cache[new_candidate]-t)<tol: #found a solution within tolerance
        return new_candidate
    if cache[new_candidate]<t: 
        ans=ssp_dfs(t,s,new_candidate)
        if ans!=None:
          return ans
    else: #we are bigger than t so can just end this branch
          continue
  return None


In [None]:
def SSP_Solve_hc(g):
  my_graph=nx.DiGraph() #my_graph is a graph with just the nodes, containing associated data from g (weight and FAD)
  my_graph.add_nodes_from(g.nodes(data=True))

  s=[] #the set of FAD numbers used for SSP
  for n in g.nodes:
    s.append(g.nodes[n]["FAD"])

  for n in g.nodes: #for each node, compute the target sum based for h-cat
    t=(g.nodes[n]["weight"]-g.nodes[n]["FAD"])/g.nodes[n]["FAD"]
    
    if t==0:
      continue #no need to add an edge for this node as its not attacked

    init_ssp()
    ans=ssp_dfs(t,s,frozenset([0]))  #now try to find the SSP that sums up to the target sum for that node 
  
    if ans==None: #we couldn't find the sum
        return False
    
    for a in ans: #if we find the solution, for node n we create an edge from the attacker (via its index in s) to the current node
      if a!=0:
        my_graph.add_edge(s.index(a),n)
  return my_graph

In [None]:
def SSP_Solve_cb(g):
  my_graph=nx.DiGraph() #my_graph is a graph with just the nodes, containing associated data from g (weight and FAD)
  my_graph.add_nodes_from(g.nodes(data=True))

  s=[] #the set of FAD numbers used for SSP
  sums=0
  for n in g.nodes:
    s.append(g.nodes[n]["FAD"])
    sums+=s[-1]

  for n in g.nodes: #for each node, compute the target sum based for card
    
    ans=None #ans will be the answer to the ssp
    sigma=g.nodes[n]["FAD"]
    w=g.nodes[n]["weight"]

    if sigma==w: #an unattacked argument
      continue
    
    #we have that k must be the integer between 
    #k=(w-f)/f as the upper bound
    #k=(w-2f)/f as the lower bound
    upper=(w-sigma)/sigma
    lower=(w-2*sigma)/sigma

    if math.floor(upper)!=math.ceil(lower):
      print("error",upper,lower)
    k=math.floor(upper)

    #for k in range(1,len(g.nodes)):
      
    t=-k*(sigma*k + sigma - w)/sigma 
      
    if t<0 or sums<t: #invalid t value so can't be this k. Should never happen
        continue
      
    init_ssp()
    ans=ssp_dfs(t,s,frozenset([0]),k=k)  #now try to find the SSP that sums up to the target sum for that node 
  
    if ans==None: #we couldn't find the sum, try next k
        continue
      
    for a in ans: #if we find the solution, for node n we create an edge from the attacker (via its index in s) to the current node
        if a!=0:
          my_graph.add_edge(s.index(a),n)
      #break #we've found an answer so go out of the k loop

    if ans==None: #we've tried all k values and stil lcan't find a solution
      return False 
  return my_graph

In [None]:
#TEST HCAT; set parameters for reproducibility
seed=1
random.seed(seed)
np.random.seed(seed)

num_args=13
erdos_p=0.3
g=nx.erdos_renyi_graph(num_args,p=erdos_p,seed=seed,directed=True)

for n in g.nodes:
    g.nodes[n]["weight"]=random.random()

w=h_cat(g)
for n in g.nodes:
    g.nodes[n]["FAD"]=w[n]

mg=SSP_Solve_hc(g)

if mg!=False:
  w2=h_cat(mg)
  print(len(mg.edges()),"\n",len(g.edges())) #these should be equal, and we should be able to dig into them
  #let's compare the two FADs, should be close to 0
  for n in g.nodes:
    print(abs(w2[n]-g.nodes[n]["FAD"]))

In [None]:
#TEST CB; set parameters for reproducibility
seed=1
random.seed(seed)
np.random.seed(seed)

num_args=13
erdos_p=0.3
g=nx.erdos_renyi_graph(num_args,p=erdos_p,seed=seed,directed=True)

for n in g.nodes:
    g.nodes[n]["weight"]=random.random()

w=card(g)
for n in g.nodes:
    g.nodes[n]["FAD"]=w[n]
  
mg=SSP_Solve_cb(g)

if mg!=False:
  w2=card(mg)
  print(len(mg.edges()),"\n",len(g.edges())) #these should be equal, and we should be able to dig into them
  #let's compare the two FADs, should be close to 0
  for n in g.nodes:
    print(abs(w2[n]-g.nodes[n]["FAD"]))