In [None]:
!date

In [None]:
%load_ext autoreload
%load_ext line_profiler

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])

In [None]:
import strainzip as sz
import graph_tool as gt
import graph_tool.draw
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
from contextlib import contextmanager
import xarray as xr
from itertools import product
from tqdm import tqdm
from itertools import chain
from strainzip.pandas_util import idxwhere
from graph_tool.util import find_edge

In [None]:
length_bins = np.logspace(0, 6.5, num=51)
depth_bins = np.logspace(-1, 4, num=51)

k = 111

In [None]:
# Load depth data
depth_table = xr.load_dataarray(f'examples/xjin_test4/r.proc.kmtricks-k{k}-m3-r2.ggcat.unitig_depth.nc')
depth_table.sizes

In [None]:
with open(f'examples/xjin_test4/r.proc.kmtricks-k{k}-m3-r2.ggcat.fn') as f:
    _, seqs = sz.io.load_graph_and_sequences_from_linked_fasta(f, k=k, header_tokenizer=sz.io.ggcat_header_tokenizer)

In [None]:
# Load graph
graph = sz.io.load_graph(f'examples/xjin_test4/r.proc.kmtricks-k{k}-m3-r2.ggcat.gt')
# FIXME: These annotations should go into the loading app:
graph.gp['num_samples'] = graph.new_graph_property('int', val=depth_table.sizes['sample'])
graph.gp['kmer_length'] = graph.new_graph_property('int', val=k)

# Set depth on graph
vertex_unitig_order = [int(s[:-1]) for s in graph.vp['sequence']]
graph.vp['depth'] = graph.new_vertex_property('vector<float>')
graph.vp['depth'].set_2d_array(depth_table.sel(unitig=vertex_unitig_order).T.values)

In [None]:
# Select components in a deterministic way (from largest to smallest).

component_graphs = []

graph_remaining = graph.new_vertex_property('bool', val=True)

last_graph_size = 1_000_000
while last_graph_size > 1000:
    this_component = gt.topology.label_largest_component(gt.GraphView(graph, vfilt=graph_remaining), directed=False)
    component_graphs.append(gt.GraphView(graph, vfilt=this_component))
    graph_remaining = graph.new_vertex_property('bool', vals=graph_remaining.a - this_component.a)
    last_graph_size = this_component.a.sum()

len(component_graphs)

In [None]:
# The largest components has a huge fraction of the unitigs
component_graphs[0], component_graphs[1], component_graphs[2], component_graphs[3], component_graphs[4]

In [None]:
c = 0

draw_graphs = False

# component = c
component = 16  # Only the label for plotting

graph2 = gt.Graph(component_graphs[c], prune=True)
# graph2.ep['filter'] = graph2.new_edge_property('bool',   # TODO: Think about filtering edges instead of removing them entirely.
graph2.set_vertex_filter(graph2.vp['filter'])

np.random.seed(1)
gt.seed_rng(1)

In [None]:
if draw_graphs:
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    sz.draw.update_xypositions(graph2, vweight=total_bases)
    gm = sz.graph_manager.GraphManager(
        unzippers=[
            sz.graph_manager.LengthUnzipper(),
            sz.graph_manager.SequenceUnzipper(),
            sz.graph_manager.VectorDepthUnzipper(),
            sz.graph_manager.PositionUnzipper(offset=(0.1, 0.1)),
        ],
        pressers=[
            sz.graph_manager.LengthPresser(),
            sz.graph_manager.SequencePresser(sep=","),
            sz.graph_manager.VectorDepthPresser(),
            sz.graph_manager.PositionPresser(),
        ],
    )
else:
    gm = sz.graph_manager.GraphManager(
        unzippers=[
            sz.graph_manager.LengthUnzipper(),
            sz.graph_manager.SequenceUnzipper(),
            sz.graph_manager.VectorDepthUnzipper(),
        ],
        pressers=[
            sz.graph_manager.LengthPresser(),
            sz.graph_manager.SequencePresser(sep=","),
            sz.graph_manager.VectorDepthPresser(),
        ],
    )
gm.validate(graph2)

In [None]:
graph3 = graph2.copy()  # Save for later plotting
sz.stats.degree_stats(graph3)

In [None]:
assembly_stage = 0

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
# Depth Smoothing

smoothed_depths = []
for i in range(graph2.gp['num_samples']):
    one_depth = gt.ungroup_vector_property(graph2.vp.depth, pos=[i])[0]
    smoothed, _change = sz.flow.smooth_depth(graph2, one_depth, graph2.vp.length, inertia=0.5, num_iter=50)
    print(_change)
    smoothed_depths.append(smoothed)

smoothed_depths = gt.group_vector_property(smoothed_depths) #

In [None]:
assembly_stage = 1

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
# Actually push smoothed depths to the graph
graph2.vp['depth'] = smoothed_depths  # TODO: Experiment with and without this.

# FIXME: Long tips lose too much depth?
# NOTE: It's possible that depth smoothing introduces artifacts at junctions that affects how they're split...?

In [None]:
# TODO: Consider dropping low depth vertices/edges
# depth_thresh = 0.1
# # Drop edges with low depth
# low_depth_edge = graph2.new_edge_property('float', vals=flow.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0) < depth_thresh)
# low_depth_edges = find_edge(graph2, low_depth_edge, True)
# for e in low_depth_edges:
#     graph2.remove_edge(e)
# low_depth_vertices = idxwhere(sz.results.extract_vertex_data(graph2, seqs).total_depth < depth_thresh)
# print(len(tips), len(low_depth_vertices), len(set(tips) & set(low_depth_vertices)))

In [None]:
# Trim tips
tips = sz.assembly.find_tips(graph2, also_required=graph2.vp['length'].a < graph2.gp['kmer_length'])
print(len(tips))
gm.batch_trim(graph2, tips)

graph4 = graph2.copy()  # Save for later plotting

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
# Second round of tip trimming
tips = sz.assembly.find_tips(graph2, also_required=graph2.vp['length'].a < graph2.gp['kmer_length'])
print(len(tips))
gm.batch_trim(graph2, tips)
_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
print(len(_new_tigs))

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 2

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for j in tqdm(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        # print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        # print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
        pass
    elif not X[:, paths].sum(1).min() == 1:
        # print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
        pass
    elif not len(paths) <= max(n, m):
        # print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
        pass
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        # print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
        pass
    else:
        # print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
# _new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
# len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 3

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 4

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 5

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 6

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 7

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 8

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 9

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 10

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
junctions = sz.assembly.find_junctions(graph2)
print(len(junctions))

batch = []
for i, j in enumerate(junctions):
    in_edge_vertices = [edge[0] for edge in graph2.get_in_edges(j)]
    out_edge_vertices = [edge[1] for edge in graph2.get_out_edges(j)]
    
    in_edge_flows = np.stack([flow[edge] for edge in graph2.get_in_edges(j)])
    out_edge_flows = np.stack([flow[edge] for edge in graph2.get_out_edges(j)])
    log_offset_ratio = np.log(in_edge_flows.sum()) - np.log(out_edge_flows.sum())

    # Balance flows before fitting.
    in_edge_flows = np.exp(np.log(in_edge_flows) - log_offset_ratio / 2)
    out_edge_flows = np.exp(np.log(out_edge_flows) + log_offset_ratio / 2)
    
    n, m = len(in_edge_vertices), len(out_edge_vertices)
    if n * m > 20:
        print(f"[junc={j} / {n}x{m}] Too many possible paths.")
        continue
    X = sz.deconvolution.design_paths(n, m)[0]
    fit, paths, named_paths, score_margin = sz.deconvolution.deconvolve_junction(
        in_edge_vertices,
        in_edge_flows,
        out_edge_vertices,
        out_edge_flows,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=1.,
    )
    if not (score_margin > 20):  # TODO: Consider selecting non-best models that have a small enough score margin, after using a more negative backward_stop threshold.
        print(f"[junc={j} / {n}x{m}] Cannot pick best model. (Selected model had {len(paths)} paths; score margin: {score_margin})")
    elif not X[:, paths].sum(1).min() == 1:
        print(f"[junc={j} / {n}x{m}] Non-complete. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not len(paths) <= max(n, m):
        print(f"[junc={j} / {n}x{m}] Non-minimal. (Best model had {len(paths)} paths; score margin: {score_margin})")
    elif not (np.linalg.cond(fit.hessian_beta) < 1e5):
        print(f"[junc={j} / {n}x{m}] Non-identifiable. (Best model had {len(paths)} paths; score margin: {score_margin})")
    else:
        print(f"[junc={j} / {n}x{m}] SUCCESS! Selected {len(paths)} paths; score margin: {score_margin}")
        batch.append((j, named_paths, {"path_depths": fit.beta.clip(0)}))

print(len(batch) / len(junctions))

In [None]:
_new_tigs = gm.batch_unzip(graph2, *batch)
print(len(_new_tigs))

_new_tigs = gm.batch_press(graph2, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(graph2)])
len(_new_tigs)

In [None]:
sz.stats.degree_stats(graph2)

In [None]:
assembly_stage = 11

# Calculate Flows
flow = []
for sample_id in range(graph2.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(graph2, gt.ungroup_vector_property(graph2.vp['depth'], pos=[sample_id])[0], graph2.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(graph2.gp['num_samples']))

# Initial depths
plt.hist2d(
    graph2.vp['length'].fa,
    graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    # Update positions
    total_bases = graph2.new_vertex_property('float', vals=graph2.vp.length.fa * graph2.vp.depth.get_2d_array(pos=range(graph2.gp['num_samples'])).sum(0))
    # sz.draw.update_xypositions(graph2, vweight=total_bases, max_iter=100, init_step=1)

    _color = graph2.new_vertex_property('float', vals=graph2.vp['depth'].get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2))
    _width = graph2.new_edge_property('float', vals=flow.get_2d_array(range(graph2.gp['num_samples'])).sum(0) ** (1/2) / 2)
    sz.draw.draw_graph(
        graph2,
        vertex_text=graph2.vp['length'],
        vertex_fill_color=_color,
        # edge_color=flow,
        # edge_pen_width=_width,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
# WORKHERE

In [None]:
vertex_results0 = sz.results.extract_vertex_data(graph2).assign(assembly=lambda d: d.segments.apply(sz.results.assemble_overlapping_unitigs, unitig_to_sequence=seqs, k=graph2.gp['kmer_length']))
# vertex_results = sz.results.deduplicate_vertex_data(vertex_results0)

In [None]:
# Low depth
vertex_results0[lambda x: (x.total_depth < 10) & (x.total_depth > 2)].sort_values('length', ascending=False).head(10)

In [None]:
# Find cycles
vertex_results0[vertex_results0.apply(lambda d: (d.name in d.in_neighbors) | (d.name in d.out_neighbors), axis=1)].sort_values('num_segments', ascending=False)

In [None]:
assembly_stage = '_final'
v = 202452

print(v)
print(graph2.vp.length[v])
print(graph2.vp.depth[v])
print(graph2.vp.sequence[v])
print()

sns.heatmap(depth_table.sel(unitig=[int(s[:-1]) for s in graph2.vp.sequence[v].split(',')]).to_pandas().T, norm=mpl.colors.SymLogNorm(1e-1))

# Flag nodes in sequence v
in_seq = graph4.new_vertex_property('bool', val=False)
gt.map_property_values(graph4.vp.sequence, in_seq, lambda x: x in graph2.vp.sequence[v].split(','))

one_depth = graph4.new_vertex_property('float', graph4.vp['depth'].get_2d_array(pos=range(graph4.gp['num_samples'])).mean(0))
one_flow, _, _, = sz.flow.estimate_flow(graph4, one_depth, graph4.vp['length'])
_color = graph4.new_vertex_property('float', vals=np.sqrt(one_depth.a))

if draw_graphs:
    outpath = f'nb/fig/component-{component}/graph_stage{assembly_stage}_seq{v}_id.pdf'
    print(outpath)
    sz.draw.draw_graph(
        graph4,
        vertex_text=graph4.vp['sequence'],
        vertex_halo=in_seq,
        # vertex_text=in_seq,
        vertex_font_size=1,
        vertex_fill_color=_color,
        edge_pen_width=graph4.new_edge_property('float', vals=one_flow.a ** (1/5)),
        output=outpath,
        vcmap=(mpl.cm.magma, 1),
    )

In [None]:
vertex_results0.segments.explode().value_counts().sort_values(ascending=False).head(10)

In [None]:
assert False

In [None]:
u = "71703-"  # Focal segment/unitig
# Get list of sequences with segment u
vertex_list = idxwhere(vertex_results0.segments.apply(lambda x: u in x))
unitigs = [int(s[:-1]) for s in chain(*vertex_results0.loc[vertex_list].segments)]

d = depth_table.sel(unitig=unitigs).to_pandas().T
sns.clustermap(d, norm=mpl.colors.SymLogNorm(1e-1))

path = f'nb/fig/component-{component}/seqs_stage_final_node{u}.fn'
with open(path, 'w') as f:
    for vertex, d1 in vertex_results0.loc[vertex_list].iterrows():
        print(f">{vertex}\n{d1.assembly}", file=f)
print(path)

vertex_results0.loc[vertex_list]