In [None]:
import pickle, os, gzip, json, sys, itertools
from pathlib import Path
from importlib import reload
from dataclasses import dataclass, field
import collections
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pysam
import scipy as sp
import seaborn
import sharedmem

plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 300


sys.path.append("scripts")
sys.path.append("../../scripts")

In [None]:
from data_io import is_fwd_id, get_fwd_id, get_sibling_id
from dim_reduction import SpectralEmbedding, scBiMapEmbedding
from nearest_neighbors import (
    ExactNearestNeighbors,
    NNDescent,
    WeightedLowHash,
    PAFNearestNeighbors,
    LowHash,
    HNSW,
    ProductQuantization,
    _NearestNeighbors,
)
from graph import OverlapGraph, GenomicInterval, get_overlap_statistics, remove_false_edges
from truth import get_overlaps
from evaluate import NearestNeighborsConfig, mp_compute_nearest_neighbors
from plots import plot_read_graph, mp_plot_read_graphs, get_graphviz_layout, get_umap_layout

## Parameters


In [None]:
MAX_SAMPLE_SIZE = int(1e9)
COVERAGE_DEPTH = 20

## Load data

In [None]:
sample = snakemake.wildcards['sample']
dataset = snakemake.wildcards['platform']
region = snakemake.wildcards['region']

npz_path = snakemake.input['feature_matrix']
tsv_path = snakemake.input['metadata']
json_path = snakemake.input['read_features']
paf_path = snakemake.input['paf']

output_folder = snakemake.output['folder']

threads  = snakemake.threads

print(sample, dataset, region)

In [None]:
meta_df = pd.read_table(tsv_path).iloc[:MAX_SAMPLE_SIZE, :].reset_index()
read_indices = {read_name: read_id for read_id, read_name in meta_df['read_name'].items()}

feature_matrix = sp.sparse.load_npz(npz_path)[meta_df.index, :]

with gzip.open(json_path, "rt") as f:
    read_features = json.load(f)
    read_features = {i: read_features[i] for i in meta_df.index}

feature_weights = {i: 1 for i in range(feature_matrix.shape[1])}

In [None]:
fig, ax = plt.subplots(figsize=(8, 2.5))
ax.hist([len(x) for x in read_features.values()], bins=100)
ax.set_xlim(left=0)
ax.set_xlabel("Number of features per read")
ax.set_ylabel("Number of reads")
ax.grid(color='k', alpha=0.1)

In [None]:
feature_matrix.shape

In [None]:
feature_matrix.sum() / (feature_matrix.shape[0] * feature_matrix.shape[1])

In [None]:
meta_df

In [None]:
def get_read_intervals(meta_df):
    read_intervals = {
        i: [GenomicInterval(strand, start, end)]
        for i, strand, start, end in zip(
            meta_df.index,
            meta_df["reference_strand"],
            meta_df["reference_start"],
            meta_df["reference_end"],
        )
    }
    return read_intervals

read_intervals = get_read_intervals(meta_df)
len(read_intervals)

In [None]:
%%time
reference_graph = OverlapGraph.from_intervals(read_intervals)
nr_edges = set((node_1, node_2) for node_1, node_2, data in reference_graph.edges(data=True) if not data['redundant'])
connected_component_count = len(list(nx.connected_components(reference_graph)))
len(reference_graph.nodes), len(reference_graph.edges), len(nr_edges), connected_component_count

## Get nearest neighbours

In [None]:
kw = dict(data=feature_matrix)
max_bucket_size = COVERAGE_DEPTH * 1.5

configs = [
    
    # Minimap2 all-vs-all
    NearestNeighborsConfig(
        nearest_neighbors_method=PAFNearestNeighbors,
        description="Minimap2 all-vs-all",
        nearest_neighbors_kw=dict(paf_path=paf_path, read_indices=read_indices),
        **kw
    ),
    
    # DimReduction + HNSW 
    NearestNeighborsConfig(
        nearest_neighbors_method=HNSW,
        description="HNSW (IDF, Spectral 100 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=HNSW,
        description="HNSW (IDF, scBiMap 100 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=scBiMapEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=HNSW,
        description="HNSW (TF-IDF, Spectral 100 dim.)",
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=HNSW,
        description="HNSW (TF-IDF, scBiMap 100 dim.)",
        tfidf=True,
        dimension_reduction_method=scBiMapEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    # LowHash
    NearestNeighborsConfig(
        nearest_neighbors_method=LowHash,
        description="LowHash (binary)",
        tfidf=False,
        nearest_neighbors_kw=dict(
            lowhash_fraction=0.01,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=LowHash,
        description="MinHash (binary)",
        tfidf=False,
        nearest_neighbors_kw=dict(
            lowhash_count=20,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    
    # Weighted LowHash
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted LowHash (TF)",
        tfidf=False,
        nearest_neighbors_kw=dict(
            lowhash_fraction=0.01,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted LowHash (IDF)",
        binarize=True,
        tfidf=True,
        nearest_neighbors_kw=dict(
            lowhash_fraction=0.01,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted LowHash (TF-IDF)",
        tfidf=True,
        nearest_neighbors_kw=dict(
            lowhash_fraction=0.01,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted MinHash (TF)",
        tfidf=False,
        nearest_neighbors_kw=dict(
            lowhash_count=20,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted MinHash (IDF)",
        binarize=True,
        tfidf=True,
        nearest_neighbors_kw=dict(
            lowhash_count=20,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=WeightedLowHash,
        description="Weighted MinHash (TF-IDF)",
        tfidf=True,
        nearest_neighbors_kw=dict(
            lowhash_count=20,
            max_bucket_size=max_bucket_size,
            repeats=100,
            seed=458,
        ),
        **kw
    ),
]


small_data_configs = [
    # PQ
    NearestNeighborsConfig(
        nearest_neighbors_method=ProductQuantization,
        description="PQ (IDF, 100 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(nbits=6),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ProductQuantization,
        description="PQ (TF-IDF, 100 dim.)",
        binarize=False,
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(nbits=6),
        **kw
    ),
    # NNdescent
    NearestNeighborsConfig(
        nearest_neighbors_method=NNDescent,
        description="NNdescent (IDF)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean", n_jobs=None),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=NNDescent,
        description="NNdescent (TF-IDF)",
        tfidf=True,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean", n_jobs=None),
        **kw
    ),
    
    Euclidean (exact)
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (binary)",
        binarize=True,
        tfidf=False,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (TF)",
        tfidf=False,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (IDF)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (TF-IDF)",
        tfidf=True,
        dimension_reduction_method=None,
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (IDF, Spectral 100 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (IDF, Spectral 500 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=500),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (TF-IDF, Spectral 100 dim.)",
        tfidf=True,
        dimension_reduction_method=SpectralEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (IDF, scBiMap 100 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=scBiMapEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (IDF, scBiMap 500 dim.)",
        binarize=True,
        tfidf=True,
        dimension_reduction_method=scBiMapEmbedding,
        dimension_reduction_kw=dict(n_dimensions=500),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        nearest_neighbors_method=ExactNearestNeighbors,
        description="Exact Euclidean (TF-IDF, scBiMap 100 dim.)",
        tfidf=True,
        dimension_reduction_method=scBiMapEmbedding,
        dimension_reduction_kw=dict(n_dimensions=100),
        nearest_neighbors_kw=dict(metric="euclidean"),
        **kw
    ),
]

if feature_matrix.shape[0] <= 20_000:
    configs += small_data_configs

In [None]:
for i, config in enumerate(configs):
    print(i, config, sep='\t')

In [None]:
%%time

max_n_neighbors = COVERAGE_DEPTH
processes = threads if feature_matrix.shape[0] <= 20_000 else 1

nbr_dict = mp_compute_nearest_neighbors(
    data=feature_matrix,
    configs=configs,
    n_neighbors=max_n_neighbors,
    processes=processes,
)

## Statistics

In [None]:
read_ids = np.array(list(read_features))
graphs = collections.defaultdict(dict)
k_values = np.arange(2, max_n_neighbors + 1)
for i, config in enumerate(configs):
    print(i, end=" ")
    nbr_indices = nbr_dict[i]
    for k in k_values:
        graph = OverlapGraph.from_neighbor_indices(
            neighbor_indices=nbr_indices,
            n_neighbors=k,
            read_ids=read_ids,
            require_mutual_neighbors=False,
        )
        graphs[i][k] = graph

In [None]:
df_rows = []
for i, config in enumerate(configs):
    print(i, end=' ')
    for k in k_values:
        graph = graphs[i][k]
        stats = get_overlap_statistics(query_graph=graph, reference_graph=reference_graph)
        stats = {"config_id": i, "description": config.description, "n_neighbors": k, **stats}
        df_rows.append(stats)
df = pd.DataFrame(df_rows)
df.to_csv(os.path.join(output_folder, "overlap_statistics.tsv"),  sep='\t')
df['label'] = df['config_id'].map(str) + " " + df['description']
df['connected_fraction'] = 1 - df['singleton_fraction']

In [None]:
y_labels = dict(
    precision="Precision",
    nr_recall="Recall (non-redundant overlaps)",
    N50="Graph N50 (after removing incorrect edges)",
    connected_fraction="Connected fraction (after removing incorrect edges)"
)

if len(configs) <= 10:
    palette = "tab10"
elif len(configs) <= 20:
    palette = 'tab20'
else:
    palette = 'rainbow'

fig, axes = plt.subplots(2, 2, figsize=(14, 8), constrained_layout=True)
for i, ax in enumerate( itertools.chain(*axes) ):
    y = list(y_labels)[i]
    x = 'n_neighbors'
    hue = 'label'
    g = seaborn.lineplot(ax=ax, data=df, x=x, y=y, hue=hue, palette=palette)
    ax.set_xticks(k_values)
    ax.set_xlabel("Number of neighbors")
    ax.set_ylabel(y_labels[y])
    ax.grid(axis='both', color='k', alpha=0.1)

    if i == 1:
        ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1))
    else:
        ax.get_legend().remove()

    ax.spines[['top', 'right']].set_visible(False)

## Graph visualisation

In [None]:
def remove_small_components(graph, min_component_size=10):
    small_components = set()
    for component in nx.connected_components(graph):
        if len(component) < min_component_size:
            small_components |= component
    graph.remove_nodes_from(small_components)
            
    
def plot_graphs(graphs, reference_graph, metadata, *, min_component_size=10, 
                processes=8, layout_method='stdp', figsize=(6, 6), node_size=3,
    seed: int = 4829, verbose=True):
    axes = []
    figures = []

    new_graphs = []
    for g in graphs:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
        figures.append(fig)
        ax1.set_title("All edges")
        ax2.set_title("Correct edges")
        
        g = g.copy()
        remove_small_components(g, min_component_size=min_component_size)
        new_graphs.append(g)
        axes.append(ax1)

        g = g.copy()
        remove_false_edges(g, reference_graph)
        remove_small_components(g, min_component_size=min_component_size)
        new_graphs.append(g)
        axes.append(ax2)

    for fig, config in zip(figures, configs):
        fig.suptitle(str(config), ha="center", va="bottom", wrap=True, size=7)

    query_graphs = new_graphs
    def plot(i, pos):
        plot_read_graph(
            ax=axes[i],
            query_graph=query_graphs[i],
            reference_graph=reference_graph,
            metadata=metadata,
            pos=pos,
            node_size=node_size,
        )

    with sharedmem.MapReduce(np=processes) as pool:

        def work(i):
            if layout_method == "umap":
                pos = get_umap_layout(graph=query_graphs[i])
            else:
                pos = get_graphviz_layout(
                    graph=query_graphs[i],
                    figsize=figsize,
                    seed=seed,
                    layout_method=layout_method,
                )
            return i, pos

        def reduce(i, pos):
            if verbose:
                print(i, end=" ")
            plot(i, pos)

        pool.map(work, range(len(query_graphs)), reduce=reduce)
        if verbose:
            print("")

    return figures

In [None]:
%%time
k6_graphs = [graphs[i][6] for i in range(len(graphs))]
plot_graphs(
    k6_graphs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
)

In [None]:
%%time
k12_graphs = [graphs[i][6] for i in range(len(graphs))]
plot_graphs(
    k12_graphs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
)