# Graph and entropy computations

In [None]:
import collections
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import scripts.boolean_helper
from scripts.grid_class import GridParms
from scripts.tree_class import Tree, plotReactionGraph

In [None]:
# model_name = "pancreatic_cancer"
model_name = "apoptosis"
seed = 42 # parameter for graph plots

In [None]:
reaction_system = scripts.boolean_helper.convertRulesToReactions("scripts/models/boolean_rulefiles/{}.hpp".format(model_name))

In [None]:
species_dict = {species_name: i for i, species_name in enumerate(reaction_system.species_names)}
sort_key = lambda x: species_dict[x]

In [None]:
def counterEntryToPartitionString(counter_entry):
    p = " ".join((str(species_dict[species_name]) for species_name in counter_entry))
    return p

In [None]:
def printStats(S, S0, S1, S_sorted, S0_sorted, S1_sorted):
    p = "(({})({}))(({})({}))".format(
        counterEntryToPartitionString(S0_sorted[0]),
        counterEntryToPartitionString(S0_sorted[1]),
        counterEntryToPartitionString(S1_sorted[0]),
        counterEntryToPartitionString(S1_sorted[1]))

    print(p)

    S_stats = [stat for stat in S[S_sorted].values()]
    S0_stats = [stat for stat in S0[S0_sorted].values()]
    S1_stats = [stat for stat in S1[S1_sorted].values()]

    print("""
        entropy_root:\t{:.3f},\tcount_root:\t{},\tcuts_root:\t{}
        entropy_child0:\t{:.3f},\tcount_child0:\t{},\tcuts_child0:\t{}
        entropy_child1:\t{:.3f},\tcount_child1:\t{},\tcuts_child1:\t{}""".format(*iter(S_stats), *iter(S0_stats), *iter(S1_stats)))
    
    print("""
        TOTAL
        -----
        entropy:\t{:.3f},\tcuts:\t{}""".format(S_stats[0]+S0_stats[0]+S1_stats[0],
                                           S_stats[2]+S0_stats[2]+S1_stats[2]))
    

    return p

In [None]:
def kernighanLinCounter(G: nx.Graph, n=10000):
    partitions = [None] * n
    for i in range(n):
        bisection = nx.algorithms.community.kernighan_lin_bisection(G, max_iter=2**32)
        b0 = tuple(sorted(bisection[0], key=sort_key))
        b1 = tuple(sorted(bisection[1], key=sort_key))
        if species_dict[b0[0]] < species_dict[b1[0]]:
            partitions[i] = (b0, b1)
        else:
            partitions[i] = (b1, b0)

    counter = collections.Counter(partitions)
    return counter

In [None]:
d = reaction_system.d()
n = 2 * np.ones(d, dtype=int)
binsize = np.ones(d, dtype=int)
liml = np.zeros(d)
grid = GridParms(n, binsize, liml)

## Find best partition

In [None]:
len0 = int(np.floor(reaction_system.d() / 2))
nrange = tuple(i for i in range(reaction_system.d()))
partition = "{}{}".format(nrange[:len0], nrange[len0:]).replace(",","")

In [None]:
d = reaction_system.d()
n = 2 * np.ones(d, dtype=int)
binsize = np.ones(d, dtype=int)
liml = np.zeros(d)
grid = GridParms(n, binsize, liml)

tree = Tree(partition, grid)
r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
tree.initialize(reaction_system, r_out)

counter = kernighanLinCounter(tree.G)
most_common = counter.most_common()

print("total number of partitions found:", len(most_common))

In [None]:
S = {}
for i, (partition, count) in enumerate(most_common):
    p = "({})({})".format(counterEntryToPartitionString(partition[0]), counterEntryToPartitionString(partition[1]))
    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)
    
    entropy = tree.calculateEntropy(tree.root)
    cuts = nx.cut_size(tree.G, partition[0], partition[1])
    S[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

## Find best subpartitions

### By entropy

In [None]:
S_entropy_sorted = sorted(S, key=lambda x: S[x]["entropy"])

#### Left subpartition

In [None]:
G0_entropy = nx.subgraph(tree.G, S_entropy_sorted[0][0])

In [None]:
counter = kernighanLinCounter(G0_entropy)
most_common = counter.most_common()
p1 = counterEntryToPartitionString(S_entropy_sorted[0][1])

In [None]:
S0_entropy = {}
for i, (partition, count) in enumerate(most_common):
    p00 = counterEntryToPartitionString(partition[0])
    p01 = counterEntryToPartitionString(partition[1])
    p = "(({})({}))({})".format(p00, p01, p1)
    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[0])

    cuts = nx.cut_size(G0_entropy, partition[0], partition[1])
    S0_entropy[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Right subpartition

In [None]:
G1_entropy = nx.subgraph(tree.G, S_entropy_sorted[0][1])

In [None]:
counter = kernighanLinCounter(G1_entropy)
most_common = counter.most_common()
p0 = counterEntryToPartitionString(S_entropy_sorted[0][0])

In [None]:
S1_entropy = {}
for i, (partition, count) in enumerate(most_common):
    p10 = counterEntryToPartitionString(partition[0])
    p11 = counterEntryToPartitionString(partition[1])
    p = "({})(({})({}))".format(p0, p10, p11)

    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[1])

    cuts = nx.cut_size(G1_entropy, partition[0], partition[1])
    S1_entropy[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Total

In [None]:
S0_entropy_sorted = sorted(S0_entropy, key=lambda x: S0_entropy[x]["entropy"])
S1_entropy_sorted = sorted(S1_entropy, key=lambda x: S1_entropy[x]["entropy"])

print("best partition (entropy):")
p_best_entropy = printStats(S, S0_entropy, S1_entropy, S_entropy_sorted[0], S0_entropy_sorted[0], S1_entropy_sorted[0])

In [None]:
tree = Tree(p_best_entropy, grid)
r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
tree.initialize(reaction_system, r_out)
fig, ax = plotReactionGraph(tree.G, seed=seed)
ax.set_title("best entropy");
plt.savefig("plots/{}_graph_best_entropy.pdf".format(model_name))

### By counts

In [None]:
S_count_sorted = sorted(S, key=lambda x: -S[x]["count"])

#### Left subpartition

In [None]:
G0_count = nx.subgraph(tree.G, S_count_sorted[0][0])

In [None]:
counter = kernighanLinCounter(G0_count)
most_common = counter.most_common()
p1 = counterEntryToPartitionString(S_count_sorted[0][1])

In [None]:
S0_count = {}
for i, (partition, count) in enumerate(most_common):
    p00 = counterEntryToPartitionString(partition[0])
    p01 = counterEntryToPartitionString(partition[1])
    p = "(({})({}))({})".format(p00, p01, p1)

    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[0])

    cuts = nx.cut_size(G0_count, partition[0], partition[1])
    S0_count[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Right subpartition

In [None]:
G1_count = nx.subgraph(tree.G, S_count_sorted[0][1])

In [None]:
counter = kernighanLinCounter(G1_count)
most_common = counter.most_common()
p0 = counterEntryToPartitionString(S_count_sorted[0][0])

In [None]:
S1_count = {}
for i, (partition, count) in enumerate(most_common):
    p10 = counterEntryToPartitionString(partition[0])
    p11 = counterEntryToPartitionString(partition[1])
    p = "({})(({})({}))".format(p0, p10, p11)

    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[1])

    cuts = nx.cut_size(G1_count, partition[0], partition[1])
    S1_count[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Total

In [None]:
S0_count_sorted = sorted(S0_count, key=lambda x: -S0_count[x]["count"])
S1_count_sorted = sorted(S1_count, key=lambda x: -S1_count[x]["count"])

print("best partition (count):")
p_best_counts = printStats(S, S0_count, S1_count, S_count_sorted[0], S0_count_sorted[0], S1_count_sorted[0])

In [None]:
tree = Tree(p_best_counts, grid)
r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
tree.initialize(reaction_system, r_out)
fig, ax = plotReactionGraph(tree.G, seed=seed)
ax.set_title("best counts");
plt.savefig("plots/{}_graph_best_counts.pdf".format(model_name))

### Find worst subpartition (by entropy)

#### Left subpartition

In [None]:
G0_entropy_worst = nx.subgraph(tree.G, S_entropy_sorted[-1][0])

In [None]:
counter = kernighanLinCounter(G0_entropy_worst)
most_common = counter.most_common()
p1 = counterEntryToPartitionString(S_entropy_sorted[-1][1])

In [None]:
S0_entropy_worst = {}
for i, (partition, count) in enumerate(most_common):
    p00 = counterEntryToPartitionString(partition[0])
    p01 = counterEntryToPartitionString(partition[1])
    p = "(({})({}))({})".format(p00, p01, p1)

    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[0])

    cuts = nx.cut_size(G0_entropy_worst, partition[0], partition[1])
    S0_entropy_worst[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Right subpartition

In [None]:
G1_entropy_worst = nx.subgraph(tree.G, S_entropy_sorted[-1][1])

In [None]:
counter = kernighanLinCounter(G1_entropy_worst)
most_common = counter.most_common()
p0 = counterEntryToPartitionString(S_entropy_sorted[-1][0])

In [None]:
S1_entropy_worst = {}
for i, (partition, count) in enumerate(most_common):
    p10 = counterEntryToPartitionString(partition[0])
    p11 = counterEntryToPartitionString(partition[1])
    p = "({})(({})({}))".format(p0, p10, p11)

    tree = Tree(p, grid)
    r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
    tree.initialize(reaction_system, r_out)

    entropy = tree.calculateEntropy(tree.root.child[1])

    cuts = nx.cut_size(G1_entropy_worst, partition[0], partition[1])
    S1_entropy_worst[partition] = {"entropy": entropy, "count": count, "cuts": cuts}

#### Total

In [None]:
S0_entropy_worst_sorted = sorted(S0_entropy_worst, key=lambda x: S0_entropy_worst[x]["entropy"])
S1_entropy_worst_sorted = sorted(S1_entropy_worst, key=lambda x: S1_entropy_worst[x]["entropy"])

print("worst partition (entropy):")
p_worst_entropy = printStats(S, S0_entropy_worst, S1_entropy_worst, S_entropy_sorted[-1], S0_entropy_worst_sorted[-1], S1_entropy_worst_sorted[-1])

In [None]:
tree = Tree(p_worst_entropy, grid)
r_out = np.ones(tree.n_internal_nodes, dtype="int") * 5
tree.initialize(reaction_system, r_out)
fig, ax = plotReactionGraph(tree.G, seed=seed)
ax.set_title("worst entropy");
plt.savefig("plots/{}_graph_worst_entropy.pdf".format(model_name))