In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])

In [None]:
import strainzip as sz
import graph_tool as gt
import graph_tool.draw
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
from contextlib import contextmanager
import xarray as xr
from itertools import product
from tqdm import tqdm
from itertools import chain
from strainzip.pandas_util import idxwhere
from graph_tool.util import find_edge

In [None]:
import jax

from jax.experimental.compilation_cache import compilation_cache as _cc
_cc.set_cache_dir("/tmp/jax-cache")  # FIXME (2024-04-20): This should definitely not be hard-coded in.

import logging
logging.getLogger("jax").setLevel(logging.DEBUG)

In [None]:
# Plotting parameters

length_bins = np.logspace(0, 6.5, num=51)
depth_bins = np.logspace(-1, 4, num=51)

In [None]:
# Load graph
full_graph = sz.io.load_graph(f'examples/xjin_test4/r.proc.kmtricks-k111-m3-r2.ggcat.gt')
full_graph.gp['kmer_length'] = full_graph.new_graph_property('int', val=full_graph.gp['kmer_size'])  # FIXME: Renaming this due to an oversight that has been since fixed.

In [None]:
with open(f'examples/xjin_test4/r.proc.kmtricks-k111-m3-r2.ggcat.fn') as f:
    _, unitig_to_sequence = sz.io.load_graph_and_sequences_from_linked_fasta(f, 111, sz.io.ggcat_header_tokenizer)

In [None]:
# Select components in a deterministic way (from largest to smallest).

component_graphs = []

graph_remaining = full_graph.new_vertex_property('bool', val=True)

last_graph_size = 1_000_000
while last_graph_size > 1000:
    this_component = gt.topology.label_largest_component(gt.GraphView(full_graph, vfilt=graph_remaining), directed=False)
    component_graphs.append(gt.GraphView(full_graph, vfilt=this_component))
    graph_remaining = full_graph.new_vertex_property('bool', vals=graph_remaining.a - this_component.a)
    last_graph_size = this_component.a.sum()

len(component_graphs)

In [None]:
# The largest components has a huge fraction of the unitigs
component_graphs[0], component_graphs[1], component_graphs[2], component_graphs[3], component_graphs[4], component_graphs[5], component_graphs[6]

In [None]:
c = 5

draw_graphs = True

# component = c
component = 18  # Only the label for plotting

curr_graph = gt.Graph(component_graphs[c], prune=True)
curr_graph.set_vertex_filter(curr_graph.vp['filter'])

np.random.seed(1)
gt.seed_rng(1)

In [None]:
if draw_graphs:
    total_bases = curr_graph.new_vertex_property('float', vals=curr_graph.vp.length.fa * curr_graph.vp.depth.get_2d_array(pos=range(curr_graph.gp['num_samples'])).sum(0))
    sz.draw.update_xypositions(curr_graph, vweight=total_bases)
    gm = sz.graph_manager.GraphManager(
        unzippers=[
            sz.graph_manager.LengthUnzipper(),
            sz.graph_manager.SequenceUnzipper(),
            sz.graph_manager.VectorDepthUnzipper(),
            sz.graph_manager.PositionUnzipper(offset=(0.1, 0.1)),
        ],
        pressers=[
            sz.graph_manager.LengthPresser(),
            sz.graph_manager.SequencePresser(sep=","),
            sz.graph_manager.VectorDepthPresser(),
            sz.graph_manager.PositionPresser(),
        ],
    )
else:
    gm = sz.graph_manager.GraphManager(
        unzippers=[
            sz.graph_manager.LengthUnzipper(),
            sz.graph_manager.SequenceUnzipper(),
            sz.graph_manager.VectorDepthUnzipper(),
        ],
        pressers=[
            sz.graph_manager.LengthPresser(),
            sz.graph_manager.SequencePresser(sep=","),
            sz.graph_manager.VectorDepthPresser(),
        ],
    )
gm.validate(curr_graph)

In [None]:
original_graph = curr_graph.copy()  # Save for later plotting
sz.stats.degree_stats(curr_graph)

In [None]:
assembly_stage = 1

# Calculate Flows
flow = []
for sample_id in range(curr_graph.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(curr_graph, gt.ungroup_vector_property(curr_graph.vp['depth'], pos=[sample_id])[0], curr_graph.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(curr_graph.gp['num_samples']))

# Initial depths
plt.hist2d(
    curr_graph.vp['length'].fa,
    curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    _color = curr_graph.new_vertex_property('float', vals=curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0) ** (1/2))
    sz.draw.draw_graph(
        curr_graph,
        vertex_text=curr_graph.vp['length'],
        vertex_fill_color=_color,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
# TODO: Consider dropping low depth vertices/edges
# depth_thresh = 0.1
# # Drop edges with low depth
# low_depth_edge = curr_graph.new_edge_property('float', vals=flow.get_2d_array(pos=range(curr_graph.gp['num_samples'])).sum(0) < depth_thresh)
# low_depth_edges = find_edge(curr_graph, low_depth_edge, True)
# for e in low_depth_edges:
#     curr_graph.remove_edge(e)
# low_depth_vertices = idxwhere(sz.results.extract_vertex_data(curr_graph, seqs).total_depth < depth_thresh)
# print(len(tips), len(low_depth_vertices), len(set(tips) & set(low_depth_vertices)))

In [None]:
# Trim tips
tips = sz.assembly.find_tips(curr_graph, also_required=curr_graph.vp['length'].a < curr_graph.gp['kmer_length'])
print(len(tips))
gm.batch_trim(curr_graph, tips)

original_graph_no_tips = curr_graph.copy()  # Save for later plotting

new_tigs = gm.batch_press(curr_graph, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(curr_graph)])
print(len(new_tigs))

In [None]:
# Second round of tip trimming
second_round_tips = sz.assembly.find_tips(curr_graph, also_required=curr_graph.vp['length'].a < curr_graph.gp['kmer_length'])
print(len(second_round_tips))
gm.batch_trim(curr_graph, second_round_tips)
second_round_new_tigs = gm.batch_press(curr_graph, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(curr_graph)])
print(len(second_round_new_tigs))

In [None]:
assembly_stage = 2

flow = sz.flow.estimate_all_flows(curr_graph)

# Initial depths
plt.hist2d(
    curr_graph.vp['length'].fa,
    curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    _color = curr_graph.new_vertex_property('float', vals=curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0) ** (1/2))
    sz.draw.draw_graph(
        curr_graph,
        vertex_text=curr_graph.vp['length'],
        vertex_fill_color=_color,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
assembly_iters = 3

for i in range(assembly_iters):
    print(f"Deconvolution round {i}.")
    flow = sz.flow.estimate_all_flows(curr_graph)
    deconvolutions = sz.assembly.parallel_calculate_all_junction_deconvolutions(
        curr_graph,
        flow,
        forward_stop=0.0,
        backward_stop=0.0,
        alpha=1.0,
        score_margin_thresh=20.,
        condition_thresh=1e5,
        max_paths=20,
        processes=2,
    )
    new_unzipped_vertices = gm.batch_unzip(curr_graph, *deconvolutions)    
    new_pressed_vertices = gm.batch_press(curr_graph, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(curr_graph)])
    print(f"Unzipped: {len(new_unzipped_vertices)} resulting in joining old tigs into {len(new_pressed_vertices)} new tigs out of {curr_graph.num_vertices()}.")
    if len(new_unzipped_vertices) == 0:
        print("No vertices unzipped. Stopping early.")
        break

In [None]:
assembly_stage = 3

flow = sz.flow.estimate_all_flows(curr_graph)

# Initial depths
plt.hist2d(
    curr_graph.vp['length'].fa,
    curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    _color = curr_graph.new_vertex_property('float', vals=curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0) ** (1/2))
    sz.draw.draw_graph(
        curr_graph,
        vertex_text=curr_graph.vp['length'],
        vertex_fill_color=_color,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
assembly_iters = 3

for i in range(assembly_iters):
    print(f"Deconvolution round {i}.")
    flow = sz.flow.estimate_all_flows(curr_graph)
    deconvolutions = sz.assembly.parallel_calculate_all_junction_deconvolutions(
        curr_graph,
        flow,
        forward_stop=0.0,
        backward_stop=0.0,
        alpha=1.0,
        score_margin_thresh=20.,
        condition_thresh=1e5,
        max_paths=20,
        processes=2,
    )
    new_unzipped_vertices = gm.batch_unzip(curr_graph, *deconvolutions)    
    new_pressed_vertices = gm.batch_press(curr_graph, *[(path, {}) for path in sz.assembly.iter_maximal_unitig_paths(curr_graph)])
    print(f"Unzipped: {len(new_unzipped_vertices)} resulting in joining old tigs into {len(new_pressed_vertices)} new tigs out of {curr_graph.num_vertices()}.")
    if len(new_unzipped_vertices) == 0:
        print("No vertices unzipped. Stopping early.")
        break

In [None]:
assembly_stage = 4

# Calculate Flows
flow = []
for sample_id in range(curr_graph.gp['num_samples']):
    one_flow, _, _, = sz.flow.estimate_flow(curr_graph, gt.ungroup_vector_property(curr_graph.vp['depth'], pos=[sample_id])[0], curr_graph.vp['length'])
    flow.append(one_flow)
flow = gt.group_vector_property(flow, pos=range(curr_graph.gp['num_samples']))

# Initial depths
plt.hist2d(
    curr_graph.vp['length'].fa,
    curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0),
    bins=(length_bins, depth_bins),
    norm=mpl.colors.LogNorm(vmin=1, vmax=1e3),
)
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.savefig(f'nb/fig/component-{component}/hist_stage{assembly_stage}.pdf')

if draw_graphs:
    _color = curr_graph.new_vertex_property('float', vals=curr_graph.vp['depth'].get_2d_array(range(curr_graph.gp['num_samples'])).sum(0) ** (1/2))
    sz.draw.draw_graph(
        curr_graph,
        vertex_text=curr_graph.vp['length'],
        vertex_fill_color=_color,
        output=f'nb/fig/component-{component}/graph_stage{assembly_stage}.pdf',
        vcmap=(mpl.cm.magma),
    )

In [None]:
# WORKHERE

In [None]:
vertex_results0 = sz.results.extract_vertex_data(curr_graph).assign(assembly=lambda d: d.segments.apply(sz.results.assemble_overlapping_unitigs, unitig_to_sequence=unitig_to_sequence, k=curr_graph.gp['kmer_length']))
vertex_results0.sort_values('num_segments', ascending=False).head(10)

In [None]:
# Any big cycles
vertex_results0[vertex_results0.apply(lambda d: (d.name in d.in_neighbors) | (d.name in d.out_neighbors), axis=1)].sort_values('num_segments', ascending=False).head(5)

In [None]:
# Low depth
vertex_results0[lambda x: (x.total_depth < 30) & (x.total_depth > 10)].sort_values('length', ascending=False).head(10)

In [None]:
assembly_stage = '_final'
v = 1341

print(v)
print(curr_graph.vp.length[v])
print(curr_graph.vp.depth[v])
print(curr_graph.vp.sequence[v])
print()

# fig = plt.figure(figsize=(5, 3))
# sns.heatmap(depth_table.sel(unitig=[int(s[:-1]) for s in curr_graph.vp.sequence[v].split(',')]).to_pandas().T, norm=mpl.colors.SymLogNorm(1e-1))

# Flag nodes in sequence v
in_seq = original_graph_no_tips.new_vertex_property('bool', val=False)
gt.map_property_values(original_graph_no_tips.vp.sequence, in_seq, lambda x: x in curr_graph.vp.sequence[v].split(','))

one_depth = original_graph_no_tips.new_vertex_property('float', original_graph_no_tips.vp['depth'].get_2d_array(pos=range(original_graph_no_tips.gp['num_samples'])).mean(0))
one_flow, _, _, = sz.flow.estimate_flow(original_graph_no_tips, one_depth, original_graph_no_tips.vp['length'])
_color = original_graph_no_tips.new_vertex_property('float', vals=np.sqrt(one_depth.a))

if draw_graphs:
    outpath = f'nb/fig/component-{component}/graph_stage{assembly_stage}_seq{v}_id.pdf'
    print(outpath)
    sz.draw.draw_graph(
        original_graph_no_tips,
        vertex_text=original_graph_no_tips.vp['sequence'],
        vertex_halo=in_seq,
        # vertex_text=in_seq,
        vertex_font_size=1,
        vertex_fill_color=_color,
        edge_pen_width=original_graph_no_tips.new_edge_property('float', vals=one_flow.a ** (1/5)),
        output=outpath,
        vcmap=(mpl.cm.magma, 1),
    )

In [None]:
vertex_results0.segments.explode().value_counts().reset_index().sort_values(['count', 'segments'], ascending=(False, True)).head(10)

In [None]:
assert False

In [None]:
u = "1471216-"  # Focal segment/unitig
# Get list of sequences with segment u
vertex_list = idxwhere(vertex_results0.segments.apply(lambda x: u in x))
unitigs = [int(s[:-1]) for s in chain(*vertex_results0.loc[vertex_list].segments)]

# d1 = depth_table.sel(unitig=unitigs).to_pandas().T
# fig = plt.figure()
# sns.clustermap(d1, norm=mpl.colors.SymLogNorm(1e-1), col_cluster=False, metric='cosine')

# d2 = pd.DataFrame(np.stack([graph2.vp['depth'][i] for i in vertex_list]), index=vertex_list).T
# fig = plt.figure()
# sns.clustermap(d2, norm=mpl.colors.SymLogNorm(1e-1), col_cluster=False, metric='cosine')


path = f'nb/fig/component-{component}/seqs_stage_final_node{u}.fn'
with open(path, 'w') as f:
    for vertex, d1 in vertex_results0.loc[vertex_list].iterrows():
        print(f">{vertex}\n{d1.assembly}", file=f)
print(path)

vertex_results0.loc[vertex_list]

In [None]:
for x, y in zip(vertex_results0.loc[vertex_list[0]].segments, vertex_results0.loc[vertex_list[1]].segments):
    print(x, y)