In [None]:
!date

In [None]:
%load_ext autoreload
%load_ext line_profiler

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])

In [None]:
import strainzip as sz
import graph_tool as gt
import graph_tool.draw
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
from contextlib import contextmanager
import xarray as xr
from itertools import product
from tqdm import tqdm
from itertools import chain
from strainzip.pandas_util import idxwhere
from graph_tool.util import find_edge
import scipy as sp

In [None]:
# Plotting parameters

length_bins = np.logspace(0, 6.5, num=51)
depth_bins = np.logspace(-1, 4, num=51)

draw_graphs = True
run_number = 20  # Label for output files/figures

In [None]:
with open(f'examples/xjin_test5/r.proc.kmtricks-k111-m3-r2.ggcat.fn') as f:
    _, unitig_to_sequence = sz.io.load_graph_and_sequences_from_linked_fasta(f, 111, sz.io.ggcat_header_tokenizer)

In [None]:
final_graph = sz.io.load_graph(f'examples/xjin_test5/r.proc.kmtricks-k111-m3-r2.ggcat.notips.deconvolve.sz')

In [None]:
final_results = sz.results.extract_vertex_data(final_graph).assign(assembly=lambda d: d.segments.apply(sz.results.assemble_overlapping_unitigs, unitig_to_sequence=unitig_to_sequence, k=final_graph.gp['kmer_length']))
final_results.sort_values(['length'], ascending=False).head(5)

In [None]:
original_graph = sz.io.load_graph('examples/xjin_test5/r.proc.kmtricks-k111-m3-r2.ggcat.gt')

In [None]:
original_results = sz.results.extract_vertex_data(original_graph)
original_results.sort_values(['length'], ascending=False).head(5)

In [None]:
unitig_depth_table = sz.results.full_depth_table(original_graph).rename(original_results.segments.str[0])
unitig_depth_table.shape

In [None]:
focal_path = 128053

related_paths = list(sz.results.iter_find_vertices_with_any_segment(final_graph, final_results.loc[focal_path].segments))
focal_segments = list(set(chain(*final_results.loc[related_paths].segments)))

original_graph_core_vertices = list(sz.results.iter_find_vertices_with_any_segment(original_graph, focal_segments))
original_graph_distance_to_core = sz.topology.get_shortest_distance(original_graph, original_graph_core_vertices, original_graph.vp['length'])
in_neighborhood = original_graph.new_vertex_property('bool', vals=original_graph_distance_to_core.a < 1000)

neighborhood_graph = gt.GraphView(original_graph, vfilt=in_neighborhood)
sz.draw.update_xypositions(neighborhood_graph)

vertex_color = neighborhood_graph.new_vertex_property('float', vals=sz.results.total_depth_property(neighborhood_graph).a**(1/2))

outpath = f'nb/fig/run-{run_number}/final_paths.neighborhood-{focal_path}.fn'
with open(outpath, 'w') as f:
    for path, d1 in final_results.loc[related_paths].iterrows():
        print(f">{path}\n{d1.assembly}", file=f)
print(outpath)

for path in related_paths:
    original_graph_vertices = list(sz.results.iter_find_vertices_with_any_segment(original_graph, final_results.loc[path].segments))
    print(len(original_graph_vertices))
    in_path = original_graph.new_vertex_property('bool', val=False)
    in_path.a[original_graph_vertices] = 1
    outpath = f'nb/fig/run-{run_number}/final_paths.neighborhood-{focal_path}-{path}.pdf'
    sz.draw.draw_graph(
        neighborhood_graph,
        vertex_text=neighborhood_graph.vp['sequence'],
        vertex_halo=in_path,
        vertex_font_size=5,
        vertex_fill_color=vertex_color,
        output=outpath,
        vcmap=(mpl.cm.magma, 1),
        output_size=(1000, 1000),
    )
    print(outpath)

final_results.loc[related_paths]

In [None]:
unitig_depth = unitig_depth_table.loc[focal_segments].T
path_membership = final_results.loc[related_paths].segments.explode().reset_index().value_counts().unstack(fill_value=0)
path_depth = pd.DataFrame({p: final_graph.vp['depth'][p] for p in related_paths})

predicted_unitig_depth = path_depth @ path_membership

obs = unitig_depth
expect = predicted_unitig_depth.loc[obs.index, obs.columns]
membership = path_membership.loc[:,obs.columns]
resid = obs - expect

unitig_linkage = sp.cluster.hierarchy.linkage(membership.T, metric='cosine', method='average')
path_linkage = sp.cluster.hierarchy.linkage(membership, metric='cosine', method='average')
sample_linkage = sp.cluster.hierarchy.linkage(unitig_depth, metric='cosine', method='average')

max_obs = obs.max().max()
max_resid = resid.max().max()

sns.clustermap(path_membership, row_linkage=path_linkage, col_linkage=unitig_linkage, figsize=(10, 5))
sns.clustermap(obs, row_linkage=sample_linkage, col_linkage=unitig_linkage, figsize=(10, 5), norm=mpl.colors.SymLogNorm(1, vmin=0, vmax=max_obs))
sns.clustermap(expect, row_linkage=sample_linkage, col_linkage=unitig_linkage, figsize=(10, 5), norm=mpl.colors.SymLogNorm(1, vmin=0, vmax=max_obs))
sns.clustermap(resid, row_linkage=sample_linkage, col_linkage=unitig_linkage, figsize=(10, 5), norm=mpl.colors.SymLogNorm(1, vmin=-max_resid, vmax=max_resid), cmap='coolwarm')
sns.clustermap(path_depth, row_linkage=sample_linkage, figsize=(3, 3), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=max_obs))