In [2]:
from scipy.io import loadmat
from coclust.CoclustMod import CoclustMod
import numpy as np

# Retrieve the CSTR document-term matrix from a matlab file
file_name = "../datasets/classic3.mat"
matlab_dict = loadmat(file_name)

X = matlab_dict['A']

model = CoclustMod(n_clusters=3, n_init=1, random_state=0)
model.fit(X)

terms = [str(x[0][0]) for x in matlab_dict['ms']]

In [6]:
def get_neighbors(sim, term_index, n_neighbors):
    row = sim[term_index, :]
    row = row.toarray()[0]
    return row.argsort()[::-1][:n_neighbors]

def get_all_neighbors(sim, top_terms_indices, n_neighbors):
    neighbors = set()
    for term_index in top_terms_indices:
        neighbors = neighbors.union(set(get_neighbors(sim, term_index, n_neighbors)))
    
    return neighbors

In [7]:
def get_graph(X, model, terms, n_cluster, n_top_terms=10, n_neighbors=5):
    row_indices, col_indices = model.get_indices(n_cluster)
    cluster = model.get_submatrix(X, n_cluster)

    # terms in the cluster
    terms = np.array(terms)[col_indices]

    sim = cluster.T * cluster
    
    p = cluster.sum(0)
    t = p.getA().flatten()
    # indices des termes avec le plus grand nombre d'occurences
    top_terms_indices = t.argsort()[::-1][:n_top_terms]
    
    neighbors = get_all_neighbors(sim, top_terms_indices, n_neighbors)
    
    #TODO
    #return neighbors

    graph = {"nodes": [], "links": []}
    i = 0
    for top_term in top_terms_indices:
        graph["nodes"].append({"name": terms[top_term], "group": 0})
        i = i + 1
    
    for neighbor in neighbors:
        graph["nodes"].append({"name": terms[neighbor], "group": 1})
        i = i + 1
    
    # link: top_term -- neighbor
    i = 0
    for top_term in top_terms_indices:
        j = 0
        for neighbor in neighbors:
            graph["links"].append( {"source": i, "target": n_top_terms + j, "value": sim[top_term, neighbor]} )
            j = j + 1
        i = i + 1
        
    # link: top_term -- top_term
    i = 0
    for top_term in top_terms_indices:
        j = 0
        for top_term2 in top_terms_indices:
            if i != j:
                graph["links"].append( {"source": i, "target": j, "value": sim[top_term, top_term2]} )
            j = j + 1
        i = i + 1
    
    return graph

In [15]:
graph = get_graph(X, model, terms, 1, 10, 5)

In [16]:
#TODO: remplacer graph par le vrai graphe

import random
n_nodes = 30
p_edge = 0.05
#graph = {"nodes": [], "links": []}
for i in range(n_nodes):
    graph["nodes"].append( {"name": "i" + str(i), "group": int(random.uniform(1,11))} )
for i in range(n_nodes):
    for j in range(n_nodes):
        if random.uniform(0,1) < p_edge:
            graph["links"].append( {"source": i, "target": j, "value": random.uniform(0.5,3)} )

In [11]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'
  }
});

<IPython.core.display.Javascript object>

In [12]:
from IPython.display import HTML
HTML("""
<style>
.node_circle {
  stroke: #fff;
  stroke-width: 1.5px;
}

.link {
  stroke: #999;
  stroke-opacity: .6;
}
</style>
""")

In [18]:
from IPython.display import Javascript
#runs arbitrary javascript, client-side
Javascript("""
           window.graph={};
           """.format(graph))

<IPython.core.display.Javascript object>

In [28]:
%%javascript
require(['d3'], function(d3){
  //a weird idempotency thing
  $("#chart1").remove();
  //create canvas
  element.append("<div id='chart1'></div>");
  $("#chart1").width("1160px");
  $("#chart1").height("800px");        
  var margin = {top: 20, right: 20, bottom: 30, left: 40};
  var width = 1280 - margin.left - margin.right;
  var height = 800 - margin.top - margin.bottom;
  var svg = d3.select("#chart1").append("svg")
    .style("position", "relative")
    .style("max-width", "960px")
    .attr("width", width + "px")
    .attr("height", (height + 50) + "px")
    .append("g")
    .attr("transform", "translate(" + margin.left + "," + margin.top + ")");

var color = d3.scale.category20();

var force = d3.layout.force()
    .charge(-800)
    .linkDistance(600)
    .size([width, height]);

var graph = window.graph;

  force
      .nodes(graph.nodes)
      .links(graph.links)
      .start();

  var link = svg.selectAll(".link")
      .data(graph.links)
      .enter().append("line")
      .attr("class", "link")
      .style("stroke", "#999;")
      .style("stroke-width", function(d) { return Math.sqrt(d.value); });

  var node = svg.selectAll(".node")
      .data(graph.nodes)
      .enter().append("g")
      .attr("class", "node")
      .call(force.drag);
    
  node.append("circle")
      .attr("class", "node_circle")
      .attr("r", 8)
      .style("fill", function(d) { return color(d.group); });

  node.append("text")
      .attr("class", "node_text")
      .attr("dx", 12)
      .attr("dy", ".35em")
      .text(function(d) { return d.name });

  node.append("title")
      .text(function(d) { return d.name; });

  var node_text = svg.selectAll(".node_text");
  var node_circle = svg.selectAll(".node_circle");
    
  force.on("tick", function() {
    link.attr("x1", function(d) { return d.source.x; })
        .attr("y1", function(d) { return d.source.y; })
        .attr("x2", function(d) { return d.target.x; })
        .attr("y2", function(d) { return d.target.y; });

    node_circle.attr("cx", function(d) { return d.x; })
        .attr("cy", function(d) { return d.y; });
      
    node_text.attr("x", function(d) { return d.x; })
        .attr("y", function(d) { return d.y; });
  });


});

<IPython.core.display.Javascript object>