In [None]:
import pickle, os, gzip, json, sys, itertools
from pathlib import Path
from importlib import reload
from dataclasses import dataclass, field
import collections
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pysam
import scipy as sp
import seaborn
import sharedmem
import pygraphviz


plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 300


sys.path.append("scripts")
sys.path.append("../../scripts")

In [None]:
from data_io import is_fwd_id, get_fwd_id, get_sibling_id
from graph import OverlapGraph, GenomicInterval, get_overlap_statistics, remove_false_edges
from truth import get_overlaps
from evaluate import NearestNeighborsConfig, mp_compute_nearest_neighbors
from plots import plot_read_graph, mp_plot_read_graphs, get_graphviz_layout, get_umap_layout

In [None]:
MAX_SAMPLE_SIZE = int(1e9)
COVERAGE_DEPTH = 20
max_n_neighbors = 20

In [None]:
sample = snakemake.wildcards['sample']
dataset = snakemake.wildcards['platform']
region = snakemake.wildcards['region']
method = snakemake.wildcards['method']

npz_path = snakemake.input['feature_matrix']
tsv_path = snakemake.input['metadata']
json_path = snakemake.input['read_features']
nbr_path = snakemake.input['nbr_indice']
time_path = snakemake.input['time_use']

stat_path = snakemake.output['overlap']
fig_path =snakemake.output['fig']

threads  = snakemake.threads

print(sample, dataset, region)

In [None]:
meta_df = pd.read_table(tsv_path).iloc[:MAX_SAMPLE_SIZE, :].reset_index()
read_indices = {read_name: read_id for read_id, read_name in meta_df['read_name'].items()}
feature_matrix = sp.sparse.load_npz(npz_path)[meta_df.index, :]

with gzip.open(json_path, "rt") as f:
    read_features = json.load(f)
    read_features = {i: read_features[i] for i in meta_df.index}

feature_weights = {i: 1 for i in range(feature_matrix.shape[1])}

In [None]:
with open(time_path, 'rb') as file:  
    time_dict = pickle.load(file) 
data = np.load(nbr_path) 
nbr_indices = data['arr_0']

In [None]:
read_ids = np.array(list(read_features))
graphs = collections.defaultdict(dict)
k_values = np.arange(2, max_n_neighbors + 1)

for k in k_values:
    graph = OverlapGraph.from_neighbor_indices(
        neighbor_indices=nbr_indices,
        n_neighbors=k,
        read_ids=read_ids,
        require_mutual_neighbors=False,
    )
    graphs[k] = graph

In [None]:
def get_read_intervals(meta_df):
    read_intervals = {
        i: [GenomicInterval(strand, start, end)]
        for i, strand, start, end in zip(
            meta_df.index,
            meta_df["reference_strand"],
            meta_df["reference_start"],
            meta_df["reference_end"],
        )
    }
    return read_intervals

read_intervals = get_read_intervals(meta_df)

reference_graph = OverlapGraph.from_intervals(read_intervals)
nr_edges = set((node_1, node_2) for node_1, node_2, data in reference_graph.edges(data=True) if not data['redundant'])
connected_component_count = len(list(nx.connected_components(reference_graph)))
len(reference_graph.nodes), len(reference_graph.edges), len(nr_edges), connected_component_count

In [None]:
df_rows = []

for k in k_values:
    graph = graphs[k]
    graph_stats = get_overlap_statistics(query_graph=graph, reference_graph=reference_graph)
    time_stats = {"elapsed_time:" + k: v for k, v in time_dict.items()}  
    stats = {"description": config.description, "n_neighbors": k, 
                **graph_stats, **time_stats}
    df_rows.append(stats)
df = pd.DataFrame(df_rows)
df['connected_fraction'] = 1 - df['singleton_fraction']
df.to_csv(stat_path,sep='\t')

In [None]:
def remove_small_components(graph, min_component_size=10):
    small_components = set()
    for component in nx.connected_components(graph):
        if len(component) < min_component_size:
            small_components |= component
    graph.remove_nodes_from(small_components)
            
    
def plot_graphs(graphs, reference_graph, metadata, *, min_component_size=10, 
                processes=8, layout_method='stdp', figsize=(6, 6), node_size=3,
    seed: int = 4829, verbose=True):
    axes = []
    new_graphs = []
    g = graphs
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
    ax1.set_title("All edges")
    ax2.set_title("Correct edges")
    
    g = g.copy()
    remove_small_components(g, min_component_size=min_component_size)
    new_graphs.append(g)
    axes.append(ax1)

    g = g.copy()
    remove_false_edges(g, reference_graph)
    remove_small_components(g, min_component_size=min_component_size)
    new_graphs.append(g)
    axes.append(ax2)

    fig.suptitle(fig_path, ha="center", va="bottom", wrap=True, size=7)

    query_graphs = new_graphs
    def plot(i, pos):
        plot_read_graph(
            ax=axes[i],
            query_graph=query_graphs[i],
            reference_graph=reference_graph,
            metadata=metadata,
            pos=pos,
            node_size=node_size,
        )

    with sharedmem.MapReduce(np=processes) as pool:

        def work(i):
            if layout_method == "umap":
                pos = get_umap_layout(graph=query_graphs[i])
            else:
                pos = get_graphviz_layout(
                    graph=query_graphs[i],
                    figsize=figsize,
                    seed=seed,
                    layout_method=layout_method,
                )
            return i, pos

        def reduce(i, pos):
            if verbose:
                print(i, end=" ")
            plot(i, pos)

        pool.map(work, range(len(query_graphs)), reduce=reduce)
        if verbose:
            print("")

    return fig

In [None]:
%%time
k6_graphs = graphs[6]
figures = plot_graphs(
    k6_graphs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
)
figures.savefig(fig_path,dpi=300,bbox_inches='tight')