In [None]:
%load_ext autoreload
%autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
print(_os.path.realpath(_os.path.curdir))

In [None]:
import graph_tool as gt
import graph_tool.draw
import numpy as np
import pandas as pd
import scipy.sparse
import scipy as sp
from collections import defaultdict

In [None]:
# Functions for constructing graphs
def path_to_edgelist(path):
    u = path[0]
    edges = []
    for v in path[1:]:
        edges.append((u, v))
        u = v
    return edges

def new_graph_from_merged_paths(paths, lengths, depths):
    g = gt.Graph()
    for p in paths:
        g.add_edge_list(path_to_edgelist(p))
    g.vp['depth'] = g.new_vp('float', depths)
    g.vp['length'] = g.new_vp('int', lengths)  
    return g

In [None]:
def draw_graph(g, **kwargs):
    return gt.draw.graph_draw(g, output_size=(300, 300), ink_scale=0.8, **kwargs)

In [None]:
paths = [
    [0, 1, 2, 0],
    [3, 4, 5, 3],
    [0, 6, 3],
    [5, 7, 8, 9, 10, 11],
    [9, 12],
    [12, 9],
]

g0 = new_graph_from_merged_paths(paths, depths=np.random.randint(0, 12, size=13), lengths=1)

In [None]:
g0_pos = draw_graph(g0, vertex_text=g0.vertex_index)

In [None]:
draw_graph(g0, pos=g0_pos, vertex_text=g0.vp.length)

In [None]:
draw_graph(g0, pos=g0_pos, vertex_text=g0.vp.depth)

In [None]:
def edge_has_no_siblings(g):
    "Check whether upstream or downstream sibling edges exist for every edge."
    vs = g.get_vertices()
    v_in_degree = g.new_vertex_property('int', vals=g.get_in_degrees(vs))
    v_out_degree = g.new_vertex_property('int', vals=g.get_out_degrees(vs))
    e_num_in_siblings = gt.edge_endpoint_property(g, v_in_degree, 'target')
    e_num_out_siblings = gt.edge_endpoint_property(g, v_out_degree, 'source')
    e_has_no_sibling_edges = g.new_edge_property('bool', (e_num_in_siblings.a <= 1) & (e_num_out_siblings.a <= 1))
    return e_has_no_sibling_edges

def vertex_does_not_have_both_multiple_in_and_multiple_out(g):
    vs = g.get_vertices()
    return g.new_vertex_property('bool', vals=(
        (g.get_in_degrees(vs) <= 1)
        | (g.get_out_degrees(vs) <= 1)
    ))

def label_maximal_unitigs(g):
    "Assign unitig indices to vertices in maximal unitigs."
    no_sibling_edges = edge_has_no_siblings(g)
    # Since any vertex that has both multiple in _and_ multiple out
    # edges cannot be part of a larger maximal unitig,
    # we could filter out these vertices, at the same time as we
    # are filtering out the edges with siblings.
    # Potentially this would make the component labeling step
    # much faster.
    both_sides_branch = vertex_does_not_have_both_multiple_in_and_multiple_out(g)
    # TODO: Double check, if this has any implications for the
    # "unitig-ness" of its neighbors. I _think_
    # if we mark edges with siblings before filtering out
    # these nodes we should be good.
    g_filt = gt.GraphView(
        g,
        efilt=no_sibling_edges,
        vfilt=both_sides_branch,
        directed=False
    )
    # Since we've filtered out the both_sides_branch vertices,
    # the labels PropertyMap would include a bunch of the default value (0)
    # for these. Instead, we set everything not labeled to -1, now a magic
    # value for nodes definitely not in maximal unitigs.
    labels = g.new_vertex_property('int', val=-1)
    labels, counts = gt.topology.label_components(g_filt, vprop=labels)
    return labels, counts, g_filt

In [None]:
edge_has_no_siblings(g0).a

In [None]:
labels, sizes, _g = label_maximal_unitigs(g0)
labels.a

In [None]:
draw_graph(_g, pos=g0_pos, vertex_text=g0.vertex_index)

In [None]:
draw_graph(g0, vertex_text=label_maximal_unitigs(g0)[0], pos=g0_pos)

In [None]:
from functools import reduce
import operator

def maximal_unitigs(g):
    "Generate maximal unitigs as lists of vertices"
    labels, counts, _ = label_maximal_unitigs(g)
    # The output *labels* is an array matching vertices to their unitig ids
    # while *counts* matches these ids to their sizes.
    # Therefore we enumerate the latter to iterate
    # through the list of maximal unitigs.
    # NOTE: While label_maximal_unitigs does return the
    # filtered and undirected graph with sibling edges and multi-in-and-multi-out
    # vertices removed, the _original_ graph should be used for the unitig
    # construction below.
    for i, c in enumerate(counts):
        vfilt = (labels.a == i)
        assert vfilt.sum() == c
        subgraph = gt.GraphView(g, vfilt=vfilt)
        unitig = list(subgraph.iter_vertices())
        is_cycle = not gt.topology.is_DAG(subgraph)
        yield unitig, is_cycle

In [None]:
unitigs, cycles = list(zip(*maximal_unitigs(g0)))
unitigs, cycles

In [None]:
def list_unitig_neighbors(g, vs):
    "The in and out neighbors of a unitig path."
    all_ins = reduce(operator.add, map(lambda v: list(g.iter_in_neighbors(v)), vs))
    all_outs = reduce(operator.add, map(lambda v: list(g.iter_out_neighbors(v)), vs))
    return list(set(all_ins) - set(vs)), list(set(all_outs) - set(vs))

In [None]:
list_unitig_neighbors(g0, [3, 4, 5])

In [None]:
def mutate_add_compressed_unitig_vertex(g, vs, is_cycle, drop_vs=False):
    v = int(g.add_vertex())
    in_neighbors, out_neighbors = list_unitig_neighbors(g, vs)
    g.add_edge_list((neighbor, v) for neighbor in in_neighbors)
    g.add_edge_list((v, neighbor) for neighbor in out_neighbors)
    g.vp.length.a[v] = g.vp.length.a[vs].sum()
    g.vp.depth.a[v] = (g.vp.depth.a[vs] * g.vp.length.a[vs]).sum() / g.vp.length.a[v]
    if is_cycle:
        g.add_edge(v, v)
    for old_v in vs:
        g.clear_vertex(old_v)
    return g

In [None]:
def mutate_compress_all_unitigs(g):
    unitig_list = maximal_unitigs(g)
    all_vs = []
    for i, (vs, is_cycle) in enumerate(unitig_list):
        # TODO: If len(vs) == 1, this is effectively a no-op and can be dropped.
        mutate_add_compressed_unitig_vertex(g, vs, is_cycle)
        all_vs.extend(vs)
    
    g.remove_vertex(set(all_vs), fast=True)
    # I think, but am not sure, that the number of nodes removed will always equal the number of edges removed.
    return g

In [None]:
g1 = mutate_compress_all_unitigs(g0.copy())

In [None]:
g1_pos = draw_graph(g1, vertex_text=g1.vertex_index)

In [None]:
draw_graph(g1, pos=g1_pos, vertex_text=g1.vp.length)
draw_graph(g0, pos=g0_pos, vertex_text=g0.vp.length)

In [None]:
draw_graph(g1, pos=g1_pos, vertex_text=g1.vp.depth)
draw_graph(g0, pos=g0_pos, vertex_text=g0.vp.depth)

In [None]:
from itertools import product
from collections import namedtuple

Split = namedtuple('Split', ['u', 'v', 'w', 'l', 'd'])

def all_local_paths_as_splits(g, v):
    "Generate all splits, the product of all in-edges crossed with all out-edges."
    assert v < g.num_vertices(ignore_filter=True)
    us = list(g.iter_in_neighbors(v))
    ws = list(g.iter_out_neighbors(v))
    num_splits = (len(us) * len(ws))
    length = g.vp.length.a
    depth = g.vp.depth.a
    for u, w in product(us, ws):
        # NOTE: This splitting function evenly distributes across all paths.
        yield Split(u, v, w, length[v], depth[v] / num_splits)

def build_tables_from_splits(split_list, start_idx):
    """Generate edges to and from new, split vertices.
    
    Note that if splits from adjacent parents are not
    reciprocated, no new edge is produced.
    
    """
    split_idx = {}
    upstream = defaultdict(list)
    downstream = defaultdict(list)
    length = []
    depth = []
    for idx, split in enumerate(split_list, start=start_idx):
        u, v, w, l, d = split
        split_idx[split] = idx
        upstream[(v, w)].append(split)
        downstream[(u, v)].append(split)
        depth.append(d)
        length.append(l)
    return split_idx, upstream, downstream, np.array(length), np.array(depth)
        
        
def new_edges_from_splits(split_list, split_idx, upstream, downstream, start_idx):
    for v, split in enumerate(split_list, start=start_idx):
        u_old, v_old, w_old, _, _ = split
        v = split_idx[split]
        
        # Upstream edges
        yield (u_old, v, split_idx[split])
        for upstream_split in upstream[(u_old, v_old)]:
            u = split_idx[upstream_split]
            yield (u, v)
            
        # Downstream edges
        yield (v, w_old)
        for downstream_split in downstream[(v_old, w_old)]:
            w = split_idx[downstream_split]
            yield (v, w)

In [None]:
paths = [
    [3, 4, 5, 0, 1, 2, 6, 7, 8],
    [1, 1],
]

v_split = 1

_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
draw_graph(_g, vertex_text=_g.vertex_index)
split_list = list(all_local_paths_as_splits(_g, v_split))
start_idx = _g.num_vertices(ignore_filter=True)
split_idx, upstream, downstream, lengths, depths = build_tables_from_splits(split_list, start_idx=start_idx)
edges_to_add = list(set(new_edges_from_splits(split_list, split_idx, upstream, downstream, start_idx)))
_g.vertex_properties['_to_drop'] = _g.new_vertex_property('bool', False)
_g.add_edge_list(edges_to_add)
_g.vp.length.a[np.arange(len(lengths)) + start_idx] = lengths
_g.vp.depth.a[np.arange(len(depths)) + start_idx] = depths

for k in [v_split]:
    _g.vp._to_drop[k] = True
_g_pos = draw_graph(_g, vertex_text=_g.vertex_index, vertex_color=_g.vp._to_drop)
_g.remove_vertex([v_split])
draw_graph(_g, pos=_g_pos, vertex_text=_g.vertex_index)

In [None]:
draw_graph(_g, pos=_g_pos, vertex_text=_g.vp.length)
draw_graph(_g, pos=_g_pos, vertex_text=_g.vp.depth)

In [None]:
def mutate_apply_splits(g, split_list):
    """Add edges and drop any parent vertices that were split.
    
    """
    start_idx = g.num_vertices(ignore_filter=True)
    split_idx, upstream, downstream, lengths, depths = (
        build_tables_from_splits(split_list, start_idx=start_idx)
    )
    edges_to_add = list(set(new_edges_from_splits(
        split_list, split_idx, upstream, downstream, start_idx
    )))
    g.add_edge_list(set(edges_to_add))
    g.vp.length.a[np.arange(len(lengths)) + start_idx] = lengths
    g.vp.depth.a[np.arange(len(depths)) + start_idx] = depths
    g.remove_vertex((split.v for split in split_list), fast=True)
    return g

In [None]:
paths = [
    [3, 4, 5, 0, 1, 2, 6, 7, 8],
    [1, 1],
]

_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
g0_pos = draw_graph(_g, vertex_text=_g.vertex_index)

split_list = list(all_local_paths_as_splits(_g, 1))
print(split_list)
mutate_apply_splits(_g, split_list)
draw_graph(_g, vertex_text=_g.vp.depth)

In [None]:
draw_graph(g1, vertex_text=g1.vertex_index)

# Split vertex 1, but drop one of the potential splits
# (the one reflecting a linear path with no repeats.)
split_list = set(all_local_paths_as_splits(g1, 0)) - set([(1, 0, 3)])
print(split_list)
g2 = mutate_apply_splits(g1.copy(), split_list=split_list)
draw_graph(g2, vertex_text=g2.vp.depth)

In [None]:
paths = [
    [0, 1, 2, 0],
    [2, 2]
]
_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
split_list = set(all_local_paths_as_splits(_g, 2))
_g = mutate_apply_splits(_g, split_list=split_list)
split_list = set(all_local_paths_as_splits(_g, 2))
_g = mutate_apply_splits(_g, split_list=split_list)

draw_graph(_g, vertex_text=_g.vertex_index)

split_list = set(all_local_paths_as_splits(_g, 2))
_g = mutate_apply_splits(_g, split_list=split_list)
# # draw_graph(_g, vertex_text=_g.vertex_index)
# split_list = set(all_local_paths_as_splits(_g, 8))
# _g = mutate_apply_splits(_g, split_list=split_list)
# # draw_graph(_g, vertex_text=_g.vertex_index)
# split_list = set(all_local_paths_as_splits(_g, 11))
# _g = mutate_apply_splits(_g, split_list=split_list)
# # draw_graph(_g, vertex_text=_g.vertex_index)

_g = mutate_compress_all_unitigs(_g)
# # draw_graph(_g, vertex_text=_g.vertex_index)

# split_list = set(all_local_paths_as_splits(_g, 0))
# _g = mutate_apply_splits(_g, split_list=split_list)
# # draw_graph(_g, vertex_text=_g.vertex_index)

draw_graph(_g, vertex_text=_g.vertex_index)


In [None]:
def splits_for_all_vertices(g, split_func):
    for v in g.iter_vertices():
        yield from split_func(g, v)

In [None]:
paths = [
    [0, 1, 2, 0],
    [2, 2]
]
_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
draw_graph(_g, vertex_text=_g.vertex_index)
split_list = set(splits_for_all_vertices(_g, all_local_paths_as_splits))
_g = mutate_apply_splits(_g, split_list=split_list)
draw_graph(_g, vertex_text=_g.vertex_index)

In [None]:
def mutate_split_all_nodes(g, split_func):
    # TODO: Vertices with <= 1 local path will be split into just
    # themselves pointing trivially at their neighbors.
    # These can be dropped as they are effectively a no-op.
    split_list = set(splits_for_all_vertices(g, split_func))
    g = mutate_apply_splits(g, split_list=split_list)
    return g

In [None]:
paths = [
    [0, 1, 2, 0],
    [0, 0],
    [1, 1],
    [2, 2],
]
_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
draw_graph(_g, vertex_text=_g.vertex_index)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
draw_graph(_g, vertex_text=_g.vertex_index)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)
%prun mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)

In [None]:
%prun mutate_compress_all_unitigs(_g)

In [None]:
vs = list(range(1_000))

paths = [
    vs,
    list(np.random.choice(vs, 200)),
]
_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
draw_graph(_g, vertex_text=_g.vertex_index)
%prun mutate_compress_all_unitigs(_g)

In [None]:
vs = list(range(10_000))

paths = (
    [
        vs,  # A long genome
        list(np.random.choice(vs, 500)), # Long-range interconnects
        list(np.random.choice(vs, 500)),
        list(np.random.choice(vs, 500)),
    ]
    + [[c, c] for c in np.random.choice(vs, 500)] # Self-loops
)
_g = new_graph_from_merged_paths(paths, lengths=1, depths=1)
print(_g)
%prun mutate_compress_all_unitigs(_g)
print(_g)

In [None]:
%prun mutate_split_all_nodes(_g, split_func=all_local_paths_as_splits)

In [None]:
import matplotlib.pyplot as plt

plt.hist(_g.vp.length.a)
plt.yscale('log')