In [495]:
import networkx as nx
import sys
from collections import Counter
import pickle
import numpy as np

class Cluster:
    def __init__(self,number,a_graph,b_graph):
        self.number = number
        self.a_graph = a_graph
        self.b_graph = b_graph
        self.a_index = {edge: i for i,edge in enumerate(a_graph.edges())}
        self.b_index = {edge: i for i,edge in enumerate(b_graph.edges())}
    def nodes(self):
        return self.a_graph.nodes()
    def a_edges(self,nbunch = None):
        return self.a_graph.edges(nbunch)
    def b_edges(self,nbunch = None):
        return self.b_graph.edges(nbunch)
    def split_at(self,e1,e2,l):
        try:
            self.a_graph.remove_edge(e1,e2)
        except:
            try:
                self.a_graph.remove_edge(e2,e1)
            except:
                return []
        new_a_graphs = nx.connected_component_subgraphs(self.a_graph)
        new_clusters = []
        for i,a in enumerate(new_a_graphs):
            b = nx.Graph()
            for n1 in a:
                for n2 in a:
                    if n1 != n2:
                        try:
                            b.add_edge(n1,n2,weight=self.b_graph[n1][n2]['weight'])
                            b.add_edge(n2,n1,weight=self.b_graph[n2][n1]['weight'])
                        except:
                            pass
            new_clusters.append(Cluster(str(self.number) + "." + chr(l + i + 97),a,b))
        return new_clusters
    def split_within(self,e1,e2):
        try:
            self.a_graph.remove_edge(e1,e2)
        except:
            try:
                self.a_graph.remove_edge(e2,e1)
            except:
                return
        new_a_graphs = nx.connected_component_subgraphs(self.a_graph)
        new_clusters = [] 
        b = nx.Graph()
        for i,a in enumerate(new_a_graphs):
            for n1 in a:
                for n2 in a:
                    if n1 != n2:
                        b.add_edge(n1,n2,weight=self.b_graph[n1][n2]['weight'])
                        b.add_edge(n2,n1,weight=self.b_graph[n2][n1]['weight'])
        self.b_graph = b
    def weight(self):
        return avg_cc_weight(self.b_graph)
    def id_as(self):
        result_ids = set()
        for n in self.nodes():
            try:
                result_ids.add(ids[n])
            except:
                pass
        return result_ids
    def lowest_weight(self):
        weight_edges = []
        for e in self.b_edges():
            edge_weight = float(self.b_graph[e[0]][e[1]]['weight'])
            weight_edges.append((edge_weight,e))
        weight_edges.sort(key=lambda x: x[0])
        if len(weight_edges) == 0:
            return 1
        else:
            return weight_edges[0]
        

def pickle_clusters(clusters,filename = "clusters.p"):
    with open(filename, "wb") as f:
        pickle.dump(clusters,f)

def load_clusters(filename = "clusters.p"):
    clusters = []
    with open(filename, "rb") as f:
        clusters = pickle.load(f)
    return clusters

def pickle_ids(ids,filename = "ids.p"):
    with open(filename, "wb") as f:
        pickle.dump(ids,f)

def load_ids(filename = "ids.p"):
    ids = {}
    with open(filename, "rb") as f:
        ids = pickle.load(f)
    return ids

def generate_identification(filename):
    identifcation = {}
    cluster_number = {}
    with open(filename) as f:
        for line in f:
            e = line.replace("\n","").rsplit(",")
            identifcation[e[0]] = e[1]
            cluster_number[e[0]] = int(e[2])
    return identifcation, cluster_number

def generate_a_edges(filename):
    G=nx.Graph()
    with open(filename) as f:
        for line in f:
            e = line.replace("\n","").rsplit(",")
            G.add_edge(e[0],e[1])
    return G

def generate_b_edges(filename):
    G=nx.Graph()
    with open(filename) as f:
        for line in f:
            e = line.replace("\n","").rsplit(",")
            G.add_edge(e[0],e[1],weight=e[2])
    return G

def cc_weights(cc):
    weights = []
    for e in cc.edges():
        edge_weight = float(cc[e[0]][e[1]]['weight'])
        if edge_weight < 1:
            weights.append(edge_weight)
    return weights

def avg_cc_weight(cc):
    weights = cc_weights(cc)
    weight_result = 1
    try:
        weight_result = sum(weights)/len(weights)
    except:
        pass
    return weight_result

def cluster_ids(cc,ids):
    id_set = []
    for n in cc.nodes():
        try:
            id_set.append(ids[n])
        except:
            pass
    return {s for s in id_set}

def mix(cc,ids):
    c_ids = cluster_ids(cc,ids)
    if 'PEPTIDE' in c_ids:
        return len(c_ids) > 2
    else:
        return len(c_ids) > 1

def an_id(cc,ids):
    c_ids = cluster_ids(cc,ids)
    if 'PEPTIDE' in c_ids:
        return len(c_ids)>1
    else:
        return True

def all_id(cc,ids):
    c_ids = cluster_ids(cc,ids)
    return not 'PEPTIDE' in c_ids

def link_edges(a_clusters,b_clusters):
    clusters = []
    for cluster_id in a_clusters:
        try:
            clusters.append(Cluster(cluster_id,a_clusters[cluster_id],b_clusters[cluster_id]))
        except:
            pass
    return clusters


def seperate_mixtures(graph):
    mix_graphs = {}
    m = 0
    s = 0
    p = 0
    sp = 0
    e = 0
    for cluster in nx.connected_component_subgraphs(graph):
        clust = cluster_ids(cluster,ids)
        if len(clust)==1 and 'PEPTIDE' in clust:
            s += 1
        elif len(clust)==1 and not 'PEPTIDE' in clust:
            p += 1
        elif len(clust)==2 and 'PEPTIDE' in clust:
            sp += 1
        elif mix(cluster,ids):
            try:
                key = cluster_number[cluster.nodes()[0]]
                mix_graphs[key] = cluster
            except:
                pass
            m += 1
            # if avg_cc_weight(cluster) < .7:
            #     for edge in cluster.edges():
            #         print(edge[0] + " " + edge[1] + ": " + cluster[edge[0]][edge[1]]['weight'])
        else:
            e += 1
    # print("Not id'd: " + str(s))
    # print("Mixtures: " + str(m))
    # print("Pure: " + str(p))
    # print("Maybe pure: " + str(sp))
    # print("Error: " + str(e))
    return mix_graphs

clusters = []
try:
    a_graph = generate_a_edges(sys.argv[1])
    b_graph = generate_b_edges(sys.argv[2])
    ids, cluster_number = generate_identification(sys.argv[3])
    mixture_a = seperate_mixtures(a_graph)
    mixture_b = seperate_mixtures(b_graph)
    clusters = link_edges(mixture_a,mixture_b)
    pickle_clusters(clusters)
    pickle_ids(ids)
except:
    clusters = load_clusters()
    ids = load_ids()

# for n in a_graph.nodes():
#     try:
#         m += 1
#         print(ids[n])
#     except:
#         s += 1


c = clusters[1]
E = np.zeros((len(c.b_edges()),len(c.a_edges())),dtype=np.int)
# print(E)
# print(c.nodes())
# # print(c.lowest_weight())
# # print(c.id_as())
# # for cluster in c.split_at(c.a_edges()[0][0], c.a_edges()[0][1]):
# #     print(cluster.id_as())
# #     # print(cluster.nodes())
# #     # print(cluster.b_edges())
# #     # print(cluster.weight())
# #     print(cluster.lowest_weight())


# print(ids)


In [460]:
def dfs(clust, node, E, in_edge):
    E_v = np.zeros((len(clust.b_edges())),dtype=np.int)
    a_idx = None
    if in_edge:
        try:
            a_idx = clust.a_index[in_edge]
        except:
            a_idx = clust.a_index[(in_edge[1],in_edge[0])]
    clust.a_graph.node[node]['visited'] = True
    for b_edge in clust.b_edges(node):
        try:
            edge_idx = clust.b_index[b_edge]
        except:
            edge_idx = clust.b_index[(b_edge[1],b_edge[0])]
        E_v[edge_idx] = 1
    for n in clust.a_graph.neighbors(node):
        if not clust.a_graph.node[n]['visited']:
            try:
                edge_idx = clust.a_index[(node,n)]
            except:
                edge_idx = clust.a_index[(n,node)]
            E = dfs(clust,n,E,(node,n))
            E_v = combine(E_v, E[:,edge_idx].T)
    if in_edge:
        E[:,a_idx] = E_v.T
    return E

In [400]:
def xor(x,y):
    if x == 1 and y == 1:
        return 0
    elif x == 0 and y == 0:
        return 0
    else:
        return 1

In [401]:
def combine(E_x,E_y):
    E_z = np.zeros((len(E_x)),dtype=np.int)
    for i in range(0,len(E_x)):
        E_z[i] = xor(E_x[i],E_y[i])
    return E_z

In [468]:
def score_E(clust,cc,E):
    weight_tuple = []
    for a_edge in cc.a_edges():
        weights = []
        try:
            a_idx = clust.a_index[a_edge]
        except:
            a_idx = clust.a_index[(a_edge[1],a_edge[0])]
        for b_edge in cc.b_edges():
            try:
                b_idx = clust.b_index[b_edge]
            except:
                b_idx = clust.b_index[(b_edge[1],b_edge[0])]
            if E[(b_idx,a_idx)] == 1:
                try:
                    weights.append(float(clust.b_graph[b_edge[0]][b_edge[1]]['weight']))
                except:
                    weights.append(float(clust.b_graph[b_edge[1]][b_edge[0]]['weight']))
        try:
            weight_tuple.append((a_edge,sum(weights)/len(weights)))
        except:
            weight_tuple.append((a_edge,0))
    min_weight = sorted(weight_tuple,key=lambda x: x[1])[0]
    return min_weight[0]

In [462]:
def find_min_cut(cluster,E,theta):
    ccs_below, ccs_above = cc_below_threshold([cluster],theta)
    clusters = ccs_above
    while len(ccs_below) > 0:
        clust = ccs_below.pop()
#         print(clust.nodes())
#         print(len(clust.nodes()))
#         print(len(clust.b_edges()))
        edge = score_E(cluster,clust,E)
        new_clusters = clust.split_at(edge[0],edge[1],0)
        ccs_below.append(new_clusters[0])
        ccs_below.append(new_clusters[1])
        ccs, ccs_above = cc_below_threshold(ccs_below,theta)
        clusters = clusters + ccs_above
        ccs_below = ccs
    return clusters
    

In [440]:
def cc_below_threshold(ccs,theta):
    ccs_below = []
    ccs_above = []
    for cc in ccs:
        if cc.weight()<theta:
            ccs_below.append(cc)
        else:
            ccs_above.append(cc)
    return ccs_below, ccs_above

In [None]:
correct = 0
total = 0
for c in clusters:
#     print("---")
#     print(len(c.nodes()))
    if total%250 == 0:
        print(total)
    E = np.zeros((len(c.b_index),len(c.a_index)),dtype=np.int)
    for a in c.a_graph:
        c.a_graph.node[a]["visited"] = False
    E = dfs(c,c.nodes()[0],E,None)
    new_c = find_min_cut(c,E,.8)
    total += 1
    if effectiveness(new_c):
        correct += 1
correct/total

0
250
500
750
1000
1250
1500
1750
2000
2250
2500
2750
3000
3250
3500
3750
4000
4250
4500
4750
5000
5250
5500
5750
6000

In [457]:
c.nodes()[0]

'00522_D02_P003811_B0L_A00_R1.mzXML:8141'

In [453]:
[1,2] + []

[1, 2]

In [473]:
def effectiveness(clusters):
    for cluster in clusters:
        if mix(cluster,ids):
            return False
    return True

In [481]:
total

313

In [None]:
def spectra_recovered(clusters):
    total_spectra = 0
    for cluster in clusters:
        
        if mix(cluster,ids):
            return False
    return True

In [None]:
correct/total