MS2LDA on GNPS corpus
===

In [261]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


import sys
sys.path.append('/Users/simon/git/lda/code/')
sys.path.append('/Users/simon/git/lda/gnps/')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [262]:
from lda import VariationalLDA

Load the corpus

In [263]:
import pickle

In [264]:
prefix = '/Users/simon/git/lda/gnps/gnps_'
with open(prefix+'corpus.lda','r') as f:
    corpus = pickle.load(f)

In [265]:
with open(prefix+'metadata.lda','r') as f:
    metadata = pickle.load(f)

In [266]:
with open(prefix+'fragment_masses.lda','r') as f:
    fragment_masses = pickle.load(f)
with open(prefix+'fragment_names.lda','r') as f:
    fragment_names = pickle.load(f)
with open(prefix+'fragment_counts.lda','r') as f:
    fragment_counts = pickle.load(f)


In [267]:
with open(prefix+'loss_masses.lda','r') as f:
    loss_masses = pickle.load(f)
with open(prefix+'loss_names.lda','r') as f:
    loss_names = pickle.load(f)
with open(prefix+'loss_counts.lda','r') as f:
    loss_counts = pickle.load(f)

Set a threshold on the number of occurances of a feature

In [268]:
feat_thresh = 2
to_remove = []
for i,fragment_name in enumerate(fragment_names):
    if fragment_counts[i]<feat_thresh:
        to_remove.append(fragment_name)
for i,loss_name in enumerate(loss_names):
    if loss_counts[i]<feat_thresh:
        to_remove.append(loss_name)
print "Found {} to remove (of {})".format(len(to_remove),len(loss_names)+len(fragment_names))  

Found 36881 to remove (of 57510)


In [269]:
instances_removed = 0
doc_pos = 0
sub_corpus = {}
for doc in corpus:
    sub_corpus[doc] = {}
    for word in corpus[doc]:
        if not word in to_remove:
            sub_corpus[doc][word] = corpus[doc][word]
    doc_pos += 1
    if len(sub_corpus[doc]) == 0:
        del sub_corpus[doc]
    if doc_pos % 1000 == 0:
        print "Done doc {}".format(doc_pos)

Done doc 1000
Done doc 2000
Done doc 3000
Done doc 4000
Done doc 5000


In [270]:
with open(prefix+'sub_corpus.lda','w') as f:
    pickle.dump(sub_corpus,f,-1)

In [271]:
n_words = {}
total = 0
for doc in sub_corpus:
    n_words[doc] = len(sub_corpus[doc])
    total += n_words[doc]
print "Average {} unique words per document".format(1.0*total/len(sub_corpus))

Average 66.9402087387 unique words per document


In [272]:
from lda import VariationalLDA
gnps_lda = VariationalLDA(sub_corpus,K=500,alpha=1,eta=0.1,update_alpha=True,normalise = 100)

Found 20629 unique words
Object created with 5653 documents
Normalising intensities


In [273]:
gnps_lda.run_vb(n_its = 100)

Initialising
Starting iterations
Iteration 0 (change = 567.696225611) (22.748078 seconds, I think I'll finish in 37.9134633333 minutes)
Iteration 1 (change = 22.6572508285) (20.298537 seconds, I think I'll finish in 33.49258605 minutes)
Iteration 2 (change = 16.2932976161) (22.037641 seconds, I think I'll finish in 35.9948136333 minutes)
Iteration 3 (change = 12.9047184667) (23.255546 seconds, I think I'll finish in 37.5964660333 minutes)
Iteration 4 (change = 10.8011579228) (21.543487 seconds, I think I'll finish in 34.4695792 minutes)
Iteration 5 (change = 9.40223166835) (22.079417 seconds, I think I'll finish in 34.9590769167 minutes)
Iteration 6 (change = 8.41700342279) (24.40558 seconds, I think I'll finish in 38.2354086667 minutes)
Iteration 7 (change = 7.70344957561) (19.795824 seconds, I think I'll finish in 30.6835272 minutes)
Iteration 8 (change = 7.20071723292) (18.616836 seconds, I think I'll finish in 28.5458152 minutes)
Iteration 9 (change = 6.92184732458) (17.991437 seco

In [276]:
gnps_lda.run_vb(n_its=900,initialise=False)

Starting iterations
Iteration 0 (change = 0.554966861496) (26.038453 seconds, I think I'll finish in 390.576795 minutes)
Iteration 1 (change = 0.492246582641) (20.442205 seconds, I think I'll finish in 306.292371583 minutes)
Iteration 2 (change = 0.450462222003) (20.327696 seconds, I think I'll finish in 304.237850133 minutes)
Iteration 3 (change = 0.425445175324) (20.15524 seconds, I think I'll finish in 301.320838 minutes)
Iteration 4 (change = 0.401373012242) (19.838678 seconds, I think I'll finish in 296.257591467 minutes)
Iteration 5 (change = 0.385575208245) (20.030708 seconds, I think I'll finish in 298.791394333 minutes)
Iteration 6 (change = 0.374518972168) (19.931014 seconds, I think I'll finish in 296.9721086 minutes)
Iteration 7 (change = 0.363991486806) (20.003826 seconds, I think I'll finish in 297.7236103 minutes)
Iteration 8 (change = 0.351493465007) (19.90087 seconds, I think I'll finish in 295.859600667 minutes)
Iteration 9 (change = 0.336720212354) (19.936075 seconds

In [277]:
from lda_plotters import VariationalLDAPlotter

In [278]:
vp = VariationalLDAPlotter(gnps_lda)
vp.bar_alpha()

Make the network graph

In [279]:
import networkx as nx
# Extract the topics of interest
topics = []
topic_idx = []
topic_id = []
topic_degree_thresh = 10
p_thresh = 0.05
# eth = v_lda.get_expect_theta()
# for i in range(v_lda.K):
eth = gnps_lda.get_expect_theta()
print eth.shape
for i in range(gnps_lda.K):
    s = (eth[:,i]>p_thresh).sum()
    if s > topic_degree_thresh:
        topics.append("motif_{}".format(i))
        topic_idx.append(i)
        topic_id.append(i)
        
        
        
        
G = nx.Graph()
for i,t in enumerate(topics):
    s = (eth[:,topic_idx[i]]>p_thresh).sum()
    G.add_node(t,group=2,name=t,size=5*s,
               type='circle',special=False,
              in_degree=s,score=1)
print "Added {} topics".format(len(topics))


# Add the parents
parents = []
parent_dict = {}
parent_id = {}
doc_for_later = None
j = len(topics)

eth = gnps_lda.get_expect_theta()
for doc in gnps_lda.corpus:
    parent_pos = gnps_lda.doc_index[doc]
#     parent_name = "doc_{}_{}".format(doc.mz,doc.rt)
    parent_name = metadata[doc]['compound']
    for i,t in enumerate(topics):
        topic_pos = topic_idx[i]
        this_topic_id = topic_id[i]
        if eth[parent_pos,topic_pos] > p_thresh:
            if not parent_name in parents:
                parents.append(parent_name)
                parent_dict[parent_name] = doc
                parent_id[doc] = j
                j += 1
                G.add_node(parent_name,group=1,name=parent_name,
                           size=20,type='square',peakid=parent_name,
                          special=False,in_degree=0,score=0)


            G.add_edge(t,parent_name,weight=5*eth[parent_pos,topic_pos])
            
print "Added {} parents".format(len(parents))

(5653, 500)
Added 276 topics
Added 5288 parents


In [280]:
import json
from networkx.readwrite import json_graph
d = json_graph.node_link_data(G) 
json.dump(d, open('../joegraph/gnps_graph.json','w'),indent=2)

In [283]:
# Write the topics to a file
b = gnps_lda.beta_matrix.copy()
print b.shape
with open('../gnps/topics.txt','w') as f:
    for topic in topic_id:
        f.write('TOPIC: {}\n'.format(topic))
        word_tup = []
        for word in gnps_lda.word_index:
            word_tup.append((word,b[topic,gnps_lda.word_index[word]]))
        word_tup = sorted(word_tup,key = lambda x: x[1],reverse=True)
        total_prob = 0.0
        pos = 0
        n_words = 0
        while total_prob < 0.9 and n_words <= 20:
            total_prob += word_tup[pos][1]
            n_words += 1
            f.write("\t{}: {}\n".format(word_tup[pos][0],word_tup[pos][1]))
            pos += 1
        f.write('\n\n')

(500, 20629)


In [282]:
doc = parent_dict['Lonidamine [M+H]']
precursor_mass = float(metadata[doc]['parentmass'])
title = doc + "  (" + metadata[doc]['compound'] + ")"
vp.plot_document_topic_colour(doc,precursor_mass = precursor_mass,show_losses = True,title = title)

In [230]:
for doc in G.neighbors('motif_335'):
    print doc
    for word in sub_corpus[parent_dict[doc]]:
        if word.startswith('loss_46.') or word.startswith('loss_18.'):
            print word,sub_corpus[parent_dict[doc]][word]
    print
    print

"MLS000069767-01!3'-AZIDO-2',3'-DIDEOXYURIDINE" M+H


Apratoxin F M+Na


Malyngamide C M+H


Pheophytin M+H


MLS002153946-01!Tetrandrine M+H
loss_46.0718175161 5.27081061432


Vincristine M+H
loss_18.0104235721 8.76207922256
loss_18.0100400521 0.958650614454


Pheophorbide A M+H
loss_46.0108309481 0.611833581266
loss_18.0122365401 0.74259801334


"NCGC00160317-01!8-(3,4-Dimethoxy-phenyl)-2,3,10,11-tetramethoxy-5,6,13a-tridehydro-berbinium" M+H
loss_46.0501679241 3.15272380673


Lansoprazole M+H


MLS001075533-01! M+H


NCGC00160180-01!KOPSINE M+H
loss_46.0053890801 1.07211283682
loss_18.0118481081 3.59952712894


VINCRISTINE SULFATE [M+H]
loss_18.0110677561 15.0665228335


Lansoprazole [M+H]
loss_18.0104235721 1.59841075795




In [145]:
m1 = 46.0057116041
m2 = 46.0053890801
1e6*abs(m1-m2)/m1

7.010520840830274

Plot T spectra that include a particular topic (above a threshold)

In [299]:
from lda_plotters import VariationalLDAPlotter
vp = VariationalLDAPlotter(gnps_lda)
thresh = 0.1
topic = 276
eth = gnps_lda.get_expect_theta()
max_found = 10
n_found = 0
di = []
for doc in gnps_lda.doc_index:
    di.append((doc,gnps_lda.doc_index[doc]))

for i,e in enumerate(eth[:,topic]):
    if e > thresh:
        n_found += 1
        doc = [d for d,j in di if j == i][0]
        title = metadata[doc]['compound']
        precursor_mass = float(metadata[doc]['parentmass'])
#         vp.plot_document_topic_colour(doc,show_losses = True,
#                                       precursor_mass=precursor_mass,title=title,
#                                       xlim = [130,140])
        vp.plot_document_colour_one_topic(doc,topic,show_losses = True,
                              precursor_mass=precursor_mass,title=title,
                              xlim = None)


    if n_found > max_found:
        break
        

In [176]:
print gnps_lda.corpus['CCMSLIB00000001778.ms']

{'loss_112.05007076': 487190.0, 'loss_115.112951652': 98433.0, 'loss_43.9931689361': 148866.0, 'loss_60.0806843561': 326324.0, 'loss_72.0878637441': 29377.0}


In [186]:
pos = []
for word in gnps_lda.phi_matrix['CCMSLIB00000001778.ms']:
    print word,gnps_lda.phi_matrix['CCMSLIB00000001778.ms'][word][8]
    pos.append(gnps_lda.word_index[word])
    
print pos

print gnps_lda.beta_matrix[8,pos]
print gnps_lda.beta_matrix[5,pos]

loss_112.05007076 1.0
loss_115.112951652 1.0
loss_43.9931689361 1.0
loss_60.0806843561 1.0
loss_72.0878637441 1.0
[22858, 21638, 18470, 5047, 23003]
[ 0.00201914  0.00042061  0.00061668  0.0013518   0.00012438]
[  1.13833356e-10   1.13833356e-10   1.13833356e-10   1.13833356e-10
   1.14647235e-10]


In [202]:
import time
print time.clock()

18244.207393


In [247]:
m1 = 112.076118
m2 = 112.07531
1e6*abs((m1-m2)/m1)

m1 = 183.149078
m2 = 183.112717
1e6*abs((m1-m2)/m1)


198.53225796746528

In [296]:
with open('gnps_lda.lda','w') as f:
    pickle.dump(gnps_lda,f)

In [300]:
# make the dictionary

In [388]:
min_prob_to_keep_beta = 1e-3
min_prob_to_keep_phi = 1e-2
min_prob_to_keep_theta = 1e-2

lda_dict = {}
lda_dict['corpus'] = gnps_lda.corpus
lda_dict['word_index'] = gnps_lda.word_index
lda_dict['doc_index'] = gnps_lda.doc_index
lda_dict['K'] = gnps_lda.K
lda_dict['alpha'] = list(gnps_lda.alpha)
lda_dict['beta'] = {}
lda_dict['doc_metadata'] = metadata
wi = []
for i in gnps_lda.word_index:
    wi.append((i,gnps_lda.word_index[i]))
wi = sorted(wi,key = lambda x: x[1])

di = []
for i in gnps_lda.doc_index:
    di.append((i,gnps_lda.doc_index[i]))
di = sorted(di,key=lambda x: x[1])

ri,i = zip(*wi)
ri = list(ri)
di,i = zip(*di)
di = list(di)

    

In [389]:
for k in range(gnps_lda.K):
    pos = np.where(gnps_lda.beta_matrix[k,:]>min_prob_to_keep_beta)[0]
    motif_name = 'motif_{}'.format(k)
    lda_dict['beta'][motif_name] = {}
    for p in pos:
        word_name = ri[p]
        lda_dict['beta'][motif_name][word_name] = gnps_lda.beta_matrix[k,p]


eth = gnps_lda.get_expect_theta()
lda_dict['theta'] = {}
for i,t in enumerate(eth):
    doc = di[i]
    lda_dict['theta'][doc] = {}
    pos = np.where(t > min_prob_to_keep_theta)[0]
    for p in pos:
        motif_name = 'motif_{}'.format(p)
        lda_dict['theta'][doc][motif_name] = t[p]
    

In [390]:
# lda_dict['gamma'] = []
# for d in range(len(gnps_lda.corpus)):
#     lda_dict['gamma'].append(list(gnps_lda.gamma_matrix[d,:]))
lda_dict['phi'] = {}
ndocs = 0
for doc in gnps_lda.corpus:
    ndocs += 1
    lda_dict['phi'][doc] = {}
    for word in gnps_lda.corpus[doc]:
        lda_dict['phi'][doc][word] = {}
        pos = np.where(gnps_lda.phi_matrix[doc][word] >= min_prob_to_keep_phi)[0]
        for p in pos:
            lda_dict['phi'][doc][word]['motif_{}'.format(p)] = gnps_lda.phi_matrix[doc][word][p]
    if ndocs % 500 == 0:
        print "Done {}".format(ndocs)
        

Done 500
Done 1000
Done 1500
Done 2000
Done 2500
Done 3000
Done 3500
Done 4000
Done 4500
Done 5000
Done 5500


In [391]:
sys.getsizeof(lda_dict)

1048

In [392]:
import pickle
with open('gnps_lda.dict','w') as f:
    pickle.dump(lda_dict,f,-1)

In [385]:
from lda_plotters import VariationalLDAPlotter_dict
vd = VariationalLDAPlotter_dict(lda_dict)
vd.bar_alpha()

In [386]:
doc = lda_dict['corpus'].keys()[2]
parentmass = float(lda_dict['doc_metadata'][doc]['parentmass'])
print parentmass
vd.plot_document_colour_one_topic(doc,'motif_157',precursor_mass = parentmass)


591.149702724


In [387]:
vd.plot_document_topic_colour(doc)
print lda_dict['theta'][doc]

{'motif_276': 0.079582169346298609, 'motif_320': 0.85365570485186826, 'motif_153': 0.016232241597535814}


In [None]:
500*5000

In [309]:
import numpy as np
a = np.array([1,2,3])
type(a)
a = list(a)
type(a)
print a

[1, 2, 3]


In [402]:
from lda_plotters import VariationalLDAPlotter_dict
vd = VariationalLDAPlotter_dict(lda_dict)
G = vd.make_graph_object(filename = '../joegraph/gnps2.json')

Found 280 topics
