# TopSBM: Topic Modeling with Stochastic Block Models

In [None]:
%load_ext autoreload
%autoreload 2

import os
import pylab as plt
%matplotlib inline  

from sbmtm import sbmtm
import graph_tool.all as gt

import numpy as np
from matplotlib import pyplot as plt

gt.seed_rng(42) ## seed for graph-tool's random number generator --> same results

In [None]:
print(gt.openmp_get_num_threads())

# Fitting the model

In [None]:
## we create an instance of the sbmtm-class
model = sbmtm()

In [None]:
## we can also skip the previous step by saving/loading a graph
#model.save_graph(filename = 'graph.xml.gz')
model.load_graph(filename = 'graph.xml.gz')

In [None]:
model.g

In [None]:
model.fit()
#model.fit_overlap()

In [None]:
state = model.state

# Plotting the result

The output shows the (hierarchical) community structure in the word-document network as inferred by the stochastic block model:

- document-nodes are on the left
- word-nodes are on the right
- different colors correspond to the different groups

The result is a grouping of nodes into groups on multiple levels in the hierarchy:

- on the uppermost level, each node belongs to the same group (square in the middle)
- on the next-lower level, we split the network into two groups: the word-nodes and the document-nodes (blue sqaures to the left and right, respectively). This is a trivial structure due to the bipartite character of the network.
- only next lower levels constitute a non-trivial structure: We now further divide nodes into smaller groups (document-nodes into document-groups on the left and word-nodes into word-groups on the right)

In [None]:
model.plot(nedges=1000)

In [None]:
state.draw(layout='bipartite', output='bipartite_overlap.png')
#state.draw(output='circular_overlap.png')

# The basics

## Topics
For each word-group on a given level in the hierarchy, we retrieve the $n$ most common words in each group -- these are the topics!


In [None]:
l=1

In [None]:
model.topics(l=l,n=200)

In [None]:
topic_lenghts = []
for topic in model.topics(l=l,n=2000):
    topic_lenghts.append(len(model.topics(l=l,n=2000)[topic]))
    
fig=plt.figure()
plt.hist(topic_lenghts, histtype='step', bins=11, lw=2)
plt.xlabel("topic size (# genes)", fontsize=16)
plt.ylabel("# topic of that size", fontsize=16)
plt.show()
fig.savefig("topic_size_%d.png"%l)

In [None]:
for ensg in model.topics(l=l,n=200)[7]:
    print(ensg[0])

### DAVID compatible format

In [None]:
model.print_topics(l=l, format='tsv')

## Topic-distribution in each document
Which topics contribute to each document?

In [None]:
## select a document (by its index)
i_doc = 814
print(model.documents[i_doc])
## get a list of tuples (topic-index, probability)
#model.topicdist(i_doc,l=0)

In [None]:
data = [el[1] for el in model.topicdist(i_doc,l=l)]
labels = [el[0]+1 for el in model.topicdist(i_doc,l=l)]
fig=plt.figure()
plt.pie(data, labels=labels)
plt.title("Topic distribution: %s"%model.documents[i_doc])
plt.show()
fig.savefig("topic_distr_%s.png"%model.documents[i_doc])

# Extra: Clustering of documents - for free.
The stochastic block models clusters the documents into groups.
We do not need to run an additional clustering to obtain this grouping.


In [None]:
model.clusters(l=l,n=500)

Application -- Finding similar articles:

For a query-article, we return all articles from the same group

In [None]:
## select a document (index)
i_doc = 2
print(i_doc,model.documents[i_doc])
## find all articles from the same group
## print: (doc-index, doc-title)
model.clusters_query(i_doc,l=0)

# More technical: Group membership
In the stochastic block model, word (-nodes) and document (-nodes) are clustered into different groups.

The group membership can be represented by the conditional probability $P(\text{group}\, |\, \text{node})$. Since words and documents belong to different groups (the word-document network is bipartite) we can show separately:

- P(bd | d), the probability of document $d$ to belong to document group $bd$
- P(bw | w), the probability of word $w$ to belong to word group $bw$.

In [None]:
l=l
p_td_d,p_tw_w = model.group_membership(l=l)
fig = plt.figure(figsize=(15,4))
plt.subplot(121)
plt.imshow(p_td_d,origin='lower',aspect='auto',interpolation='none')
plt.title(r'Document group membership $P(bd | d)$')
plt.xlabel('Document d (index)')
plt.ylabel('Document group, bd')
plt.colorbar()

plt.subplot(122)
plt.imshow(p_tw_w,origin='lower',aspect='auto',interpolation='none')
plt.title(r'Word group membership $P(bw | w)$')
plt.xlabel('Word w (index)')
plt.ylabel('Word group, bw')
plt.colorbar()
plt.show()

In [None]:
fig.savefig("group_membership_%d.pdf"%l)
fig.savefig("group_membership_%d.png"%l)

In [None]:
overlaplenghts=[]
overlap_index = []
for i,el in enumerate(p_tw_w.T):
    mixture_size = len(np.nonzero(el)[0])
    overlaplenghts.append(mixture_size)
    if mixture_size>1:
        overlap_index.append(i)

In [None]:
fig=plt.figure()
plt.title("How many topics in a single gene")
plt.xlabel("# topic")
plt.ylabel("genes with that number of topics")
plt.hist(overlaplenghts, histtype='step', lw=2, range=(-0.5,5.5), bins=6)
plt.show()
fig.savefig("overlap_size_%d.png"%l)

In [None]:
gene=300
for gene in overlap_index[:10]:
    fig=plt.figure()
    plt.title("Topic distribution of a single gene: %s"%model.words[gene])
    plt.xlabel("topic tw")
    plt.ylabel("probability")
    plt.ylim((0,1.1))
    plt.plot(p_tw_w.T[gene])
    plt.show()
    fig.savefig("distribution_single_gene_%s.png"%model.words[gene])

## overlapping genes

In [None]:
#get genes overlapping
overlappinggenes = []
for i,el in enumerate(p_tw_w.T):
    if(len(np.nonzero(el)[0])>1):
        overlappinggenes.append(model.words[i])

In [None]:
for g in overlappinggenes:
    print(g)

# state analysis

In [None]:
state = model.state

In [None]:
for level in state.get_levels():
    e=level.get_matrix()
    plt.matshow(e.todense())
    plt.savefig("mat_%d.png"%l)

In [None]:
for i in range(len(state.get_levels())-2)[::-1]:
    print("doing %d"%i)
    model.print_topics(l=i)

In [None]:
model.print_summary()

In [None]:
print(len(model.words))
print(len(model.documents))

### topicdist

In [None]:
groups = model.groups[l]

In [None]:
p_w_tw = groups['p_w_tw']
fig=plt.figure(figsize=(12,10))
plt.imshow(p_w_tw,origin='lower',aspect='auto',interpolation='none')
plt.title(r'Word group membership $P(w | tw)$')
plt.xlabel('Topic, tw')
plt.ylabel('Word w (index)')
plt.colorbar()
fig.savefig("p_w_tw_%d.png"%l)

In [None]:
p_tw_d = groups['p_tw_d']
fig=plt.figure(figsize=(12,10))
plt.imshow(p_tw_d,origin='lower',aspect='auto',interpolation='none')
plt.title(r'Word group membership $P(tw | d)$')
plt.xlabel('Document (index)')
plt.ylabel('Topic, tw')
plt.colorbar()
fig.savefig("p_tw_d_%d.png"%l)

In [None]:
topic=2
fig=plt.figure()
plt.title("Topic%d composition"%topic)
plt.xlabel("word w")
plt.ylabel("probability")
plt.plot(p_w_tw.T[topic])
plt.show()
fig.savefig("Topic%d_composition.png"%topic)

In [None]:
doc=12
title = model.documents[doc]
fig=plt.figure()
plt.title("Topic distribution of sample: %s"%title)
plt.xlabel("topic tw")
plt.ylabel("probability")
plt.plot(p_tw_d.T[doc])
plt.show()
fig.savefig("distribution_single_sample_%d.png"%doc)