In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from mvmm.multi_view.block_diag.graph.linalg import get_adjmat_bp
from mvmm.simulation.sim_viz import save_fig

fig_dir = './figures/'
dpi = 200

# plt.rcParams["axes.labelsize"] = 15

In [2]:
def get_bipt_graph(X):
    mult = 1
    X = X.astype(bool)

    n_rows = X.shape[0]
    n_cols = X.shape[1]

    A = get_adjmat_bp(X)
    G = nx.from_numpy_array(A)

    kws = {}
    kws['node_color'] = np.concatenate([['red'] * n_rows, ['grey']*n_cols])
    kws['node_size'] = 3000 # 1800
    kws['with_labels'] = True
    kws['font_weight'] = 'bold'
    kws['font_size'] = 15
    kws['alpha'] = .9
    kws['width'] = 3

    pos = {}
    idx2name = {}
    idx = 0
    for r in range(n_rows):
        lab = 'Row {}'.format(r + 1)
        pos[lab] = [0, -r * mult]

        idx2name[idx] = lab
        idx += 1

    for c in range(n_cols):
        lab = 'Col {}'.format(c + 1)
        pos[lab] = [.1, -c * mult]

        idx2name[idx] = lab
        idx += 1

    G = nx.relabel_nodes(G, idx2name)
    return G, pos, kws



In [3]:
X = np.zeros((10, 10))
X[0, 0] = 1
X[1, 1] = 1
X[2, 2:5] = 1
X[3:5, 5] = 1
X[5:7, 6] = 1
X[5:7, 7] = 1
X[7:10, 8] = 1
X[7:9, 9] = 1
X = X.astype(bool)


G, pos, kws = get_bipt_graph(X)

plt.figure(figsize=(5, 5))
sns.heatmap(~X, mask=~X, cbar=False, square=True, linewidths=.5)
x_ticks = np.arange(1, X.shape[1] + 1)
y_ticks = np.arange(1, X.shape[0] + 1)
plt.xticks(x_ticks - .5, x_ticks)
plt.yticks(y_ticks - .5, y_ticks)
plt.xlabel("View 1 clusters")
plt.ylabel("View 2 clusters")
save_fig(fig_dir + 'motiv_ex_pi.png', dpi=dpi)

plt.figure(figsize=(6, 10))
nx.draw(G, pos, **kws)
save_fig(fig_dir + 'motiv_ex_pi_bipt_graph.png', dpi=dpi)