In [None]:
    %load_ext autoreload
%autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
print(_os.path.realpath(_os.path.curdir))

In [None]:
import graph_tool as gt
import graph_tool.draw
import numpy as np
import pandas as pd
import scipy.sparse
import scipy as sp
from collections import defaultdict
from tqdm import tqdm
from functools import reduce
import operator
from itertools import product
from collections import namedtuple
from sklearn.decomposition import non_negative_factorization
from functools import partial
import itertools
from warnings import warn
from collections import defaultdict
import difflib
import matplotlib.pyplot as plt

In [None]:
# Graph generation

def path_to_edgelist(path):
    u = path[0]
    edges = []
    for v in path[1:]:
        edges.append((u, v))
        u = v
    return edges


def single_stranded_graph_from_merged_paths(paths, lengths, depths):
    g = gt.Graph()
    all_edges = []
    for p in paths:
        all_edges.extend(path_to_edgelist(p))
    g.add_edge_list(set(all_edges))
    g.vp['depth'] = g.new_vp('vector<float>')
    g.vp.depth.set_2d_array(depths)
    g.vp['length'] = g.new_vp('int', lengths)  
    g.gp['nsample'] = g.new_gp('int', len(depths))
    g.vp['sequence'] = g.new_vp('object', vals=[[k] for k in range(g.num_vertices())])
    g.ep['flow'] = g.new_ep('vector<float>', val=[1] * g.gp.nsample)
    return g


def single_stranded_graph_with_simulated_depth(paths, depths, length=None, scale_depth_by=1):
    if length is None:
        length = defaultdict(lambda: 1)

    nvertices = max(itertools.chain(*paths.values())) + 1
    nsamples = max(len(x) for x in depths.values())
    
    vertex_length = np.array([length[i] for i in range(nvertices)])

    expected_depths = np.zeros((nsamples, nvertices))
    for p in paths:
        expected_depths[:, paths[p]] += np.outer(np.array(depths[p]) * scale_depth_by, np.ones(len(paths[p])))
    
    # TODO: Consider using nbinom
    # # See docs for sp.stats.nbinom
    # sigma_sq = expected_depths + dispersion * expected_depths**2
    # p = expected_depths / sigma_sq
    # n = expected_depths**2 / (sigma_sq - expected_depths)
    _depths = sp.stats.poisson(mu=expected_depths).rvs()

    g = single_stranded_graph_from_merged_paths(
        paths.values(),
        depths=_depths,
        lengths=vertex_length,
    )
    return g

In [None]:
# Graph statistics

def depth_matrix(g, vs=None, samples=None):
    if vs is None:
        vs = g.get_vertices()
    if samples is None:
        samples = np.arange(g.gp.nsample)
    depth = g.vp.depth.get_2d_array(samples)
    return depth[:, vs]


def total_length_x_depth(g):
    return (depth_matrix(g) * g.vp.length.a).sum()


def edit_ratio(ref, query):
    diff = difflib.SequenceMatcher(a=ref, b=query)
    return diff.ratio() * (len(diff.a) + len(diff.b)) / (2 * len(diff.b))


def vertex_description(g, refs=None):
    if refs is None:
        refs = []
    vertex = pd.DataFrame(dict(
        in_degree=g.degree_property_map('in').a,
        out_degree=g.degree_property_map('out').a,
        length=g.vp.length.a,
    ))
    depth = pd.DataFrame(depth_matrix(g).T)
    depth.rename(columns=lambda i: f"d{i}")
    return vertex.join(depth)

def scale_ep(ep, maximum=2):
    g = ep.get_graph()
    return g.new_edge_property('float', vals=ep.a * maximum / ep.a.max())

def scale_vp(vp, maximum=10):
    g = vp.get_graph()
    return g.new_vertex_property('float', vals=vp.a * maximum / vp.a.max())

In [None]:
# Visualization

def draw_graph(g, output_size=(300, 300), ink_scale=0.8, **kwargs):
    kwargs = dict(
        vertex_fill_color=g.new_vertex_property('float', vals=np.linspace(0, 1, num=max(g.get_vertices()) + 1)),
        vertex_text=g.vertex_index,
    ) | kwargs
    return gt.draw.graph_draw(g, output_size=output_size, ink_scale=ink_scale, **kwargs)


def dotplot(pathA, pathB, ax=None, **scatter_kws):
    if ax is None:
        ax = plt.gca()
    pathA = np.asanyarray(pathA)
    pathB = np.asanyarray(pathB)
    length = max(len(pathA), len(pathB)) + 1
    pathA = np.pad(pathA, (0, length - len(pathA)), constant_values=-1)
    pathA = np.pad(pathA, (0, length - len(pathA)), constant_values=-1)
    match = sp.spatial.distance.cdist(pathA.reshape((-1, 1)), pathB.reshape((-1, 1)), metric=lambda x, y: x == y)
    ax.scatter(*np.where(match.T), **(dict(marker='o', s=1) | scatter_kws))
    ax.set_aspect('equal')

In [None]:
# Unitigs

def edge_has_no_siblings(g):
    "Check whether upstream or downstream sibling edges exist for every edge."
    vs = g.get_vertices()
    v_in_degree = g.new_vertex_property('int', vals=g.get_in_degrees(vs))
    v_out_degree = g.new_vertex_property('int', vals=g.get_out_degrees(vs))
    e_num_in_siblings = gt.edge_endpoint_property(g, v_in_degree, 'target')
    e_num_out_siblings = gt.edge_endpoint_property(g, v_out_degree, 'source')
    e_has_no_sibling_edges = g.new_edge_property('bool', (e_num_in_siblings.a <= 1) & (e_num_out_siblings.a <= 1))
    return e_has_no_sibling_edges


def vertex_does_not_have_both_multiple_in_and_multiple_out(g):
    vs = g.get_vertices()
    return g.new_vertex_property('bool', vals=(
        (g.get_in_degrees(vs) <= 1)
        | (g.get_out_degrees(vs) <= 1)
    ))


def label_maximal_unitigs(g):
    "Assign unitig indices to vertices in maximal unitigs."
    no_sibling_edges = edge_has_no_siblings(g)
    # Since any vertex that has both multiple in _and_ multiple out
    # edges cannot be part of a larger maximal unitig,
    # we could filter out these vertices, at the same time as we
    # are filtering out the edges with siblings.
    # Potentially this would make the component labeling step
    # much faster.
    both_sides_branch = vertex_does_not_have_both_multiple_in_and_multiple_out(g)
    # TODO: Double check, if this has any implications for the
    # "unitig-ness" of its neighbors. I _think_
    # if we mark edges with siblings before filtering out
    # these nodes we should be good.
    g_filt = gt.GraphView(
        g,
        efilt=no_sibling_edges,
        vfilt=both_sides_branch,
        directed=True
    )
    g_filt_undirected = gt.GraphView(g_filt, directed=False)
    # Since we've filtered out the both_sides_branch vertices,
    # the labels PropertyMap would include a bunch of the default value (0)
    # for these. Instead, we set everything not labeled to -1, now a magic
    # value for nodes definitely not in maximal unitigs.
    labels = g.new_vertex_property('int', val=-1)
    labels, counts = gt.topology.label_components(g_filt_undirected, vprop=labels)
    return labels, counts, g_filt


def is_simple_cycle(g, vfilt):
    g_filt = gt.GraphView(g, vfilt=vfilt)
    vs = list(g_filt.iter_vertices())
    path = list(gt.topology.all_paths(g_filt, vs[0], vs[0]))
    if not len(path) == 1:
        return False
    elif set(path[0]) != set(vs):
        return False
    else:
        return True


def maximal_unitigs(g):
    "Generate maximal unitigs as lists of vertices"
    # (1a) Filter edges
    # (1b) Label unitigs
    labels, counts, g_filt = label_maximal_unitigs(g)
    # (2) Find every node in the filtered graph without in-edges (these are the origins)
    is_origin = g_filt.new_vp('bool', g_filt.get_in_degrees(g_filt.get_vertices()) == 0)
    # (3) Find every node in the filtered graph without out-edges (these are the termina)
    is_terminus = g_filt.new_vp('bool', g_filt.get_out_degrees(g_filt.get_vertices()) == 0)
    # (4) Iter through the labels.
    for i, c in enumerate(counts):
        # (a) For each, select the subgraph for that unitig.
        vfilt = (labels.a == i)
        assert vfilt.sum() == c
        subgraph = gt.GraphView(g_filt, vfilt=vfilt)
        # (b) Identify the origin and terminus nodes in the path
        vs = subgraph.get_vertices()
        origin = gt.GraphView(subgraph, vfilt=is_origin).get_vertices()
        terminus = gt.GraphView(subgraph, vfilt=is_terminus).get_vertices()
        if (len(origin) == 0) and (len(terminus) == 0):
            origin, terminus = vs[0], vs[0]
        elif (len(origin) == 1) and (len(terminus) == 1):
            origin, terminus = origin[0], terminus[0]
        else:
            raise AssertionError("If there are multiple origins or termina, then it's not a unitig.")
        # (c) Trace the route from the origin to the terminus (`graph_tool.topology.all_paths`)
        unitig = list(gt.topology.all_paths(subgraph, origin, terminus))
        if len(unitig) == 0:
            continue  # Maximal unitig of length 1. No-op.
        assert len(unitig) == 1, "If there are multiple paths from origin to terminus, then it's not a unitig."
        unitig = unitig[0]
        # (d) Ask if it's a cycle.
        is_cycle = is_simple_cycle(g, vfilt)
        if is_cycle and (len(unitig) > 1) and (unitig[0] == unitig[-1]):
            unitig = unitig[:-1]
            assert len(set(unitig)) == len(unitig)
        # (e) Yield the route and whether it's a cycle.
        yield list(unitig), is_cycle
        
        
def list_unitig_neighbors(g, vs):
    "The in and out neighbors of a unitig path."
    all_ins = reduce(operator.add, map(lambda v: list(g.iter_in_neighbors(v)), vs))
    all_outs = reduce(operator.add, map(lambda v: list(g.iter_out_neighbors(v)), vs))
    return list(set(all_ins) - set(vs)), list(set(all_outs) - set(vs))


def mutate_add_compressed_unitig_vertex(g, vs, is_cycle, drop_vs=False):
    v = int(g.add_vertex())
    nsample = g.gp.nsample
    in_neighbors, out_neighbors = list_unitig_neighbors(g, vs)
    g.add_edge_list((neighbor, v) for neighbor in in_neighbors)
    g.add_edge_list((v, neighbor) for neighbor in out_neighbors)
    g.vp.length.a[v] = g.vp.length.a[vs].sum()
    g.vp.sequence[v] = reduce(operator.add, (g.vp.sequence[u] for u in vs), [])
    g.vp.depth[v] = (
        (
            depth_matrix(g, vs)
            * g.vp.length.a[vs]
        ).sum(1) / g.vp.length.a[v]
    )
    assert np.allclose(
        (depth_matrix(g, vs) * g.vp.length.a[vs]).sum(1),
        (depth_matrix(g, [v]) * g.vp.length.a[v]).sum(1)
    )
    if is_cycle:
        g.add_edge(v, v)
    for old_v in vs:
        g.clear_vertex(old_v)
    return g


def mutate_compress_all_unitigs(g):
    unitig_list = maximal_unitigs(g)
    all_vs = []
    for i, (vs, is_cycle) in enumerate(unitig_list):
        # TODO: If len(vs) == 1, this is effectively a no-op and can be dropped.
        mutate_add_compressed_unitig_vertex(g, vs, is_cycle)
        all_vs.extend(vs)
    g.remove_vertex(set(all_vs), fast=True)
    # I think, but am not sure, that the number of nodes removed will always equal the number of edges removed.
    return g

In [None]:
# Flows

def estimate_flow(f0, d, weight=None, eps=1e-2, maxiter=100):
    if weight is None:
        weight = np.ones_like(d)

    loss_hist = [np.inf]
    f = f0
    for step_i in range(maxiter):
        f_out = f
        f_total_out = f_out.sum(1)
        d_error_out = f_total_out - d
        
        f_in = f_out.T
        f_total_in = f_in.sum(1)
        d_error_in = f_total_in - d
        
        loss_hist.append(np.square(d_error_out).sum() + np.square(d_error_in).sum())
        loss_ratio = (loss_hist[-2] - loss_hist[-1]) / loss_hist[-2]
        if loss_ratio < eps:
            break
            
        allocation_out = f_out.T * np.nan_to_num(1 / f_total_out, posinf=1, nan=0)
        allocated_d_error_out = (allocation_out * d_error_out).T
        allocation_in = f_in.T * np.nan_to_num(1 / f_total_in, posinf=1, nan=0)
        allocated_d_error_in = (allocation_in * d_error_in).T

        # The final step is calculated as a average of the in and out error, weighted
        # by the node weight.
        inv_weight = 1 / weight
        mean_allocated_d_error = (
            ((allocated_d_error_in * inv_weight).T + (allocated_d_error_out * inv_weight))
            * (1 / (inv_weight.reshape((-1, 1)) + (inv_weight.reshape((1, -1)))))
        )
        
        f = (f_out - mean_allocated_d_error)
    else:
        warn(f"loss_ratio < eps ({eps}) not achieved in maxiter ({maxiter}) steps. Final loss_ratio={loss_ratio}. Final loss={loss_hist[-1]}.")
    return f


def estimate_all_flows(g, eps=1e-3, maxiter=1000, use_weights=True):
    flows = []
    if use_weights:
        weight = g.vp.length.a
    else:
        weight = None
    f0 = sp.sparse.csr_array(gt.spectral.adjacency(g))
    dd = depth_matrix(g)
    for sample_idx in range(g.gp.nsample):
        d = dd[sample_idx]
        f = estimate_flow(f0, d, weight=weight, eps=eps, maxiter=maxiter)
        flows.append(f)
    return flows


def mutate_add_flows(g, flows):
    props = []
    for sample_idx, f in enumerate(flows):
        p = g.new_edge_property('float', val=0)
        for i, j in g.get_edges():
            p[g.edge(i, j)] = f[j, i]
        props.append(p)
    props = gt.group_vector_property(props)
    g.ep['flow'] = props
    return g

In [None]:
# Node splitting

Split = namedtuple('Split', ['u', 'v', 'w'])


def splits_from_sparse_encoding(g, v, threshold=1.):
    # Compile tables
    in_neighbors = list(sorted(g.get_in_neighbors(v)))
    num_in_neighbors = len(in_neighbors)
    in_neighbors_label = {k: v for k, v in enumerate(in_neighbors)}
    in_neighbors_onehot = {k: v for k, v in zip(in_neighbors, np.eye(num_in_neighbors))}
    in_neighbors_onehot[None] = np.zeros(num_in_neighbors)

    out_neighbors = list(sorted(g.get_out_neighbors(v)))
    num_out_neighbors = len(out_neighbors)
    out_neighbors_label = {k: v for k, v in enumerate(out_neighbors)}
    out_neighbors_onehot = {k: v for k, v in zip(out_neighbors, np.eye(num_out_neighbors))}
    out_neighbors_onehot[None] = np.zeros(num_out_neighbors)

    in_neighbor_flow = []
    for u in in_neighbors:
        in_neighbor_flow.append(g.ep.flow[g.edge(u, v)])

    out_neighbor_flow = []
    for w in out_neighbors:
        out_neighbor_flow.append(g.ep.flow[g.edge(v, w)])

    in_neighbor_code = []
    out_neighbor_code = []
    split_idx = {}
    for i, (u, w) in enumerate(product(in_neighbors + [None], out_neighbors + [None])):
        if (u, w) == (None, None):
            # continue
            # # NOTE: I DON'T include a node-only atom,
            # # because I'll return this always and I don't want
            # # to double-count.
            # FIXME: Trying out a system where I compare the dotproduct for
            # the naked node to the other dot products.
            naked_vertex_idx = i
            pass
        in_neighbor_code.append(in_neighbors_onehot[u])
        out_neighbor_code.append(out_neighbors_onehot[w])
        split_idx[i] = (u, w)
    in_neighbor_code = np.stack(in_neighbor_code)
    out_neighbor_code = np.stack(out_neighbor_code)

    depth_row = g.vp.depth[v].a
    obs = np.stack(in_neighbor_flow + [depth_row] + out_neighbor_flow).T
    unnormalized_code = np.concatenate([
        in_neighbor_code,
        np.ones((in_neighbor_code.shape[0], 1)),
        out_neighbor_code
    ], axis=1)
    code_magnitude = np.sqrt(np.square(unnormalized_code).sum(1, keepdims=True))
    code = unnormalized_code / code_magnitude


    # Group Matching Pursuit (GMP)
    # Inspired by https://arxiv.org/pdf/1812.10538.pdf
    resid = obs
    atoms = []
    dictionary = np.zeros_like(code)
    normalized_encoding = np.zeros((obs.shape[0], dictionary.shape[0]))
    for _ in range(code.shape[0]):
        loss = np.abs(resid).sum()
        dot = resid @ code.T
        # TODO: Decide how to decide atoms
        # next_atom = np.square(dot).sum(0).argmax()
        next_atom = dot.argmax() % code.shape[0]
        naked_vertex_dot = dot[:, naked_vertex_idx]
        if dot[:, next_atom].sum() <= threshold:
            break
        if next_atom in atoms:
            break
        if next_atom == naked_vertex_idx:
            break  # TODO: Decide if this is a good stopping criterion.
        atoms.append(next_atom)
        dictionary[atoms[-1]] = code[atoms[-1]]
        normalized_encoding, _, _ = non_negative_factorization(obs, n_components=dictionary.shape[0], H=dictionary, update_H=False, alpha_W=0)
        resid = obs - normalized_encoding @ code
        

    # Iterate through atoms as splits.
    encoding = normalized_encoding / code_magnitude.T
    # import pdb; pdb.set_trace()
    for i in atoms:
        if i == naked_vertex_idx:
            # Don't return the naked vertex here. It'll be returned
            # later.
            continue
        u, w = split_idx[i]
        yield Split(u, v, w), g.vp.length[v], encoding[:, i]
    # Remaining depth must also include any depth assigned to the naked encoding.
    remaining_depth = g.vp.depth[v] - encoding.sum(1) + encoding[:, naked_vertex_idx]
    if not np.allclose(remaining_depth, 0):
        yield Split(None, v, None), g.vp.length[v], remaining_depth
        
        
def build_tables_from_splits(split_list, start_idx):
    """Generate edges to and from new, split vertices.
    
    Note that if splits from adjacent parents are not
    reciprocated, no new edge is produced.
    
    """
    split_idx = {}
    upstream = defaultdict(list)
    downstream = defaultdict(list)
    length = []
    depth = []
    for idx, (split, l, d) in enumerate(split_list, start=start_idx):
        u, v, w = split
        split_idx[split] = idx
        upstream[(v, w)].append(split)
        downstream[(u, v)].append(split)
        depth.append(d)
        length.append(l)
    assert len(split_list) == len(split_idx)
    return split_idx, upstream, downstream, np.array(length), np.array(depth)
        
        
def new_edges_from_splits(split_list, split_idx, upstream, downstream, start_idx):
    for v, (split, _, _) in enumerate(split_list, start=start_idx):
        u_old, v_old, w_old = split
        v = split_idx[split]
        
        # Upstream edges
        if u_old is not None:
            yield (u_old, v)
        for upstream_split in upstream[(u_old, v_old)]:
            u = split_idx[upstream_split]
            if u is not None:
                yield (u, v)
            
        # Downstream edges
        if w_old is not None:
            yield (v, w_old)
        for downstream_split in downstream[(v_old, w_old)]:
            w = split_idx[downstream_split]
            if w is not None:
                yield (v, w)
            
            
def mutate_apply_splits(g, split_list):
    """Add edges and drop any parent vertices that were split.
    
    """
    start_idx = g.num_vertices(ignore_filter=True)
    split_idx, upstream, downstream, lengths, depths = (
        build_tables_from_splits(split_list, start_idx=start_idx)
    )
    edges_to_add = list(set(new_edges_from_splits(
        split_list, split_idx, upstream, downstream, start_idx
    )))
    
    # FIXME:
    g_old = g.copy()
    
    # NOTE: Without adding vertices before edges I can get an IndexError
    # running `g.vp.length.a[np.arange(len(lengths)) + start_idx]`.
    # I believe this is because one or more split nodes
    # are without any new edges, and therefore these nodes don't
    # get added implicitly by `g.add_edge_list`.
    # When these accidentally hidden nodes are the highest valued
    # ones, they also don't get implicitly added due to their index.
    # The result is that I'm missing nodes that should actually exist.
    # NOTE: This line returns an unassigned generator. I _think_ all the
    # nodes are still added, but it's not entirely clear.
    # UPDATE: I'm sure the nodes are still added because of the following
    # assert.
    g.add_vertex(n=max(split_idx.values()) - max(g.get_vertices()))
    g.add_edge_list(set(edges_to_add))
    assert max(g.get_vertices()) == max(split_idx.values())
    
    g.vp.length.a[np.arange(len(lengths)) + start_idx] = lengths
    new_depth = depth_matrix(g)
    new_depth[:, np.arange(len(depths)) + start_idx] = depths.T
    g.vp.depth.set_2d_array(new_depth)
    for split, _, _ in split_list:
        g.vp.sequence[split_idx[split]] = g.vp.sequence[split.v]
    # for v in g.iter_vertices():
    #     assert g.vp.sequence[v]
    vertices_to_drop = set(split.v for (split, _, _) in split_list)
    g.remove_vertex(vertices_to_drop, fast=True)
    return g


def splits_for_all_vertices(g, split_func):
    for v in g.vertices():
        if (v.in_degree() < 2) and (v.out_degree() < 2):
            continue
        else:
            yield from split_func(g, v)


def mutate_split_all_nodes(g, split_func):
    # TODO: Vertices with <= 1 local path will be split into just
    # themselves pointing trivially at their neighbors.
    # These can be dropped as they are effectively a no-op.
    split_list = list(splits_for_all_vertices(g, split_func))
    if len(split_list) > 0:
        g = mutate_apply_splits(g, split_list=split_list)
    return g

In [None]:
np.random.seed(1)
gt.seed_rng(2)

paths = [
    [0, 2, 3],
    [1, 2, 4],
    # [1, 3, 4, 5],
    # [0, 0],
]

nnodes = max(itertools.chain(*paths)) + 1

depths = np.array([
    [4, 4, 4, 4, 4],
])
nsamples = depths.shape[0]

_g = single_stranded_graph_from_merged_paths(
    paths,
    depths=depths,
    lengths=[1, 1, 10, 1, 1],
    # lengths=np.array([1] * nnodes),
)

figsize=150
pos = draw_graph(_g, output_size=(figsize, figsize))
draw_graph(_g, pos=pos, vertex_text=_g.vp.length, output_size=(figsize, figsize))

f0 = sp.sparse.csr_array(gt.spectral.adjacency(_g))#.toarray()
d = _g.vp.depth.get_2d_array([0])[0]
# estimate_flow3(f0, d, eps=1e-2, maxiter=1000, weight=_g.vp.length.a).round(3)
# draw_graph(
#     _g,
#     pos=pos,
#     vertex_text=gt.ungroup_vector_property(_g.vp.depth, pos=[0])[0],
#     edge_pen_width=gt.ungroup_vector_property(_g.ep.flow, pos=[0])[0],
# )

mutate_add_flows(_g, estimate_all_flows3(_g))
draw_graph(
    _g,
    pos=pos,
    vertex_text=gt.ungroup_vector_property(_g.vp.depth, pos=[0])[0],
    edge_pen_width=gt.ungroup_vector_property(_g.ep.flow, pos=[0])[0],
    output_size=(figsize, figsize),
)
# mutate_add_flows(_g, estimate_all_flows3(_g))
# draw_graph(
#     _g,
#     pos=pos,
#     vertex_text=gt.ungroup_vector_property(_g.vp.depth, pos=[0])[0],
#     edge_pen_width=gt.ungroup_vector_property(_g.ep.flow, pos=[0])[0],
#     output_size=(figsize, figsize),
# )
estimate_all_flows(_g)[0].toarray().round(2)

In [None]:
### from functools import partial

np.random.seed(1)
gt.seed_rng(1)

strain_length = 500
num_mutations = strain_length // 2
ancestral_path = np.arange(strain_length)

strainA_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainA_path = ancestral_path.copy()
strainA_path[strainA_mutations] = np.arange(strain_length + 0 * num_mutations, strain_length + 1 * num_mutations)

strainB_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainB_path = ancestral_path.copy()
strainB_path[strainB_mutations] = np.arange(strain_length + 1 * num_mutations, strain_length + 2 * num_mutations)

strainC_mutations = np.random.choice(ancestral_path, size=num_mutations, replace=False)
strainC_path = ancestral_path.copy()
strainC_path[strainC_mutations] = np.arange(strain_length + 2 * num_mutations, strain_length + 3 * num_mutations)

num_errors = 100
error_path = np.arange(strain_length + 3 * num_mutations, strain_length + 3 * num_mutations + num_errors, step=0.5)
error_path[1::2] = np.random.choice(np.arange(strain_length + 3 * num_mutations), size=num_errors)
error_path = list(error_path.astype(int))

paths = dict(
    x=ancestral_path,
    a=strainA_path,
    b=strainB_path,
    c=strainC_path,
    error=error_path,
)

sample_replicates = 1
mean_depths = dict(
    a=[0.90, 0.05, 0.05] * sample_replicates,
    b=[0.05, 0.90, 0.05] * sample_replicates,
    c=[0.05, 0.05, 0.90] * sample_replicates,
    x=[0.00, 0.00, 0.00] * sample_replicates,
    error=[0.0001, 0.0001, 0.001],
)

g = []
g.append(
    single_stranded_graph_with_simulated_depth(
        paths,
        depths=mean_depths, scale_depth_by=20,
))

figsize = 250
# draw_graph(g0, output_size=(figsize, figsize))
# g1 = mutate_compress_all_unitigs(g0.copy())
# mutate_add_flows(g[-1], estimate_all_flows(g[-1]))

# pos = draw_graph(g[-1], output_size=(figsize, figsize))
# for i in range(3):
#     draw_graph(
#         g[-1],
#         pos=pos,
#         vertex_text='',
#         vertex_size=scale_vp(gt.ungroup_vector_property(g[-1].vp.depth, pos=[i])[0]),
#         edge_pen_width=scale_flow_ep(gt.ungroup_vector_property(g[-1].ep.flow, pos=[i])[0]),
#         output_size=(figsize, figsize),
#     )
    
# g.append(g[-1].copy())
mutate_compress_all_unitigs(g[-1])
# mutate_add_flows(g[-1], estimate_all_flows(g[-1]))
pos = draw_graph(g[-1], output_size=(figsize, figsize))
# for i in range(3):
#     draw_graph(
#         g[-1],
#         pos=pos,
#         vertex_text='',
#         vertex_size=scale_vp(gt.ungroup_vector_property(g[-1].vp.depth, pos=[i])[0]),
#         edge_pen_width=scale_flow_ep(gt.ungroup_vector_property(g[-1].ep.flow, pos=[i])[0]),
#         output_size=(figsize, figsize),
#     )

In [None]:
g = g[:1]

for v in g[-1].iter_vertices():
    assert g[-1].vp.sequence[v]

In [None]:
thresh = 2

g.append(g[-1].copy())
mutate_add_flows(g[-1], estimate_all_flows(g[-1], use_weights=True))
mutate_split_all_nodes(g[-1], partial(splits_from_sparse_encoding, threshold=thresh))
# mutate_add_flows(g[-1], estimate_all_flows(g[-1]))

# pos = draw_graph(g[-1], output_size=(figsize, figsize))
# for i in range(3):
#     draw_graph(
#         g[-1],
#         pos=pos,
#         vertex_text='',
#         vertex_size=scale_vp(gt.ungroup_vector_property(g[-1].vp.depth, pos=[i])[0]),
#         edge_pen_width=scale_flow_ep(gt.ungroup_vector_property(g[-1].ep.flow, pos=[i])[0]),
#         output_size=(figsize, figsize),
#     )

# g.append(g[-1].copy())
mutate_compress_all_unitigs(g[-1])
# mutate_add_flows(g[-1], estimate_all_flows(g[-1]))
pos = draw_graph(g[-1], output_size=(figsize, figsize))
# for i in range(3):
#     draw_graph(
#         g[-1],
#         pos=pos,
#         vertex_text='',
#         vertex_size=scale_vp(gt.ungroup_vector_property(g[-1].vp.depth, pos=[i])[0]),
#         edge_pen_width=scale_flow_ep(gt.ungroup_vector_property(g[-1].ep.flow, pos=[i])[0]),
#         output_size=(figsize, figsize),
#     )

In [None]:
_g = g[-1]

print(total_length_x_depth(g[0]) - total_length_x_depth(g[-1]))
vp = vertex_description(_g).rename(columns={0: 'a', 1: 'b', 2: 'c'})
longest_seqs = vp.sort_values('length', ascending=False).head(10).index
vp.loc[longest_seqs]

In [None]:
ref_list = ['a', 'b', 'c']
query_list = longest_seqs

fig, axs = plt.subplots(len(query_list), len(ref_list), figsize=(3 * len(ref_list), 3 * len(query_list)), sharex=True, sharey=True)
axs = axs.reshape((len(query_list), len(ref_list)))

for (i, query_idx), (j, ref_idx), in product(enumerate(query_list), enumerate(ref_list)):
    ax = axs[i, j]
    ref = paths[ref_idx]
    query = _g.vp.sequence[query_idx]
    dotplot(query, ref, ax=ax, marker='o', s=1)
    ax.annotate(vp.loc[query_idx][ref_idx].round(2), xy=(0.1, 0.9), xycoords='axes fraction', va='top')
    ax.annotate(round(edit_ratio(ref, query), 2), xy=(0.7, 0.1), xycoords='axes fraction', va='bottom')
        
for ref_idx, bottom_row_ax in zip(ref_list, axs[-1,:]):
    bottom_row_ax.set_xlabel(ref_idx)
    
for query_idx, left_column_ax in zip(query_list, axs[:, 0]):
    left_column_ax.set_ylabel(query_idx)