In [None]:
import graph_tool.all as gt
import random
import numpy as np

In [None]:
g = gt.Graph(directed=False)
name = g.vp["name"] = g.new_vp("string")
kind = g.vp["kind"] = g.new_vp("int")
weight = g.ep["count"] = g.new_ep("float")
edge_kind = g.ep["edge_kind"] = g.new_ep("int")



D = 100
W = 20
W1 = 10

for i in range(D):
    n = g.add_vertex()
    name[n]=i
    kind[n]=0
        
for i in range(W):
    n = g.add_vertex()
    name[n]=i+D
    kind[n]=1
    
for i in range(W1):
    n = g.add_vertex()
    name[n]=i+D+W
    kind[n]=2
    
for i in range(D):
    for j in range(W):
        e = g.add_edge(i,j+D, False)
        weight[e] = int(random.randint(1,5))
        edge_kind[e] = 1
    for j in range(W1):
        e = g.add_edge(i,j+D+W, False)
        weight[e] = int(random.randint(1,5))
        edge_kind[e] = 2

In [None]:
for _ in range(1000):
    i = random.randint(0, D-1)
    j = random.randint(D, D+W-1)
    e = g.add_edge(i, j, False)
    weight[e] = random.randint(1,100)
    edge_kind[e] = 0

In [None]:
%load_ext autoreload
%autoreload 2
from trisbm import trisbm

In [None]:
len(list(g.vertices())),len(list(g.edges()))

In [None]:
clabel = g.vp['kind']
state_args = {'clabel': clabel, 'pclabel': clabel}
state_args["eweight"] = g.ep.count

In [None]:
g.save("graph.xml.gz")
gt.graph_draw(g)

In [None]:
model = trisbm()
model.load_graph("graph.xml.gz")

In [None]:
model.fit()

In [None]:
model.draw()

In [None]:
state = model.state.copy(bs=model.state.get_bs() + [np.zeros(1)] * 4, sampling = True)
for _ in range(100):
    state.multiflip_mcmc_sweep(beta=np.inf)

In [None]:
model.state.entropy(), state.entropy()

In [None]:
colmap = model.g.vertex_properties["color"] = model.g.new_vertex_property("vector<double>")

for v in model.g.vertices():
    k = model.g.vertex_properties['kind'][v]
    if k < 1:
        color = np.array([112,140,195])/255.
    elif k < 2:
        color = np.array([131,209,80])/255.
    elif k < 3:
        color = np.array([210,82,58])/255.
    else:
        color = np.array([114,124,206])/255.
    colmap[v] = color

gt.draw_hierarchy(model.state,
            layout="bipartite",
            subsample_edges=500,
            edge_pen_width=model.g.ep["count"],
            #edge_color = edge_kind,
            #edge_fill_color = edge_kind,
            hvertex_fillcolor="black",
            vertex_color=colmap,        
            vertex_fill_color=colmap,
            vertex_size=20,
            vertex_shape="square",
            output="network.png"
)

# Consensus

In [None]:
pv = model.search_consensus(force_niter=10, niter=5)

In [None]:
# We can visualize the marginals as pie charts on the nodes:
model.state.draw(layout="bipartite", 
            subsample_edges=5000,
            vertex_shape="pie", 
            vertex_pie_fractions=pv,
            #output="network_consensus.pdf"
            )