In [None]:
%load_ext autoreload
%autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
print(_os.path.realpath(_os.path.curdir))

In [None]:
import graph_tool as gt
import graph_tool.draw
import numpy as np
import pandas as pd
import scipy.sparse
import scipy as sp
from collections import defaultdict
from tqdm import tqdm

In [None]:
# Functions for constructing graphs
def path_to_edgelist(path):
    u = path[0]
    edges = []
    for v in path[1:]:
        edges.append((u, v))
        u = v
    return edges

def new_graph_from_merged_paths(paths, lengths, depths):
    g = gt.Graph()
    all_edges = []
    for p in paths:
        all_edges.extend(path_to_edgelist(p))
    g.add_edge_list(set(all_edges))
    g.vp['depth'] = g.new_vp('vector<float>')
    g.vp.depth.set_2d_array(depths)
    g.vp['length'] = g.new_vp('int', lengths)  
    g.gp['nsample'] = g.new_gp('int', len(depths))
    g.vp['sequence'] = g.new_vp('object', vals=[[k] for k in range(g.num_vertices())])
    return g

def get_depth_matrix(g, vs=None):
    if not vs:
        return g.vp.depth.get_2d_array(range(g.gp.nsample))
    else:
        return np.stack([g.vp.depth[i] for i in vs], axis=1)

In [None]:
from warnings import warn

def estimate_flow(f0, d, sample_idx, eps=1e-2, maxiter=100):
    loss_hist = [np.inf]
    f = f0
    for step_i in range(maxiter):
        f_out = f
        f_total_out = f_out.sum(1)
        d_error_out = f_total_out - d
        f_in = f_out.T
        f_total_in = f_in.sum(1)
        d_error_in = f_total_in - d
        loss_hist.append(np.square(d_error_out).sum() + np.square(d_error_in).sum())
        loss_ratio = (loss_hist[-2] - loss_hist[-1]) / loss_hist[-2]
        # print(loss_ratio)
        if loss_ratio < eps:
            # print(loss_hist[-1], loss_ratio)
            # print(step_i)
            break
        f_frac_out = (f_out.T * np.nan_to_num(1 / f_total_out, nan=0.0, posinf=1.0)).T
        allocated_d_error_out = (d_error_out * f_frac_out.T).T
        f_frac_in = (f_in.T * np.nan_to_num(1 / f_total_in, nan=0.0, posinf=1.0)).T
        allocated_d_error_in = (d_error_in * f_frac_in.T).T
        # TODO: Consider weighting the shift by length. A very short unitig shouldn't
        # be equally influential as a very long one.
        mean_allocated_d_error = (allocated_d_error_in.T + allocated_d_error_out) / 2
        f = (f_out - mean_allocated_d_error)
    else:
        warn(f"loss_ratio < eps ({eps}) not achieved in maxiter ({maxiter}) steps. Final loss_ratio={loss_ratio}. Final loss={loss_hist[-1]}.")
    return f


def estimate_all_flows(g, eps=1e-3, maxiter=1000):
    flows = []
    for sample_idx in range(g.gp.nsample):
        f0 = sp.sparse.csr_array(gt.spectral.adjacency(g))
        d = g.vp.depth.get_2d_array([sample_idx])[0]
        f = estimate_flow(f0, d, sample_idx=0, eps=eps, maxiter=maxiter)
        flows.append(f)
    return flows


def mutate_add_flows(g, flows):
    props = []
    for sample_idx, f in enumerate(flows):
        p = g.new_edge_property('float', val=0)
        for i, j in g.get_edges():
            p[g.edge(i, j)] = f[j, i]
        props.append(p)
    props = gt.group_vector_property(props)
    g.ep['flow'] = props
    return g

In [None]:
def draw_graph(g, **kwargs):
    kwargs = dict(
        vertex_fill_color=g.new_vertex_property('float', vals=np.linspace(0, 1, num=max(g.get_vertices()) + 1)),
        vertex_text=g.vertex_index,
        output_size=(300, 300),
        ink_scale=0.8,
    ) | kwargs
    return gt.draw.graph_draw(g, **kwargs)

In [None]:
import itertools

paths = [
    [0, 1, 2, 0],
    [0, 3, 4, 0],
    [0, 0],
]

nnodes = max(itertools.chain(*paths)) + 1

depths = np.array([
    [5, 5, 5, 0, 0],
    [4, 0, 0, 4, 4],
    [5, 1, 1, 1, 1],
    [11, 2, 2, 2, 2],
])
nsamples = depths.shape[0]

g0 = new_graph_from_merged_paths(
    paths,
    depths=depths,
    lengths=np.array([1] * nnodes),
)
mutate_add_flows(g0, estimate_all_flows(g0))


g0_pos = draw_graph(g0, vertex_text=g0.vertex_index)
for i in range(g0.gp.nsample):
    draw_graph(
        g0,
        pos=g0_pos,
        vertex_text=gt.ungroup_vector_property(g0.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g0.ep.flow, pos=[i])[0]
    )

In [None]:
def edge_has_no_siblings(g):
    "Check whether upstream or downstream sibling edges exist for every edge."
    vs = g.get_vertices()
    v_in_degree = g.new_vertex_property('int', vals=g.get_in_degrees(vs))
    v_out_degree = g.new_vertex_property('int', vals=g.get_out_degrees(vs))
    e_num_in_siblings = gt.edge_endpoint_property(g, v_in_degree, 'target')
    e_num_out_siblings = gt.edge_endpoint_property(g, v_out_degree, 'source')
    e_has_no_sibling_edges = g.new_edge_property('bool', (e_num_in_siblings.a <= 1) & (e_num_out_siblings.a <= 1))
    return e_has_no_sibling_edges

def vertex_does_not_have_both_multiple_in_and_multiple_out(g):
    vs = g.get_vertices()
    return g.new_vertex_property('bool', vals=(
        (g.get_in_degrees(vs) <= 1)
        | (g.get_out_degrees(vs) <= 1)
    ))


def label_maximal_unitigs(g):
    "Assign unitig indices to vertices in maximal unitigs."
    no_sibling_edges = edge_has_no_siblings(g)
    # Since any vertex that has both multiple in _and_ multiple out
    # edges cannot be part of a larger maximal unitig,
    # we could filter out these vertices, at the same time as we
    # are filtering out the edges with siblings.
    # Potentially this would make the component labeling step
    # much faster.
    both_sides_branch = vertex_does_not_have_both_multiple_in_and_multiple_out(g)
    # TODO: Double check, if this has any implications for the
    # "unitig-ness" of its neighbors. I _think_
    # if we mark edges with siblings before filtering out
    # these nodes we should be good.
    g_filt = gt.GraphView(
        g,
        efilt=no_sibling_edges,
        vfilt=both_sides_branch,
        directed=True
    )
    g_filt_undirected = gt.GraphView(g_filt, directed=False)
    # Since we've filtered out the both_sides_branch vertices,
    # the labels PropertyMap would include a bunch of the default value (0)
    # for these. Instead, we set everything not labeled to -1, now a magic
    # value for nodes definitely not in maximal unitigs.
    labels = g.new_vertex_property('int', val=-1)
    labels, counts = gt.topology.label_components(g_filt_undirected, vprop=labels)
    return labels, counts, g_filt


from functools import reduce
import operator


def is_simple_cycle(g, vfilt):
    g_filt = gt.GraphView(g, vfilt=vfilt)
    vs = list(g_filt.iter_vertices())
    path = list(gt.topology.all_paths(g_filt, vs[0], vs[0]))
    if not len(path) == 1:
        return False
    elif set(path[0]) != set(vs):
        return False
    else:
        return True

def maximal_unitigs(g):
    "Generate maximal unitigs as lists of vertices"
    # (1a) Filter edges
    # (1b) Label unitigs
    labels, counts, g_filt = label_maximal_unitigs(g)
    # (2) Find every node in the filtered graph without in-edges (these are the origins)
    is_origin = g_filt.new_vp('bool', g_filt.get_in_degrees(g_filt.get_vertices()) == 0)
    # (3) Find every node in the filtered graph without out-edges (these are the termina)
    is_terminus = g_filt.new_vp('bool', g_filt.get_out_degrees(g_filt.get_vertices()) == 0)
    # (4) Iter through the labels.
    for i, c in enumerate(counts):
        # (a) For each, select the subgraph for that unitig.
        vfilt = (labels.a == i)
        assert vfilt.sum() == c
        subgraph = gt.GraphView(g_filt, vfilt=vfilt)
        # (b) Identify the origin and terminus nodes in the path
        vs = subgraph.get_vertices()
        origin = gt.GraphView(subgraph, vfilt=is_origin).get_vertices()
        terminus = gt.GraphView(subgraph, vfilt=is_terminus).get_vertices()
        if (len(origin) == 0) and (len(terminus) == 0):
            origin, terminus = vs[0], vs[0]
        elif (len(origin) == 1) and (len(terminus) == 1):
            origin, terminus = origin[0], terminus[0]
        else:
            raise AssertionError("If there are multiple origins or termina, then it's not a unitig.")
        # (c) Trace the route from the origin to the terminus (`graph_tool.topology.all_paths`)
        unitig = list(gt.topology.all_paths(subgraph, origin, terminus))
        if len(unitig) == 0:
            continue  # Maximal unitig of length 1. No-op.
        assert len(unitig) == 1, "If there are multiple paths from origin to terminus, then it's not a unitig."
        unitig = unitig[0]
        # (d) Ask if it's a cycle.
        is_cycle = is_simple_cycle(g, vfilt)
        if is_cycle and (len(unitig) > 1) and (unitig[0] == unitig[-1]):
            unitig = unitig[:-1]
            assert len(set(unitig)) == len(unitig)
        # (e) Yield the route and whether it's a cycle.
        yield list(unitig), is_cycle
        
        
def list_unitig_neighbors(g, vs):
    "The in and out neighbors of a unitig path."
    all_ins = reduce(operator.add, map(lambda v: list(g.iter_in_neighbors(v)), vs))
    all_outs = reduce(operator.add, map(lambda v: list(g.iter_out_neighbors(v)), vs))
    return list(set(all_ins) - set(vs)), list(set(all_outs) - set(vs))

def mutate_add_compressed_unitig_vertex(g, vs, is_cycle, drop_vs=False):
    v = int(g.add_vertex())
    nsample = g.gp.nsample
    in_neighbors, out_neighbors = list_unitig_neighbors(g, vs)
    g.add_edge_list((neighbor, v) for neighbor in in_neighbors)
    g.add_edge_list((v, neighbor) for neighbor in out_neighbors)
    g.vp.length.a[v] = g.vp.length.a[vs].sum()
    g.vp.sequence[v] = reduce(operator.add, (g.vp.sequence[u] for u in vs), [])
    g.vp.depth[v] = (
        (
            get_depth_matrix(g, vs)
            * g.vp.length.a[vs]
        ).sum(1) / g.vp.length.a[v]
    )
    assert np.allclose((get_depth_matrix(g, vs) * g.vp.length.a[vs]).sum(1), (get_depth_matrix(g, [v]) * g.vp.length.a[v]).sum(1))
    if is_cycle:
        g.add_edge(v, v)
    for old_v in vs:
        g.clear_vertex(old_v)
    return g

# FIXME: For some reason, when length-1 unitigs are compressed, the total depth of the graph increases... hmm...?
# Could this be mostly due to self-looping nodes?
# I think one possibility is that self-looping nodes are somehow being expanded out into ever-longer
# chains of loops without correctly adjusting the depth.

def mutate_compress_all_unitigs(g):
    unitig_list = maximal_unitigs(g)
    all_vs = []
    for i, (vs, is_cycle) in enumerate(unitig_list):
        # TODO: If len(vs) == 1, this is effectively a no-op and can be dropped.
        mutate_add_compressed_unitig_vertex(g, vs, is_cycle)
        all_vs.extend(vs)
    g.remove_vertex(set(all_vs), fast=True)
    # I think, but am not sure, that the number of nodes removed will always equal the number of edges removed.
    return g

In [None]:
g1 = mutate_compress_all_unitigs(g0.copy())
mutate_add_flows(g1, estimate_all_flows(g1))

g1_pos = draw_graph(g1, vertex_text=g1.vertex_index)
for i in range(g1.gp.nsample):
    draw_graph(
        g1,
        pos=g1_pos,
        vertex_text=gt.ungroup_vector_property(g1.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g1.ep.flow, pos=[i])[0]
    )

In [None]:
from itertools import product
from sklearn.decomposition import sparse_encode


def compile_code_inference_inputs(g, v):
    obs = []
    i = 0
    edge_idx = {}
    for u in g.get_in_neighbors(v):
        if (u, v) not in edge_idx:
            edge_idx[(u, v)] = i
            i += 1
            obs.append(g.ep.flow[g.edge(u, v)])
    for w in g.get_out_neighbors(v):
        if (v, w) not in edge_idx:
            edge_idx[(v, w)] = i
            i += 1
            obs.append(g.ep.flow[g.edge(v, w)])

    dictionary = []
    split_idx = {}
    eye = np.eye(len(edge_idx))
    i = 0
    for u, w in product(g.get_in_neighbors(v), g.get_out_neighbors(v)):
        split_idx[i] = (u, w)
        i += 1
        if u != w:
            dictionary.append(eye[edge_idx[(u, v)]] + eye[edge_idx[(v, w)]])
        else:
            dictionary.append(eye[edge_idx[(u, v)]] + eye[edge_idx[(v, w)]])
    # TODO: Consider how to weight these estimates by unitig length.
    # A lar
    return(np.stack(obs, axis=1), np.stack(dictionary), edge_idx, split_idx)


def sparse_encode_gmp(X, dictionary, eps=0.1):
    # Group Matching Pursuit (GMP)
    # Inspired by https://arxiv.org/pdf/1812.10538.pdf
    residual = X
    dictionary0 = dictionary
    dictionary = np.zeros_like(dictionary0)
    atoms = []
    loss0 = np.abs(X).sum()
    loss_hist = []
    encoding = np.zeros((X.shape[0], dictionary.shape[0]))
    for _ in range(dictionary.shape[0]):
        loss_hist.append(np.abs(residual).sum())
        if loss_hist[-1] / loss0 < eps:
            break
        dot = residual @ dictionary0.T
        atoms.append(dot.argmax() % dot.shape[1])
        dictionary[atoms[-1]] = dictionary0[atoms[-1]]
        encoding = sparse_encode(X, dictionary=dictionary, algorithm='lasso_lars', positive=True, alpha=0.0)
        # TODO: How do I do it this way?
        # encoding[:, atoms[-1]] += dot[:, atoms[-1]]
        residual =  X - encoding @ dictionary
    return encoding, loss_hist

def print_split_details_from_sparse_encoding(g, v, eps=1e-2):
    obs, dictionary, edge_idx, split_idx = compile_code_inference_inputs(g, v=v)
    print(obs.round(3))
    print(dictionary)
    print(edge_idx)
    split_depth, loss_hist = sparse_encode_gmp(obs, dictionary, eps=eps)
    print(split_depth.round(3))
    active_components = np.arange(split_depth.shape[1])[split_depth.sum(0) != 0]
    for i in active_components:
        print(i, split_idx[i], split_depth[:, i].round(2))
        
        
from itertools import product
from collections import namedtuple

Split = namedtuple('Split', ['u', 'v', 'w'])
# TODO: Add functionality to split a node with only an in or only an out.
# TODO: Use this functionality to keep "residual" nodes, that keep all of the depth unaccounted for
# by the sparse coding.

def all_local_paths_as_splits(g, v):
    "Generate all splits, the product of all in-edges crossed with all out-edges."
    assert v < g.num_vertices(ignore_filter=True)
    us = list(g.iter_in_neighbors(v))
    ws = list(g.iter_out_neighbors(v))
    num_splits = (len(us) * len(ws))
    length = g.vp.length[v]
    depth = get_depth_matrix(g, vs=[v])
    for u, w in product(us, ws):
        # NOTE: This (dummy) splitting function evenly distributes across all paths.
        yield Split(u, v, w), length, depth / num_splits

def splits_from_sparse_encoding(g, v, eps=1e-2):
    obs, dictionary, edge_idx, split_idx = compile_code_inference_inputs(g, v=v)
    split_depth, loss_hist = sparse_encode_gmp(obs, dictionary, eps=eps)
    active_components = np.arange(split_depth.shape[1])[split_depth.sum(0) != 0]
    for i in active_components:
        u, w = split_idx[i]
        yield Split(u, v, w), g.vp.length[v], split_depth[:, i]
        
list(splits_from_sparse_encoding(g1, 0))

In [None]:
def splits_from_sparse_encoding2(g, v):
    # Group Matching Pursuit (GMP)
    # Inspired by https://arxiv.org/pdf/1812.10538.pdf
    in_neighbors = list(sorted(g.get_in_neighbors(v)))
    in_neighbors_label = {k: v for k, v in enumerate(in_neighbors)}
    in_neighbors_onehot = {k: v for k, v in zip(in_neighbors, np.eye(len(in_neighbors)))}

    out_neighbors = list(sorted(g.get_out_neighbors(v)))
    out_neighbors_label = {k: v for k, v in enumerate(out_neighbors)}
    out_neighbors_onehot = {k: v for k, v in zip(out_neighbors, np.eye(len(out_neighbors)))}

    in_neighbor_flow = []
    for u in in_neighbors:
        in_neighbor_flow.append(g.ep.flow[g.edge(u, v)])
    in_neighbor_flow = np.stack(in_neighbor_flow)

    out_neighbor_flow = []
    for w in out_neighbors:
        out_neighbor_flow.append(g.ep.flow[g.edge(v, w)])
    out_neighbor_flow = np.stack(out_neighbor_flow)

    in_neighbor_code = []
    out_neighbor_code = []
    split_idx = {}
    for i, (u, w) in enumerate(product(in_neighbors, out_neighbors)):
        in_neighbor_code.append(in_neighbors_onehot[u])
        out_neighbor_code.append(out_neighbors_onehot[w])
        split_idx[i] = (u, w)
    in_neighbor_code = np.stack(in_neighbor_code)
    out_neighbor_code = np.stack(out_neighbor_code)

    obs = np.concatenate([in_neighbor_flow, out_neighbor_flow]).T
    code = np.concatenate([in_neighbor_code, out_neighbor_code], axis=1)

    resid = obs
    atoms = []
    dictionary = np.zeros_like(code)

    for _ in range(code.shape[0]):
        loss = np.abs(resid).sum()
        # print('-----', loss)
        dot = resid @ code.T
        # print(dot.round(2))
        dot_sum = dot.sum(0)
        # print(dot.sum(0).round(5))
        if dot_sum.max() <= 1e-5:
            # print("No atoms to add.")
            break
        atoms.append(dot.sum(0).argmax())
        # print(atoms)
        dictionary[atoms[-1]] = code[atoms[-1]]
        # print(dictionary)
        # encoding = sparse_encode(obs, dictionary=dictionary, algorithm='lasso_lars', positive=True, alpha=0.)
        encoding, _, _ = non_negative_factorization(obs, n_components=dictionary.shape[0], H=dictionary, update_H=False)
        # print(encoding.round(2))
        resid = obs - encoding @ code
        
    active_components = np.arange(encoding.shape[1])[encoding.sum(0) != 0]
    for i in active_components:
        u, w = split_idx[i]
        yield Split(u, v, w), g.vp.length[v], encoding[:, i]

list(splits_from_sparse_encoding2(g1, 0))

In [None]:
def build_tables_from_splits(split_list, start_idx):
    """Generate edges to and from new, split vertices.
    
    Note that if splits from adjacent parents are not
    reciprocated, no new edge is produced.
    
    """
    split_idx = {}
    upstream = defaultdict(list)
    downstream = defaultdict(list)
    length = []
    depth = []
    for idx, (split, l, d) in enumerate(split_list, start=start_idx):
        u, v, w = split
        split_idx[split] = idx
        upstream[(v, w)].append(split)
        downstream[(u, v)].append(split)
        depth.append(d)
        length.append(l)
    return split_idx, upstream, downstream, np.array(length), np.array(depth)
        
        
def new_edges_from_splits(split_list, split_idx, upstream, downstream, start_idx):
    for v, (split, _, _) in enumerate(split_list, start=start_idx):
        u_old, v_old, w_old = split
        v = split_idx[split]
        
        # Upstream edges
        if u_old is not None:
            yield (u_old, v)
        for upstream_split in upstream[(u_old, v_old)]:
            u = split_idx[upstream_split]
            if u is not None:
                yield (u, v)
            
        # Downstream edges
        if w_old is not None:
            yield (v, w_old)
        for downstream_split in downstream[(v_old, w_old)]:
            w = split_idx[downstream_split]
            if w is not None:
                yield (v, w)
            
            
def mutate_apply_splits(g, split_list):
    """Add edges and drop any parent vertices that were split.
    
    """
    start_idx = g.num_vertices(ignore_filter=True)
    split_idx, upstream, downstream, lengths, depths = (
        build_tables_from_splits(split_list, start_idx=start_idx)
    )
    edges_to_add = list(set(new_edges_from_splits(
        split_list, split_idx, upstream, downstream, start_idx
    )))
    
    # NOTE: Without adding vertices before edges I can get an IndexError
    # running `g.vp.length.a[np.arange(len(lengths)) + start_idx]`.
    # I believe this is because one or more split nodes
    # are without any new edges, and therefore these nodes don't
    # get added implicitly by `g.add_edge_list`.
    # When these accidentally hidden nodes are the highest valued
    # ones, they also don't get implicitly added due to their index.
    # The result is that I'm missing nodes that should actually exist.
    g.add_vertex(n=max(split_idx.values()) - max(g.get_vertices()))
    assert max(g.get_vertices()) == max(split_idx.values()) 
    g.add_edge_list(set(edges_to_add))
    
    g.vp.length.a[np.arange(len(lengths)) + start_idx] = lengths
    new_depth = get_depth_matrix(g)
    new_depth[:, np.arange(len(depths)) + start_idx] = depths.T
    g.vp.depth.set_2d_array(new_depth)
    for split, _, _ in split_list:
        g.vp.sequence[split_idx[split]] = g.vp.sequence[split.v]
    vertices_to_drop = set(split.v for (split, _, _) in split_list)
    g.remove_vertex(vertices_to_drop, fast=True)
    return g


def splits_for_all_vertices(g, split_func):
    for v in g.vertices():
        if (v.in_degree() < 2) and (v.out_degree() < 2):
            continue
        else:
            yield from split_func(g, v)
        
def mutate_split_all_nodes(g, split_func):
    # TODO: Vertices with <= 1 local path will be split into just
    # themselves pointing trivially at their neighbors.
    # These can be dropped as they are effectively a no-op.
    split_list = list(splits_for_all_vertices(g, split_func))
    if len(split_list) > 0:
        g = mutate_apply_splits(g, split_list=split_list)
    return g

In [None]:
def total_depth_length(g):
    return (get_depth_matrix(g) * g.vp.length.a).sum(1)

In [None]:
draw_graph(g1, pos=g1_pos, vertex_text=g1.vertex_index)

for i in range(g1.gp.nsample):
    draw_graph(
        g1,
        pos=g1_pos,
        vertex_text=gt.ungroup_vector_property(g1.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g1.ep.flow, pos=[i])[0]
    )
    
print(g1.vp.length.a)

In [None]:
g2 = mutate_apply_splits(g1.copy(), list(all_local_paths_as_splits(g1, 0)))
mutate_add_flows(g2, estimate_all_flows(g2))

g2_pos = draw_graph(g2, vertex_text=g2.vertex_index)

for i in range(g2.gp.nsample):
    draw_graph(
        g2,
        pos=g2_pos,
        # vertex_text=gt.ungroup_vector_property(g2.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g2.ep.flow, pos=[i])[0]
    )
    
print(
    (get_depth_matrix(g1) * g1.vp.length.a).sum(),
    (get_depth_matrix(g2) * g2.vp.length.a).sum(),
)

In [None]:
list(splits_from_sparse_encoding2(g1, 0))

In [None]:
g3 = mutate_apply_splits(g1.copy(), list(splits_from_sparse_encoding2(g1, 0)))
mutate_add_flows(g3, estimate_all_flows(g3))

g3_pos = draw_graph(g3, vertex_text=g3.vertex_index)

for i in range(g3.gp.nsample):
    draw_graph(
        g3,
        pos=g3_pos,
        # vertex_text=gt.ungroup_vector_property(g3.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g3.ep.flow, pos=[i])[0]
    )

print(
    total_depth_length(g2),
    total_depth_length(g3),
)

In [None]:
g4 = mutate_compress_all_unitigs(g3.copy())
mutate_add_flows(g4, estimate_all_flows(g4))
g4_pos = draw_graph(g4, vertex_text=g4.vertex_index)

# for i in range(g4.gp.nsample):
#     draw_graph(
#         g4,
#         pos=g4_pos,
#         # vertex_text=gt.ungroup_vector_property(g3.vp.depth, pos=[i])[0],
#         edge_pen_width=gt.ungroup_vector_property(g4.ep.flow, pos=[i])[0]
#     )
    
# FIXME: Why isn't total depth (mostly) invariant?
# Oddly: This happens much more during unitig compression than node splitting.
print(
    total_depth_length(g3),
    total_depth_length(g4),
)

In [None]:
import itertools

paths = [
    [0, 1, 2, 3, 4],
    [2, 2],
]

nnodes = max(itertools.chain(*paths)) + 1

depths = np.array([
    [5, 5, 5, 5, 5],
    [1, 1, 2, 1, 1],
])
nsamples = depths.shape[0]

g5 = new_graph_from_merged_paths(
    paths,
    depths=depths,
    lengths=np.array([1] * nnodes),
)
mutate_add_flows(g5, estimate_all_flows(g5))


g5_pos = draw_graph(g5, vertex_text=g5.vertex_index)
for i in range(g5.gp.nsample):
    draw_graph(
        g5,
        pos=g5_pos,
        vertex_text=gt.ungroup_vector_property(g5.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g5.ep.flow, pos=[i])[0]
    )
    
print(total_depth_length(g5))

In [None]:
g6 = mutate_compress_all_unitigs(g5.copy())
mutate_add_flows(g6, estimate_all_flows(g6))
g6_pos = draw_graph(g6, vertex_text=g6.vertex_index)
for i in range(g6.gp.nsample):
    draw_graph(
        g6,
        pos=g6_pos,
        vertex_text=gt.ungroup_vector_property(g6.vp.depth, pos=[i])[0],
        edge_pen_width=gt.ungroup_vector_property(g6.ep.flow, pos=[i])[0]
    )
    
print(total_depth_length(g6))

In [None]:
draw_graph(g1, vertex_text=g1.vertex_index)

In [None]:
g7 = mutate_split_all_nodes(g1.copy(), splits_from_sparse_encoding2)
mutate_add_flows(g7, estimate_all_flows(g7))

g7_pos = draw_graph(g7, vertex_text=g7.vertex_index)

# for i in range(g7.gp.nsample):
#     draw_graph(
#         g7,
#         pos=g7_pos,
#         # vertex_text=gt.ungroup_vector_property(g7.vp.depth, pos=[i])[0],
#         edge_pen_width=gt.ungroup_vector_property(g7.ep.flow, pos=[i])[0]
#     )

print(
    total_depth_length(g1),
    total_depth_length(g7),
)

In [None]:
# graph building

def path_to_edgelist(path):
    u = path[0]
    edges = []
    for v in path[1:]:
        edges.append((u, v))
        u = v
    return edges

def new_graph_from_merged_paths(paths, lengths, depths):
    g = gt.Graph()
    all_edges = []
    for p in paths:
        all_edges.extend(path_to_edgelist(p))
    g.add_edge_list(set(all_edges))
    g.vp['depth'] = g.new_vp('vector<float>')
    g.vp.depth.set_2d_array(depths)
    g.vp['length'] = g.new_vp('int', lengths)  
    g.gp['nsample'] = g.new_gp('int', len(depths))
    g.vp['sequence'] = g.new_vp('object', vals=[[k] for k in range(g.num_vertices())])
    return g

In [None]:
paths = [
    [0, 1, 2, 0],
    [0, 3, 4, 0],
    [0, 0],
]

nnodes = max(itertools.chain(*paths)) + 1

depths = np.array([
    [5, 5, 5, 0, 0],
    [4, 0, 0, 4, 4],
    [5, 1, 1, 1, 1],
    [11, 2, 2, 2, 2],
])
nsamples = depths.shape[0]

g = new_graph_from_merged_paths(
    paths,
    depths=depths,
    lengths=np.array([1] * nnodes),
)
mutate_add_flows(g, estimate_all_flows(g, eps=1e-4, maxiter=1000))
v = 0

def splits_from_sparse_encoding3(g, v, eps=1e-2):
    # Group Matching Pursuit (GMP)
    # Inspired by https://arxiv.org/pdf/1812.10538.pdf
    in_neighbors = list(sorted(g.get_in_neighbors(v)))
    num_in_neighbors = len(in_neighbors)
    in_neighbors_label = {k: v for k, v in enumerate(in_neighbors)}
    in_neighbors_onehot = {k: v for k, v in zip(in_neighbors, np.eye(num_in_neighbors))}
    in_neighbors_onehot[None] = np.zeros(num_in_neighbors)

    out_neighbors = list(sorted(g.get_out_neighbors(v)))
    num_out_neighbors = len(out_neighbors)
    out_neighbors_label = {k: v for k, v in enumerate(out_neighbors)}
    out_neighbors_onehot = {k: v for k, v in zip(out_neighbors, np.eye(num_out_neighbors))}
    out_neighbors_onehot[None] = np.zeros(num_out_neighbors)

    in_neighbor_flow = []
    for u in in_neighbors:
        in_neighbor_flow.append(g.ep.flow[g.edge(u, v)])
    # in_neighbor_flow = np.stack(in_neighbor_flow)

    out_neighbor_flow = []
    for w in out_neighbors:
        out_neighbor_flow.append(g.ep.flow[g.edge(v, w)])
    # out_neighbor_flow = np.stack(out_neighbor_flow)

    in_neighbor_code = []
    out_neighbor_code = []
    split_idx = {}
    for i, (u, w) in enumerate(product(in_neighbors + [None], out_neighbors + [None])):
        if (u, w) == (None, None):
            continue
            # NOTE: I DON'T include a node-only atom,
            # because I'll return this always and I don't want
            # to double-count.
        in_neighbor_code.append(in_neighbors_onehot[u])
        out_neighbor_code.append(out_neighbors_onehot[w])
        split_idx[i] = (u, w)
    in_neighbor_code = np.stack(in_neighbor_code)
    out_neighbor_code = np.stack(out_neighbor_code)

    depth_row = g.vp.depth[v].a
    obs = np.stack(in_neighbor_flow + [depth_row] + out_neighbor_flow).T
    code = np.concatenate([
        in_neighbor_code,
        np.ones((in_neighbor_code.shape[0], 1)),
        out_neighbor_code
    ], axis=1)
    # print(obs.round(2))
    # print(code)

    resid = obs
    atoms = []
    dictionary = np.zeros_like(code)
    encoding = np.zeros((obs.shape[0], dictionary.shape[0]))
    for _ in range(code.shape[0]):
        loss = np.abs(resid).sum()
        # print('-----', loss)
        dot = resid @ code.T
        # print(dot.round(2))
        next_atom = dot.argmax() % code.shape[0]
        # TODO: Decide if I want to add the atom with the single element
        # or the atom with the largest _summed_ dot-product.
        if dot[:, next_atom].sum() <= eps:
            # print("No atoms to add.", dot.round(5))
            break
        atoms.append(next_atom)
        # print(atoms)
        dictionary[atoms[-1]] = code[atoms[-1]]
        # print(dictionary)
        # encoding = sparse_encode(obs, dictionary=dictionary, algorithm='lasso_lars', positive=True, alpha=0.)
        encoding, _, _ = non_negative_factorization(obs, n_components=dictionary.shape[0], H=dictionary, update_H=False, alpha_W=0)
        # print(encoding.round(2))
        resid = obs - encoding @ code
        # print('-----', split_idx[atoms[-1]], dot_sum.max(), np.abs(resid).sum())
        # print(resid.round(2))
    # print(depth_row, encoding.sum(1))
    active_components = np.arange(encoding.shape[1])[encoding.sum(0) != 0]
    for i in atoms:
        u, w = split_idx[i]
        yield Split(u, v, w), g.vp.length[v], encoding[:, i]
    remaining_depth = g.vp.depth[v] - encoding.sum(1)
    if not np.allclose(remaining_depth, 0):
        yield Split(None, v, None), g.vp.length[v], g.vp.depth[v] - encoding.sum(1)

list(splits_from_sparse_encoding3(g, 0, eps=0.5))

In [None]:
from functools import partial

np.random.seed(1)

strain_length = 100
num_mutations = 50
strainA_path = np.arange(strain_length)

strainB_mutations = np.random.choice(strainA_path, size=num_mutations, replace=False)
strainB_path = strainA_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length, strain_length + num_mutations)

strainC_mutations = np.random.choice(strainA_path, size=num_mutations, replace=False)
strainC_path = strainA_path.copy()
inversion_start = strain_length // 4
inversion_end = 3 * strain_length // 4
strainC_path[inversion_start : inversion_end] = strainC_path[inversion_end : inversion_start : -1]
strainC_path_without_mutations = strainC_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + num_mutations, strain_length + 2 * num_mutations)

paths = dict(
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
)
depths = dict(
    a=[10, 1, 1],
    b=[1, 10, 1],
    c=[1, 1, 10],
)

nnodes = max(itertools.chain(*paths.values())) + 1
nsamples = 3

_depths = np.zeros((nsamples, nnodes))
for p in paths:
    _depths[:, paths[p]] += np.outer(depths[p], np.ones(len(paths[p])))
    
_depths

g8 = new_graph_from_merged_paths(
    paths.values(),
    depths=_depths,
    lengths=np.array([1] * nnodes),
)
draw_graph(g8)
print(g8)
mutate_compress_all_unitigs(g8)

for i in range(3):
    draw_graph(g8)
    print(g8)
    # print(i, g8)
    mutate_add_flows(g8, estimate_all_flows(g8, maxiter=1000, eps=1e-3))
    # print(total_depth_length(g8))
    mutate_split_all_nodes(g8, partial(splits_from_sparse_encoding3, eps=5))
    # print(total_depth_length(g8))
    mutate_compress_all_unitigs(g8)
    # print(total_depth_length(g8))
draw_graph(g8)
print(g8)

In [None]:
list(reversed(sorted(enumerate(g8.vp.length.a), key=lambda x: x[1])))[:10]

In [None]:
i = 98
s = np.array(g8.vp.sequence[i])
g8.vp.length[i], g8.vp.depth[i], dict(zip(*np.unique(s[1:] - s[:-1], return_counts=True)))

In [None]:
from functools import partial

np.random.seed(1)

strain_length = 1000
num_mutations = strain_length // 2
ancestral_path = np.arange(strain_length)

strainA_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainA_path = ancestral_path.copy()
strainA_path[strainA_mutations] = np.arange(strain_length + 0 * num_mutations, strain_length + 1 * num_mutations)

strainB_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainB_path = ancestral_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length + 1 * num_mutations, strain_length + 2 * num_mutations)

strainC_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainC_path = ancestral_path.copy()
inversion_start = strain_length // 4
inversion_end = 3 * strain_length // 4
strainC_path[inversion_start : inversion_end] = strainC_path[inversion_end : inversion_start : -1]
strainC_path_without_mutations = strainC_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + 2 * num_mutations, strain_length + 3 * num_mutations)

paths = dict(
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
    x=ancestral_path,
)
mean_depths = dict(
    a=[ 100,  0.1,  0.1 ],
    b=[ 0.1,  10 ,  0.1 ],
    c=[ 0.1,  0.1,  5   ],
    x=[ 0  ,  0  ,  0   ],
)


nnodes = max(itertools.chain(*paths.values())) + 1
nsamples = 3

expected_depths = np.zeros((nsamples, nnodes))
for p in paths:
    expected_depths[:, paths[p]] += np.outer(mean_depths[p], np.ones(len(paths[p])))
    
_depths = sp.stats.poisson(mu=expected_depths * 10).rvs()

g9 = new_graph_from_merged_paths(
    paths.values(),
    depths=_depths,
    lengths=np.array([1] * nnodes),
)
draw_graph(g9)
print(g9)
mutate_compress_all_unitigs(g9)

for i in range(3):
    draw_graph(g9)
    print(g9)
    # print(i, g9)
    mutate_add_flows(g9, estimate_all_flows(g9, maxiter=1000, eps=1e-3))
    # print(total_depth_length(g9))
    mutate_split_all_nodes(g9, partial(splits_from_sparse_encoding3, eps=50))
    # print(total_depth_length(g9))
    mutate_compress_all_unitigs(g9)
    # print(total_depth_length(g9))
draw_graph(g9)
print(g9)

In [None]:
vp = pd.DataFrame(dict(
    in_degree=g9.degree_property_map('in').a,
    out_degree=g9.degree_property_map('out').a,
    length=g9.vp.length.a,
))
vp = vp.join(pd.DataFrame(get_depth_matrix(g9).T, columns=['a', 'b', 'c']))
vp = vp.join(vp[['a', 'b', 'c']].multiply(vp.length, axis=0), rsuffix='_x_len')
vp.sort_values('a_x_len')

In [None]:
bins = np.linspace(-200, 1200, num=100)

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, figsize=(15, 5))
for (label, predicate), col in zip({
    '=0': (vp.in_degree + vp.out_degree) == 0,
    '=1': (vp.in_degree + vp.out_degree) == 1,
    '=2': (vp.in_degree + vp.out_degree) == 2,
    '>2': (vp.in_degree + vp.out_degree) > 2,
}.items(), axs.T):
    for sample, ax in zip(['a', 'b', 'c'], col):
        ax.hist(vp[predicate][sample], bins=bins, alpha=0.5, label=label)
        ax.legend(loc='upper right')
axs[0,0].set_yscale('symlog')
fig.tight_layout()

In [None]:
longest_seqs = vp.length.sort_values(ascending=False).head(10).index
vp.loc[longest_seqs]

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = longest_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g9.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query))
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
        
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)

In [None]:
from functools import partial

np.random.seed(1)

strain_length = 1000
num_mutations = strain_length // 2
ancestral_path = np.arange(strain_length)

strainA_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainA_path = ancestral_path.copy()
strainA_path[strainA_mutations] = np.arange(strain_length + 0 * num_mutations, strain_length + 1 * num_mutations)

strainB_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainB_path = ancestral_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length + 1 * num_mutations, strain_length + 2 * num_mutations)

strainC_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainC_path = ancestral_path.copy()
inversion_start = strain_length // 4
inversion_end = 3 * strain_length // 4
strainC_path[inversion_start : inversion_end] = strainC_path[inversion_end : inversion_start : -1]
strainC_path_without_mutations = strainC_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + 2 * num_mutations, strain_length + 3 * num_mutations)

paths = dict(
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
    x=ancestral_path,
)
mean_depths = dict(
    a=[100, 0, 0, 0, 0, 0, 0],
    b=[0, 10, 0, 0, 0, 0, 0],
    c=[0, 0, 1, 1, 1, 1, 1],
    x=[0, 0, 0, 0, 0, 0, 0],
)


nnodes = max(itertools.chain(*paths.values())) + 1
nsamples = 7

expected_depths = np.zeros((nsamples, nnodes))
for p in paths:
    expected_depths[:, paths[p]] += np.outer(mean_depths[p], np.ones(len(paths[p])))
    
_depths = sp.stats.poisson(mu=expected_depths * 10).rvs()

g10 = new_graph_from_merged_paths(
    paths.values(),
    depths=_depths,
    lengths=np.array([1] * nnodes),
)
draw_graph(g10)
print(g10)
mutate_compress_all_unitigs(g10)

for i in range(6):
    draw_graph(g10)
    print(g10)
    # print(i, g10)
    mutate_add_flows(g10, estimate_all_flows(g10, maxiter=1000, eps=1e-3))
    # print(total_depth_length(g10))
    mutate_split_all_nodes(g10, partial(splits_from_sparse_encoding3, eps=50))
    # print(total_depth_length(g10))
    mutate_compress_all_unitigs(g10)
    # print(total_depth_length(g10))
draw_graph(g10)
print(g10)

In [None]:
vp = pd.DataFrame(dict(
    in_degree=g10.degree_property_map('in').a,
    out_degree=g10.degree_property_map('out').a,
    length=g10.vp.length.a,
))
vp = vp.join(pd.DataFrame(get_depth_matrix(g10).T, columns=['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']))
vp = vp.join(vp[['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']].multiply(vp.length, axis=0), rsuffix='_x_len')
vp.sort_values('length', ascending=False).head(10)

In [None]:
longest_seqs = vp.length.sort_values(ascending=False).head(10).index
vp.loc[longest_seqs]

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = longest_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g10.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query)) + 1
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
        
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)

In [None]:
most_clike_seqs = vp.c_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_clike_seqs]

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = most_clike_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g10.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query))
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)

In [None]:
from functools import partial

np.random.seed(1)

strain_length = 1000
num_mutations = strain_length // 2
ancestral_path = np.arange(strain_length)

strainA_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainA_path = ancestral_path.copy()
strainA_path[strainA_mutations] = np.arange(strain_length + 0 * num_mutations, strain_length + 1 * num_mutations)

strainB_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainB_path = ancestral_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length + 1 * num_mutations, strain_length + 2 * num_mutations)

strainC_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainC_path = ancestral_path.copy()
inversion_start = strain_length // 4
inversion_end = 3 * strain_length // 4
strainC_path[inversion_start : inversion_end] = strainC_path[inversion_end : inversion_start : -1]
strainC_path_without_mutations = strainC_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + 2 * num_mutations, strain_length + 3 * num_mutations)

paths = dict(
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
    x=ancestral_path,
)
mean_depths = dict(
    a=[100, 0, 0, 0, 0, 0, 0],
    b=[0, 10, 0, 0, 0, 0, 0],
    c=[0, 0, 1, 1, 1, 1, 1],
    x=[0, 0, 0, 0, 0, 0, 0],
)


nnodes = max(itertools.chain(*paths.values())) + 1
nsamples = 7

expected_depths = np.zeros((nsamples, nnodes))
for p in paths:
    expected_depths[:, paths[p]] += np.outer(mean_depths[p], np.ones(len(paths[p])))
    
_depths = sp.stats.poisson(mu=expected_depths * 10).rvs()

g11 = new_graph_from_merged_paths(
    paths.values(),
    depths=_depths,
    lengths=np.array([1] * nnodes),
)
draw_graph(g11)
print(g11)
mutate_compress_all_unitigs(g11)

for i in range(7):
    draw_graph(g11)
    print(g11)
    # print(i, g11)
    mutate_add_flows(g11, estimate_all_flows(g11, maxiter=1000, eps=1e-3))
    # print(total_depth_length(g11))
    mutate_split_all_nodes(g11, partial(splits_from_sparse_encoding3, eps=20))
    # print(total_depth_length(g11))
    mutate_compress_all_unitigs(g11)
    # print(total_depth_length(g11))
draw_graph(g11)
print(g11)

In [None]:
vp = pd.DataFrame(dict(
    in_degree=g11.degree_property_map('in').a,
    out_degree=g11.degree_property_map('out').a,
    length=g11.vp.length.a,
))
vp = vp.join(pd.DataFrame(get_depth_matrix(g11).T, columns=['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']))
vp = vp.join(vp[['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']].multiply(vp.length, axis=0), rsuffix='_x_len')
vp.sort_values('length', ascending=False).head(10)

In [None]:
most_alike_seqs = vp.a_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_alike_seqs]

In [None]:
most_blike_seqs = vp.b_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_blike_seqs]

In [None]:
most_clike_seqs = vp.c_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_clike_seqs]

In [None]:
most_clike_seqs = vp.c_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_clike_seqs]

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = [698, 743, 439, 584, 1218, 845, 1173, 957, 141, 48, 473]

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g11.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query)) + 1
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = most_clike_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g11.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query)) + 1
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)

In [None]:
paths = [
    [0, 1, 2, 0],
    [0, 3, 4, 0],
    [0, 0],
]

nnodes = max(itertools.chain(*paths)) + 1

depths = np.array([
    [5, 5, 5, 0, 0],
    [4, 0, 0, 4, 4],
    [5, 1, 1, 1, 1],
    [11, 2, 2, 2, 2],
])
nsamples = depths.shape[0]

g = new_graph_from_merged_paths(
    paths,
    depths=depths,
    lengths=np.array([1] * nnodes),
)
mutate_add_flows(g, estimate_all_flows(g, eps=1e-4, maxiter=1000))
v = 0

def splits_from_sparse_encoding4(g, v, eps=1e-2):
    # Group Matching Pursuit (GMP)
    # Inspired by https://arxiv.org/pdf/1812.10538.pdf
    in_neighbors = list(sorted(g.get_in_neighbors(v)))
    num_in_neighbors = len(in_neighbors)
    in_neighbors_label = {k: v for k, v in enumerate(in_neighbors)}
    in_neighbors_onehot = {k: v for k, v in zip(in_neighbors, np.eye(num_in_neighbors))}
    in_neighbors_onehot[None] = np.zeros(num_in_neighbors)

    out_neighbors = list(sorted(g.get_out_neighbors(v)))
    num_out_neighbors = len(out_neighbors)
    out_neighbors_label = {k: v for k, v in enumerate(out_neighbors)}
    out_neighbors_onehot = {k: v for k, v in zip(out_neighbors, np.eye(num_out_neighbors))}
    out_neighbors_onehot[None] = np.zeros(num_out_neighbors)

    in_neighbor_flow = []
    for u in in_neighbors:
        in_neighbor_flow.append(g.ep.flow[g.edge(u, v)])
    # in_neighbor_flow = np.stack(in_neighbor_flow)

    out_neighbor_flow = []
    for w in out_neighbors:
        out_neighbor_flow.append(g.ep.flow[g.edge(v, w)])
    # out_neighbor_flow = np.stack(out_neighbor_flow)

    in_neighbor_code = []
    out_neighbor_code = []
    split_idx = {}
    for i, (u, w) in enumerate(product(in_neighbors + [None], out_neighbors + [None])):
        if (u, w) == (None, None):
            continue
            # NOTE: I DON'T include a node-only atom,
            # because I'll return this always and I don't want
            # to double-count.
        in_neighbor_code.append(in_neighbors_onehot[u])
        out_neighbor_code.append(out_neighbors_onehot[w])
        split_idx[i] = (u, w)
    in_neighbor_code = np.stack(in_neighbor_code)
    out_neighbor_code = np.stack(out_neighbor_code)

    depth_row = g.vp.depth[v].a
    obs = np.stack(in_neighbor_flow + [depth_row] + out_neighbor_flow).T
    code = np.concatenate([
        in_neighbor_code,
        np.ones((in_neighbor_code.shape[0], 1)),
        out_neighbor_code
    ], axis=1)
    # print(obs.round(2))
    # print(code)

    resid = obs
    atoms = []
    dictionary = np.zeros_like(code)
    encoding = np.zeros((obs.shape[0], dictionary.shape[0]))
    for _ in range(code.shape[0]):
        loss = np.abs(resid).sum()
        # print('-----', loss)
        dot = resid @ code.T
        # print(dot.round(2))
        next_atom = dot.sum(0).argmax()
        # TODO: Decide if I want to add the atom with the single element
        # or the atom with the largest _summed_ dot-product.
        if dot[:, next_atom].sum() <= eps:
            # print("No atoms to add.", dot.round(5))
            break
        atoms.append(next_atom)
        # print(atoms)
        dictionary[atoms[-1]] = code[atoms[-1]]
        # print(dictionary)
        # encoding = sparse_encode(obs, dictionary=dictionary, algorithm='lasso_lars', positive=True, alpha=0.)
        encoding, _, _ = non_negative_factorization(obs, n_components=dictionary.shape[0], H=dictionary, update_H=False, alpha_W=0)
        # print(encoding.round(2))
        resid = obs - encoding @ code
        # print('-----', split_idx[atoms[-1]], dot_sum.max(), np.abs(resid).sum())
        # print(resid.round(2))
    # print(depth_row, encoding.sum(1))
    active_components = np.arange(encoding.shape[1])[encoding.sum(0) != 0]
    for i in atoms:
        u, w = split_idx[i]
        yield Split(u, v, w), g.vp.length[v], encoding[:, i]
    remaining_depth = g.vp.depth[v] - encoding.sum(1)
    if not np.allclose(remaining_depth, 0):
        yield Split(None, v, None), g.vp.length[v], g.vp.depth[v] - encoding.sum(1)

list(splits_from_sparse_encoding4(g, 0, eps=0.5))

In [None]:
### from functools import partial

np.random.seed(1)

strain_length = 1000
num_mutations = strain_length // 2
ancestral_path = np.arange(strain_length)

strainA_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainA_path = ancestral_path.copy()
strainA_path[strainA_mutations] = np.arange(strain_length + 0 * num_mutations, strain_length + 1 * num_mutations)

strainB_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainB_path = ancestral_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length + 1 * num_mutations, strain_length + 2 * num_mutations)

strainC_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainC_path = ancestral_path.copy()
inversion_start = strain_length // 4
inversion_end = 3 * strain_length // 4
strainC_path[inversion_start : inversion_end] = strainC_path[inversion_end : inversion_start : -1]
strainC_path_without_mutations = strainC_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + 2 * num_mutations, strain_length + 3 * num_mutations)

paths = dict(
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
    x=ancestral_path,
)
mean_depths = dict(
    a=[100, 0, 0, 0, 0, 0, 0],
    b=[0, 10, 0, 0, 0, 0, 0],
    c=[0, 0, 1, 1, 1, 1, 1],
    x=[0, 0, 0, 0, 0, 0, 0],
)


nnodes = max(itertools.chain(*paths.values())) + 1
nsamples = 7

expected_depths = np.zeros((nsamples, nnodes))
for p in paths:
    expected_depths[:, paths[p]] += np.outer(mean_depths[p], np.ones(len(paths[p])))
    
_depths = sp.stats.poisson(mu=expected_depths * 10).rvs()

g12 = new_graph_from_merged_paths(
    paths.values(),
    depths=_depths,
    lengths=np.array([1] * nnodes),
)
draw_graph(g12)
print(g12)
mutate_compress_all_unitigs(g12)

for i in range(7):
    draw_graph(g12)
    print(g12)
    # print(i, g12)
    mutate_add_flows(g12, estimate_all_flows(g12, maxiter=1000, eps=1e-3))
    # print(total_depth_length(g12))
    mutate_split_all_nodes(g12, partial(splits_from_sparse_encoding4, eps=20))
    # print(total_depth_length(g12))
    mutate_compress_all_unitigs(g12)
    # print(total_depth_length(g12))
draw_graph(g12)
print(g12)

In [None]:
vp = pd.DataFrame(dict(
    in_degree=g12.degree_property_map('in').a,
    out_degree=g12.degree_property_map('out').a,
    length=g12.vp.length.a,
))
vp = vp.join(pd.DataFrame(get_depth_matrix(g11).T, columns=['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']))
vp = vp.join(vp[['a', 'b', 'c', 'c1', 'c2', 'c3', 'c4']].multiply(vp.length, axis=0), rsuffix='_x_len')
vp.sort_values('length', ascending=False).head(10)

In [None]:
most_clike_seqs = vp.c_x_len.sort_values(ascending=False).head(10).index
vp.loc[most_clike_seqs]

In [None]:
import matplotlib.pyplot as plt

ref_list = ['a', 'b', 'c']
query_list = most_clike_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = g11.vp.sequence[query_idx]
    ref = np.asanyarray(ref)
    query = np.asanyarray(query)
    length = max(len(ref), len(query)) + 1
    query = np.pad(query, (0, length - len(query)), constant_values=-1)
    ref = np.pad(ref, (0, length - len(ref)), constant_values=-1)
    dotplot = sp.spatial.distance.cdist(query.reshape((-1, 1)), ref.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(dotplot.T), marker='o', s=1)
    # ax.imshow(dotplot, cmap='Greys', origin='lower')
    ax.set_aspect('equal')
    ax.annotate(int(vp.loc[query_idx][ref_idx].round()), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)