In [None]:
import pickle, os, gzip, json, sys
from pathlib import Path
from importlib import reload
from dataclasses import dataclass, field
import collections
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pysam
import scipy as sp

plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 300


sys.path.append("scripts")
sys.path.append("../../scripts")

## Load data

In [None]:
from data_io import is_fwd_id, get_fwd_id, get_sibling_id
from nearest_neighbors import (
    ExactNearestNeighbors,
    NNDescent,
    WeightedLowHash,
    PAFNearestNeighbors,
    LowHash,
    HNSW,
    _NearestNeighbors,
)
from graph import ReadGraph, GenomicInterval
from truth import get_overlaps
from evaluate import NearestNeighborsConfig, mp_evaluate_configs

In [None]:
sample = snakemake.wildcards['sample']
dataset = snakemake.wildcards['platform']
region = snakemake.wildcards['region']
print(sample, dataset, region)

npz_path = snakemake.input['feature_matrix']
tsv_path = snakemake.input['metadata']
json_path = snakemake.input['read_features']
paf_path = snakemake.input['paf']

output_folder = snakemake.output['folder']

threads  = snakemake.threads

In [None]:
MAX_SAMPLE_SIZE = int(1e9)

meta_df = pd.read_table(tsv_path).iloc[:MAX_SAMPLE_SIZE, :].reset_index()
read_indices = {read_name: read_id for read_id, read_name in meta_df['read_name'].items()}

feature_matrix = sp.sparse.load_npz(npz_path)[meta_df.index, :]

with gzip.open(json_path, "rt") as f:
    read_features = json.load(f)
    read_features = {i: read_features[i] for i in meta_df.index}

feature_weights = {i: 1 for i in range(feature_matrix.shape[1])}

In [None]:
fig, ax = plt.subplots(figsize=(8, 2.5))
ax.hist([len(x) for x in read_features.values()], bins=100);
ax.set_xlabel("Number of features per read")
ax.set_ylabel("Number of reads")
ax.grid(color='k', alpha=0.1)

In [None]:
feature_matrix.shape

In [None]:
feature_matrix.sum() / (feature_matrix.shape[0] * feature_matrix.shape[1])

In [None]:
meta_df

In [None]:
def get_read_intervals(meta_df):
    read_intervals = {
        i: [GenomicInterval(strand, start, end)]
        for i, strand, start, end in zip(
            meta_df.index,
            meta_df["reference_strand"],
            meta_df["reference_start"],
            meta_df["reference_end"],
        )
    }
    return read_intervals

read_intervals = get_read_intervals(meta_df)
len(read_intervals)

In [None]:
reference_graph = ReadGraph.from_intervals(read_intervals)
nr_edges = set((node_1, node_2) for node_1, node_2, data in reference_graph.edges(data=True) if not data['redundant'])
connected_component_count = len(list(nx.connected_components(reference_graph)))
len(reference_graph.nodes), len(reference_graph.edges), len(nr_edges), connected_component_count

## Get nearest neighbours

In [None]:
configs = [
    # Minimap2 all-vs-all
    NearestNeighborsConfig(
        method=PAFNearestNeighbors,
        n_neighbors=6,
        nearest_neighbor_kw=dict(paf_path=paf_path, read_indices=read_indices),
    ),
    # LowHash
    NearestNeighborsConfig(
        method=LowHash,
        use_tfidf=False,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    NearestNeighborsConfig(
        method=LowHash,
        use_tfidf=False,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    NearestNeighborsConfig(
        method=LowHash,
        use_tfidf=False,
        n_neighbors=1_000,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    # Weighted LowHash
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=False,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=False,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=True,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    NearestNeighborsConfig(
        method=WeightedLowHash,
        use_tfidf=True,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(
            lowhash_fraction=0.01,
            repeats=100,
            max_bucket_size=10,
            min_cooccurence_count=2,
            seed=458,
        ),
    ),
    # NNdescent
    NearestNeighborsConfig(
        method=NNDescent,
        use_tfidf=True,
        dim_reduction=None,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean", n_jobs=1),
    ),
    NearestNeighborsConfig(
        method=NNDescent,
        use_tfidf=True,
        dim_reduction=None,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean", n_jobs=1),
    ),
    # DimReduction + HNSW 
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=100,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=100,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=1000,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=HNSW,
        use_tfidf=True,
        dim_reduction=1000,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
]


small_data_configs = [
    # Euclidean (exact)
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=False,
        dim_reduction=None,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=False,
        dim_reduction=None,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=None,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=None,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=100,
        n_neighbors=6,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
    NearestNeighborsConfig(
        method=ExactNearestNeighbors,
        use_tfidf=True,
        dim_reduction=100,
        n_neighbors=12,
        require_mutual_neighbors=False,
        nearest_neighbor_kw=dict(metric="euclidean"),
    ),
]

if feature_matrix.shape[0] <= 100_000:
    configs += small_data_configs

In [None]:
%%time
pickle_file = os.path.join(output_folder, "alignment_dict.pickle.gz")
configs = mp_evaluate_configs(
    configs,
    alignment_pickle_path=pickle_file,
    feature_matrix=feature_matrix,
    feature_weights=feature_weights,
    read_features=read_features,
    pairwise_alignment=True,
    post_align_n_neighbors=6,
    processes=threads,
    batch_size=1_000,
    reference_graph=reference_graph,
)

## Visualisation

In [None]:
# Plots
import plots
reload(plots)
from plots import plot_read_graph, mp_plot_read_graphs, get_graphviz_layout, get_umap_layout

In [None]:
def remove_singletons(graph):
    singletons = []
    for node in graph.nodes:
        if len(graph[node]) <= 1:
            singletons.append(node)
    graph.remove_nodes_from(singletons)

def plot_configs(configs, reference_graph, metadata, *, show_singletons=False, **kw):
    graphs = []
    axes = []
    figures = []
    for config in configs:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        figures.append(fig)
        ax1.set_title("Pre alignment")
        ax2.set_title("Post alignment")
        
        g = config.pre_align_graph
        if g is not None:
            g = g.copy()
            remove_singletons(g)
            
            graphs.append(g)
            axes.append(ax1)

        g = config.post_align_graph
        if g is not None:
            g = g.copy()
            remove_singletons(g)
            
            graphs.append(g)
            axes.append(ax2)

    mp_plot_read_graphs(
        axes, graphs, reference_graph=reference_graph, metadata=metadata, figsize=(6,6), **kw
    )

    for fig, config in zip(figures, configs):
        fig.suptitle(str(config), ha="center", va="bottom", wrap=True, size=6)

In [None]:
%%time
if feature_matrix.shape[0] <= 100_000:
    plot_configs(
        configs, reference_graph=reference_graph, metadata=meta_df, layout_method="sfdp", processes=threads
    )