In [None]:
import pickle, os, gzip, json, sys
from pathlib import Path
from importlib import reload
from dataclasses import dataclass, field
import collections
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pysam
import scipy as sp
import seaborn
import sharedmem

plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 300


sys.path.append("scripts")
sys.path.append("../../scripts")

## Load data

In [None]:
from data_io import is_fwd_id, get_fwd_id, get_sibling_id
from nearest_neighbors import (
    ExactNearestNeighbors,
    NNDescent,
    WeightedLowHash,
    PAFNearestNeighbors,
    LowHash,
    HNSW,
    _NearestNeighbors,
)
from graph import OverlapGraph, GenomicInterval, get_overlap_statistics
from truth import get_overlaps
from evaluate import NearestNeighborsConfig, mp_evaluate_configs
from plots import plot_read_graph, mp_plot_read_graphs, get_graphviz_layout, get_umap_layout

In [None]:
sample = snakemake.wildcards['sample']
dataset = snakemake.wildcards['platform']
region = snakemake.wildcards['region']
print(sample, dataset, region)

npz_path = snakemake.input['feature_matrix']
tsv_path = snakemake.input['metadata']
json_path = snakemake.input['read_features']
paf_path = snakemake.input['paf']

output_folder = snakemake.output['folder']

threads  = snakemake.threads

In [None]:
MAX_SAMPLE_SIZE = int(1e9)

meta_df = pd.read_table(tsv_path).iloc[:MAX_SAMPLE_SIZE, :].reset_index()
read_indices = {read_name: read_id for read_id, read_name in meta_df['read_name'].items()}

feature_matrix = sp.sparse.load_npz(npz_path)[meta_df.index, :]

with gzip.open(json_path, "rt") as f:
    read_features = json.load(f)
    read_features = {i: read_features[i] for i in meta_df.index}

feature_weights = {i: 1 for i in range(feature_matrix.shape[1])}

In [None]:
fig, ax = plt.subplots(figsize=(8, 2.5))
ax.hist([len(x) for x in read_features.values()], bins=100);
ax.set_xlabel("Number of features per read")
ax.set_ylabel("Number of reads")
ax.grid(color='k', alpha=0.1)

In [None]:
feature_matrix.shape

In [None]:
feature_matrix.sum() / (feature_matrix.shape[0] * feature_matrix.shape[1])

In [None]:
meta_df

In [None]:
def get_read_intervals(meta_df):
    read_intervals = {
        i: [GenomicInterval(strand, start, end)]
        for i, strand, start, end in zip(
            meta_df.index,
            meta_df["reference_strand"],
            meta_df["reference_start"],
            meta_df["reference_end"],
        )
    }
    return read_intervals

read_intervals = get_read_intervals(meta_df)
len(read_intervals)

In [None]:
%%time
reference_graph = OverlapGraph.from_intervals(read_intervals)
nr_edges = set((node_1, node_2) for node_1, node_2, data in reference_graph.edges(data=True) if not data['redundant'])
connected_component_count = len(list(nx.connected_components(reference_graph)))
len(reference_graph.nodes), len(reference_graph.edges), len(nr_edges), connected_component_count

## Get nearest neighbours

In [None]:
kw = dict(data=feature_matrix)

configs = [
    # Minimap2 all-vs-all
    NearestNeighborsConfig(
        method=PAFNearestNeighbors,
        nearest_neighbor_kw=dict(paf_path=paf_path, read_indices=read_indices),
        **kw
    ),
    # DimReduction + HNSW 
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=100,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=300,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
    # LowHash
    NearestNeighborsConfig(
        method=LowHash,
        use_tfidf=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
        **kw
    ),
    # Weighted LowHash
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
        **kw
    ),
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=True,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
        **kw
    ),
]


small_data_configs = [
    # NNdescent
    NearestNeighborsConfig(
        method=NNDescent,
        use_tfidf=True,
        dim_reduction=None,
        nearest_neighbor_kw=dict(metric="euclidean", n_jobs=None),
        **kw
    ),
    # Euclidean (exact)
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=False,
        dim_reduction=None,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=None,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=100,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=300,
        nearest_neighbor_kw=dict(metric="euclidean"),
        **kw
    ),
]

if feature_matrix.shape[0] <= 10_000:
    configs += small_data_configs

In [None]:
%%time

max_n_neighbors = 20

for i, config in enumerate(configs):
    print(i, config, sep='\t')
    config.compute_nearest_neighbors(n_neighbors=max_n_neighbors)

## Statistics

In [None]:
graphs = collections.defaultdict(dict)
k_values = np.arange(2, max_n_neighbors + 1)
for i, config in enumerate(configs):
    print(i, end=' ')
    for k in k_values:
        graph = config.get_overlap_graph(n_neighbors=k, read_ids=list(read_features), require_mutual_neighbors=False)
        graphs[i][k] = graph

In [None]:
df_rows = []
for i in range(len(configs)):
    print(i, end=' ')
    for k in k_values:
        graph = graphs[i][k]
        stats = get_overlap_statistics(query_graph=graph, reference_graph=reference_graph)
        stats = {"config_id": i, "description": str(config), "n_neighbors": k, **stats}
        df_rows.append(stats)
df = pd.DataFrame(df_rows)
df.to_csv(os.path.join(output_folder, "overlap_statistics.tsv"),  sep='\t')

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
g = seaborn.lineplot(ax=ax, data=df, x='n_neighbors', y="nr_recall", hue='config_id', palette='tab20')
ax.set_xticks(k_values)
ax.set_xlabel("Number of neighbors")
ax.set_ylabel("Recall (non-redundant overlaps)")

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
g = seaborn.lineplot(ax=ax, data=df, x='n_neighbors', y="precision", hue='config_id', palette='tab20')
ax.set_xticks(k_values)
ax.set_xlabel("Number of neighbors")
ax.set_ylabel("Precision")

## Visualisation

In [None]:
def remove_singletons(graph):
    singletons = []
    for node in graph.nodes:
        if len(graph[node]) <= 1:
            singletons.append(node)
    graph.remove_nodes_from(singletons)

def remove_false_edges(graph, reference_graph):
    false_edges= []
    for u, v in graph.edges:
        if not reference_graph.has_edge(u, v):
            false_edges.append((u,v))
    graph.remove_edges_from(false_edges)
            
    
def plot_graphs(graphs, reference_graph, metadata, *, processes=8, layout_method='stdp', figsize=(6, 6), node_size=3,
    seed: int = 4829, verbose=True):
    axes = []
    figures = []

    new_graphs = []
    for g in graphs:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
        figures.append(fig)
        ax1.set_title("Raw")
        ax2.set_title("False edges removed")
        
        g = g.copy()
        remove_singletons(g)
        new_graphs.append(g)
        axes.append(ax1)

        g = g.copy()
        remove_false_edges(g, reference_graph)
        new_graphs.append(g)
        axes.append(ax2)

    for fig, config in zip(figures, configs):
        fig.suptitle(str(config), ha="center", va="bottom", wrap=True, size=7)

    query_graphs = new_graphs
    def plot(i, pos):
        plot_read_graph(
            ax=axes[i],
            query_graph=query_graphs[i],
            reference_graph=reference_graph,
            metadata=metadata,
            pos=pos,
            node_size=node_size,
        )

    with sharedmem.MapReduce(np=processes) as pool:

        def work(i):
            if layout_method == "umap":
                pos = get_umap_layout(graph=query_graphs[i])
            else:
                pos = get_graphviz_layout(
                    graph=query_graphs[i],
                    figsize=figsize,
                    seed=seed,
                    method=layout_method,
                )
            return i, pos

        def reduce(i, pos):
            if verbose:
                print(i, end=" ")
            plot(i, pos)

        pool.map(work, range(len(query_graphs)), reduce=reduce)
        if verbose:
            print("")

    return figures

In [None]:
%%time
k6_graphs = [graphs[i][6] for i in range(len(graphs))]
plot_graphs(
    k6_graphs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
)

In [None]:
%%time
k12_graphs = [graphs[i][6] for i in range(len(graphs))]
plot_graphs(
    k12_graphs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
)