In [1]:
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from snakemake_stub import *


import gzip, json, collections
from typing import Sequence, Mapping, Collection
from Bio import SeqIO
import scipy.sparse as sp
import numpy as np
import pandas as pd
import sys
sys.path.append("scripts")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/4262b1bf4bf1ffb403c0eb7a42ad5906_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/4506eccf78279d93d0e8a34c035e91c5_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/6bda807e3967eae797c7b1b9eeaee8db_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/c2a47d89d1d34e789fdf782557bb7194_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/c6c5514ada15b890fb27d1e36371554c_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/d964a294c2d0fef56a434c021026281e_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/e1c932db5cd4271709e54d8028824bc9_/lib/python3.12/site-packages")
def init_reverse_complement():
    TRANSLATION_TABLE = str.maketrans("ACTGactg", "TGACtgac")

    def reverse_complement(sequence: str) -> str:
        """
        >>> reverse_complement("AATC")
        'GATT'
        >>> reverse_complement("CCANT")
        'ANTGG'
        """
        sequence = str(sequence)
        return sequence.translate(TRANSLATION_TABLE)[::-1]

    return reverse_complement


reverse_complement = init_reverse_complement()

In [2]:
ref_database = '/home/miaocj/docker_dir/data/metagenome/part1.fa'
query_reads = '/home/miaocj/docker_dir/data/metagenome/pbsim_ONT_95_30k_10dep_part1_reads.fasta'

In [3]:
def load_reads(fasta_path: str):
    read_sequences = []
    read_names = []
    read_orientations = []
    read_length = []

    with open(fasta_path, "rt") as handle:  # Open gzipped file in text mode
        for record in SeqIO.parse(handle, "fasta"):
            seq = str(record.seq)
            read_sequences.append(seq)
            read_names.append(record.id)
            read_orientations.append("+")
            read_length.append(len(record))

            # Include reverse complement
            read_sequences.append(reverse_complement(seq))
            read_names.append(record.id)
            read_orientations.append("-")
            read_length.append(len(record))

    return read_names, read_orientations, read_sequences,read_length

read_names, read_orientations, read_sequences, read_length = load_reads(ref_database)


In [9]:
def build_kmer_index(
    read_sequences: Sequence[str],
    k: int,
    *,
    sample_fraction: float,
    min_multiplicity: int,
    seed: int,
) -> Mapping[str, int]:
    kmer_counter = collections.Counter()
    for seq in read_sequences:
        for p in range(len(seq) - k + 1):
            kmer = seq[p : p + k]
            kmer_counter[kmer] += 1

    kmer_spectrum = collections.Counter(x for x in kmer_counter.values() if x <= 10)
    print(kmer_spectrum)

    rng = np.random.default_rng(seed=seed)
    vocabulary = set(
        x
        for x, count in kmer_counter.items()
        if count >= min_multiplicity and rng.random() <= sample_fraction
    )
    vocabulary |= set(reverse_complement(x) for x in vocabulary)
    kmer_indices = {kmer: i for i, kmer in enumerate(vocabulary)}
    return kmer_indices

sample_fraction=0.05
min_multiplicity=2
seed=562104830
kmer_indices = build_kmer_index(        
        read_sequences=read_sequences,
        k=16,
        sample_fraction=sample_fraction,
        min_multiplicity=min_multiplicity,
        seed=seed)


Counter({1: 6853434, 2: 90134, 3: 25438, 4: 10907, 5: 6236, 6: 3622, 7: 2564, 8: 1742, 9: 1320, 10: 810})


In [10]:
qread_names, qread_orientations, qread_sequence, qread_length  = load_reads(query_reads)

In [13]:
ref_reads_tax_list = []
with open(ref_database) as file:
    for lines in file:
        if lines[0] == '>':
            line = lines.strip().split(' ')
            ref_reads_tax_list.append(line[1]+line[2])
ref_read_tax = {i:tax for i,tax in enumerate(ref_reads_tax_list)}

flag = 0
que_read_tax = {}
indice_length = {}
with open(query_reads) as file:
    for lines in file:
        if lines[0] == '>':
            end = lines.index('_')
            ref_num = lines[2:end]
            ref_tax = ref_reads_tax_list[int(ref_num)-1]
            que_read_tax[flag] = ref_tax
            flag +=1

In [12]:
def build_feature_matrix(
    read_sequences: Sequence[str],
    kmer_indices: Mapping[str, int],
    k: int,
) -> tuple[sp.csr_matrix, Sequence[Sequence[int]]]:
    row_ind, col_ind, data = [], [], []
    read_features = []
    for i, seq in enumerate(read_sequences):
        features_i = []
        for p in range(len(seq) - k + 1):
            kmer = seq[p : p + k]
            j = kmer_indices.get(kmer)
            if j is None:
                continue
            features_i.append(j)

        read_features.append(features_i)

        kmer_counts = collections.Counter(features_i)
        for j, count in kmer_counts.items():
            row_ind.append(i)
            col_ind.append(j)
            data.append(count)

    feature_matrix = sp.csr_matrix(
        (data, (row_ind, col_ind)), shape=(len(read_sequences), len(kmer_indices))
    )
    return feature_matrix, read_features

In [49]:
len(que_read_tax)

59625

In [13]:
##calculate sensitivity and precision

from sklearn.metrics import precision_score, recall_score  

def evaluate(indices):
    actual = []
    prediction = []
    for query_read_num,x in enumerate(indices):
        neighbor = x[0]
        neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
        query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
        prediction.append(ref_read_tax[neighbor])
        actual.append(que_read_tax[query_read_num])

    precision = precision_score(actual, prediction,average='macro')
    sensitivity = recall_score(actual, prediction,average='macro')
    ##计算每个类别的
    precision_sep = precision_score(actual, prediction, average=None)  
    sensitivity_sep = recall_score(actual, prediction, average=None)
    return precision,sensitivity,precision_sep,sensitivity_sep

In [14]:
ref_feature_matrix,ref_read_features = build_feature_matrix(read_sequences=read_sequences,kmer_indices=kmer_indices,k=16)

In [32]:
ref_feature_matrix.shape

(1000, 2899125)

In [16]:
que_feature_matrix,que_read_features = build_feature_matrix(read_sequences=qread_sequence,kmer_indices=kmer_indices,k=16)

In [33]:
que_feature_matrix.shape

(119250, 2899125)

In [None]:
ref_feature_matrix = sp.load_npz('/home/miaocj/docker_dir/kNN-overlap-finder/data/feature_matrix/metagenome/pbsim_ONT_95_20k/kmer_16/ref_feature_matrix.npz')
que_feature_matrix = sp.load_npz('/home/miaocj/docker_dir/data/metagenome/que_feature_matrix.npz')

In [4]:
from dataclasses import dataclass
import scipy as sp
import numpy as np
from numpy.typing import NDArray
from scipy import sparse
from scipy.sparse import csr_matrix
from sklearn.preprocessing import normalize as normalize_function
import anndata
from sklearn.decomposition import TruncatedSVD
from sklearn import random_projection

class _SpectralMatrixFree:
    """
    Perform dimension reduction using Laplacian Eigenmaps.

    Matrix-free spectral embedding without computing the similarity matrix explicitly.

    Only cosine similarity is supported.

    Adapted from https://github.com/kaizhang/SnapATAC2/blob/51f040c095820fea43e9a6360d751bfc29faecc5/snapatac2-python/python/snapatac2/tools/_embedding.py#L434
    """

    def __init__(
        self,
        out_dim: int = 30,
        feature_weights=None,
    ):
        self.out_dim = out_dim
        self.feature_weights = feature_weights

    def fit(self, mat):
        if self.feature_weights is not None:
            mat = mat @ sp.sparse.diags(self.feature_weights)
        self.sample = mat
        self.in_dim = mat.shape[1]

        s = 1 / np.sqrt(np.ravel(sp.sparse.csr_matrix.power(mat, 2).sum(axis=1)))
        X = sp.sparse.diags(s) @ mat

        D = np.ravel(X @ X.sum(axis=0).T) - 1
        X = sp.sparse.diags(1 / np.sqrt(D)) @ X
        evals, evecs = self._eigen(X, 1 / D, k=self.out_dim)

        ix = evals.argsort()[::-1]
        self.evals = evals[ix]
        self.evecs = evecs[:, ix]

    def transform(self, weighted_by_sd: bool = True):
        evals = self.evals
        evecs = self.evecs

        if weighted_by_sd:
            idx = [i for i in range(evals.shape[0]) if evals[i] > 0]
            evals = evals[idx]
            evecs = evecs[:, idx] * np.sqrt(evals)
        return evals, evecs

    @staticmethod
    def _eigen(X, D, k):
        def f(v):
            return X @ (v.T @ X).T - D * v

        n = X.shape[0]
        A = sp.sparse.linalg.LinearOperator((n, n), matvec=f, dtype=np.float64)
        return sp.sparse.linalg.eigsh(A, k=k)


class _DimensionReduction:

    def transform(self, data: csr_matrix | NDArray, n_dimensions: int) -> NDArray:
        raise NotImplementedError


class SpectralEmbedding(_DimensionReduction):
    def transform(
        self, data: csr_matrix | NDArray, n_dimensions: int, weighted_by_sd: bool = True
    ) -> NDArray:
        reducer = _SpectralMatrixFree(out_dim=n_dimensions)
        reducer.fit(data)
        _, embedding = reducer.transform(weighted_by_sd=weighted_by_sd)
        return embedding


In [24]:
que_feature_matrix.shape

(4302, 14078)

In [5]:
sys.path.append("scripts")
sys.path.append("../../scripts")
merged_matrix = sp.sparse.vstack([ref_feature_matrix, que_feature_matrix])
dim500 = SpectralEmbedding().transform(merged_matrix,n_dimensions=500) 

  s = 1 / np.sqrt(np.ravel(sp.sparse.csr_matrix.power(mat, 2).sum(axis=1)))
  X = sp.sparse.diags(1 / np.sqrt(D)) @ X


In [15]:
class GaussianRandomProjection(_DimensionReduction):
    def transform(
        self, data: csr_matrix | NDArray, n_dimensions: int
    ) -> NDArray:
        reducer = random_projection.GaussianRandomProjection(n_components=n_dimensions)
        embedding = reducer.fit_transform(data)
        return embedding
dim500_rd = GaussianRandomProjection().transform(merged_matrix,n_dimensions=3200)

In [47]:
que_feature_matrix.shape

(119250, 2899125)

In [41]:
dim500_rd.shape


(120250, 500)

In [7]:
dim500_ref = dim500_rd[:1000]
dim500_que = dim500_rd[1000:]

NameError: name 'dim500_rd' is not defined

In [8]:
dim500_ref1 = dim500[:1000]
dim500_que1 = dim500[1000:]

In [81]:
## 降维+精确求解
from sklearn.neighbors import NearestNeighbors  
import numpy as np  

query_vectors = np.array(dim500_que1) 
database_vectors = np.array(dim500_ref1) 
nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto',metric='cosine').fit(database_vectors)  
distances, indices_exact = nbrs.kneighbors(query_vectors)
#precision,sensitivity,precision_sep,sensitivity_sep = evaluate(indices_exact)


In [83]:
indices_exact

array([[676],
       [677],
       [ 14],
       ...,
       [540],
       [242],
       [243]])

In [25]:
from sklearn.neighbors import NearestNeighbors  
import numpy as np  

query_vectors = np.array(dim500_que1) 
database_vectors = np.array(dim500_ref1) 
nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto',metric='cosine').fit(database_vectors)  
distances, indices_exact1 = nbrs.kneighbors(query_vectors)

In [97]:
len(indices_exact)

119250

In [94]:
precision_sep,sensitivity_sep,precision,sensitivity = evaluate_meta(indices_exact,ref_read_tax,que_read_tax)

KeyError: '40128'

In [86]:
from sklearn.metrics import precision_score, recall_score  

actual = []
prediction = []
for query_read_num,x in enumerate(indices_exact):
    neighbor = x[0]
    neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
    query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
    prediction.append(ref_read_tax[str(int(neighbor))])
    actual.append(que_read_tax[str(int(query_read_num))])

precision = precision_score(actual, prediction,average='macro')
sensitivity = recall_score(actual, prediction,average='macro')
##计算每个类别的
precision_sep = precision_score(actual, prediction, average=None)  
sensitivity_sep = recall_score(actual, prediction, average=None)

KeyError: '40128'

In [23]:
neighbor_indices

NpzFile '/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/Exact_Euclidean_None_TF-IDF_nbr_matrix.npz' with keys: arr_0

In [93]:
def evaluate_meta(nbr_indice,ref_read_tax,que_read_tax):
    actual = []
    prediction = []
    for query_read_num,x in enumerate(nbr_indice):
        neighbor = x[0]
        neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
        query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
        prediction.append(ref_read_tax[str(int(neighbor))])
        actual.append(que_read_tax[str(int(query_read_num))])
    precision_sep=[]
    sensitivity_sep=[]
    for i in set(actual):
        tp = 0
        fp = 0
        fn = 0
        for x in range(len(actual)):
            if actual[x] == i:
                if actual[x] == prediction[x]:
                    tp += 1
                else:
                    fn += 1
            else:
                if prediction[x] == i:
                    fp += 1
        pre = tp/(tp+fp)
        sen = tp/(tp+fn)
        precision_sep.append(pre)
        sensitivity_sep.append(sen)
        precision = sum(precision_sep)/len(precision_sep)
        sensitivity = sum(sensitivity_sep)/len(sensitivity_sep)
    return precision_sep,sensitivity_sep,precision,sensitivity


In [123]:
neighbor_indices = np.load('/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/HNSW_Euclidean_GaussianRP_500d_None_nbr_matrix.npz')

In [124]:

neighbor_indices

NpzFile '/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/HNSW_Euclidean_GaussianRP_500d_None_nbr_matrix.npz' with keys: arr_0

In [127]:
sys.path.append("scripts")
sys.path.append("../../../scripts")

ref_read_tax_file="/home/miaocj/docker_dir/kNN-overlap-finder/data/feature_matrix/metagenome/pbsim_ONT_95_20k/kmer_16/ref_read_tax.json.gz"
que_read_tax_file="/home/miaocj/docker_dir/kNN-overlap-finder/data/feature_matrix/metagenome/pbsim_ONT_95_20k/kmer_16/que_read_tax.json.gz"
with gzip.open(ref_read_tax_file, "rt") as f:
    ref_read_tax = json.load(f)
with gzip.open(que_read_tax_file, "rt") as f:
    que_read_tax = json.load(f)

neighbor_indices = np.load('/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/HNSW_Euclidean_GaussianRP_500d_None_nbr_matrix.npz')
df_rows = []
precision_sep,sensitivity_sep,precision,sensitivity = evaluate_meta(neighbor_indices['arr_0'],ref_read_tax,que_read_tax)
evaluate_dict = dict(
    precision_sep=precision_sep,
    sensitivity_sep=sensitivity_sep,
    precision=precision,
    sensitivity=sensitivity)
df_rows.append(evaluate_dict)

ZeroDivisionError: division by zero

In [None]:
ref_feature_matrix= sp.load_npz(output_ref_npz_file)

In [63]:
ref_feature_matrix = sp.sparse.load_npz('/home/miaocj/docker_dir/kNN-overlap-finder/data/feature_matrix/metagenome/pbsim_ONT_95_20k/kmer_16/ref_feature_matrix.npz')

In [76]:
line_81685 = ref_feature_matrix[81686,:]

In [77]:
num_nonzero = line_81685.nnz  # 非零元素的数量
num_elements = np.prod(line_81685.shape)  # 矩阵总元素数量
density = num_nonzero / num_elements

In [78]:
density

6.369763373026138e-05

In [75]:
density

5.9525299643126355e-05

In [53]:
len(que_read_tax)

40128

In [129]:
neighbor_indices = np.load('/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/HNSW_Euclidean_SparseRP_500d_None_nbr_matrix.npz')
nn = neighbor_indices['arr_0']

In [133]:
nn[1,:]

array([214806, 219581, 221108, 221717, 248074, 214379, 214383, 214392,
       214393, 214402, 214404, 214504, 214595, 214596, 214597, 214755,
       214756, 214775, 214807, 214825], dtype=uint64)

In [138]:
ref_read_tax['107403']

'QUGC01'

In [112]:
len(que_read_tax)

40128

In [122]:
collections.Counter(neighbor_indices['arr_0'].flatten().tolist())

Counter({1818: 553,
         1819: 553,
         0: 390,
         1840: 358,
         1841: 358,
         1: 340,
         1862: 339,
         1863: 339,
         1884: 295,
         1885: 295,
         848: 246,
         849: 246,
         22: 218,
         23: 218,
         44: 214,
         45: 214,
         1906: 205,
         1907: 205,
         231490: 204,
         231491: 204,
         972: 191,
         973: 191,
         1944: 189,
         1945: 189,
         840: 182,
         841: 182,
         64635: 175,
         64634: 175,
         870: 174,
         871: 174,
         54: 173,
         55: 173,
         2941: 168,
         2940: 168,
         1592: 162,
         1593: 162,
         2297: 162,
         2296: 162,
         2896: 155,
         2897: 155,
         1960: 143,
         1961: 143,
         1962: 140,
         1963: 140,
         56: 139,
         57: 139,
         64701: 133,
         64700: 133,
         2318: 127,
         2319: 127,
         237055: 124,


In [125]:
neighbor_indices = np.load('/home/miaocj/docker_dir/kNN-overlap-finder/data/evaluation/metagenome/pbsim_ONT_95_20k/kmer_16/HNSW_Euclidean_GaussianRP_500d_None_nbr_matrix.npz')
neighbor_indices['arr_0']

array([[ 77797],
       [ 77796],
       [231500],
       ...,
       [ 27582],
       [ 27582],
       [ 27582]], dtype=uint64)

In [126]:
collections.Counter(neighbor_indices['arr_0'].flatten().tolist())

Counter({27582: 9100,
         128328: 2546,
         123008: 2242,
         105105: 1753,
         128585: 1401,
         113047: 992,
         59223: 983,
         134369: 930,
         69641: 846,
         107880: 746,
         53063: 704,
         105368: 677,
         150001: 645,
         108500: 628,
         86024: 608,
         242234: 605,
         26195: 569,
         82552: 556,
         81248: 537,
         84910: 525,
         85453: 498,
         18434: 495,
         44087: 494,
         543: 494,
         28512: 491,
         4288: 480,
         27586: 456,
         26194: 455,
         216524: 450,
         197761: 446,
         60494: 441,
         66333: 430,
         44086: 400,
         242236: 397,
         178843: 396,
         1934: 390,
         18435: 384,
         2860: 369,
         219419: 356,
         29940: 353,
         206191: 343,
         45950: 336,
         248072: 335,
         20298: 331,
         215462: 327,
         251840: 319,
         22969

In [39]:
ref_read_tax['81995']

'QRHG01'

In [15]:
precision_sep

array([0.97042045, 0.56018519, 0.23458904, 0.19275701, 0.79422972,
       0.16987887, 0.53108348, 0.48534671, 0.98349835, 0.58404075,
       0.35463918, 0.98897231])

In [19]:
precision_sep

array([0.95552261, 0.77848437, 0.92774216, 0.95049795, 0.95008396,
       0.95396991, 0.96278134, 0.63651031, 0.97497366, 0.72229323,
       0.94242224, 0.98285983])

In [155]:
actual = []
prediction = []
for query_read_num,x in enumerate(indices_exact):
    neighbor = x[0]
    neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
    query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
    prediction.append(ref_read_tax[neighbor])
    actual.append(que_read_tax[query_read_num])

In [164]:
for i,tax in ref_read_tax.items():
    pre_indices = [index for index, value in enumerate(prediction) if value == tax]
    pre_len = sum([qread_length[x] for x in pre_indices])
    actual_len = read_length[i]
    pred_depth = pre_len/actual_len
    print(tax)
    print(pred_depth)

Agaricusbisporus
17.496957207641202
Eremotheciumcymbalariae
22.542886990411617
Kazachstaniaafricana
23.512340437859905
Phaeoacremoniumminimum
45.30084371824357


In [159]:
ref_length

{0: 1748887, 1: 897456, 2: 421466, 3: 592725}

In [109]:
target_class = 'Encephalitozoonromaleae'
true_positive = sum([1 for p, a in zip(prediction, actual) if p == target_class and a == target_class])
false_positive = sum([1 for p, a in zip(prediction, actual) if p == target_class and a != target_class])
false_negative = sum([1 for p, a in zip(prediction, actual) if p != target_class and a == target_class])

recall = true_positive / (true_positive + false_negative)
precision = true_positive / (true_positive + false_positive)
precision,true_positive,false_positive,false_negative

(0.22524752475247525, 91, 313, 1)

In [None]:
import matplotlib.pyplot as plt
import seaborn
plt.figure(figsize=(6, 6), dpi=300) 
plt.scatter(precision_sep,sensitivity_sep, label='Data points', marker='o', c=sensitivity_sep, cmap='rainbow')

plt.gca().spines['right'].set_color('none')  
plt.gca().spines['top'].set_color('none') 
plt.xlabel('precision')
plt.ylabel('sensitivity')
plt.margins(0.3) 

In [114]:
##rough evaluate
right_num = 0
for query_read_num,x in enumerate(indices):
    neighbor = x[0]
    neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
    query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
    if ref_read_tax[neighbor] == que_read_tax[query_read_num]:
        right_num +=1

print(right_num/len(indices))

0.9584399030599837


In [31]:
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Sequence, Type, Mapping, Iterable, Literal
from warnings import warn
from math import ceil
from scipy.sparse._csr import csr_matrix
import numpy as np
import hnswlib

class HNSW():
    def get_neighbors(
        self,
        ref_data: csr_matrix | np.ndarray,
        que_data: csr_matrix | np.ndarray,
        n_neighbors: int,
        metric: Literal["euclidean", "cosine"] = "euclidean",
        *,
        threads: int | None = None,
        M: int = 16,
        ef_construction: int = 200,
        ef_search: int = 50,
    ) -> np.ndarray:
        if metric == "euclidean":
            space = "l2"
        else:
            space = metric

        # Initialize the HNSW index
        p = hnswlib.Index(space=space, dim=ref_data.shape[1])
        if threads is not None:
            p.set_num_threads(threads)
        p.init_index(max_elements=ref_data.shape[0], ef_construction=ef_construction, M=M)
        ids = np.arange(ref_data.shape[0])
        p.add_items(ref_data, ids)
        p.set_ef(ef_search)
        nbr_indices, _ = p.knn_query(que_data, k=n_neighbors)
        return nbr_indices
nbr_indices = HNSW().get_neighbors(ref_data=dim500_ref,que_data=dim500_que,n_neighbors=1)


In [32]:
##dim500_HNSW

indices = HNSW().get_neighbors(ref_data=dim500_ref,que_data=dim500_que,n_neighbors=1)
precision,sensitivity,precision_sep,sensitivity_sep = evaluate(indices)

In [33]:
precision_sep,sensitivity_sep

(array([0.94196891, 0.22524752, 0.9897541 , 0.81634183]),
 array([0.8015873 , 0.98913043, 0.57913669, 0.96286472]))

In [61]:
from dataclasses import dataclass, field
from functools import lru_cache
import collections
from typing import Sequence, Type, Mapping, Iterable, Literal
from warnings import warn
from math import ceil
from scipy import sparse
from scipy.sparse._csr import csr_matrix
import numpy as np
import mmh3



class LowHash():

    @staticmethod
    def _hash(x: int, seed: int) -> int:
        hash_value = mmh3.hash(str(x), seed=seed)
        return hash_value

    @staticmethod
    def _get_hash_values(data: Iterable[int], repeats: int, seed: int) -> np.ndarray:
        rng = np.random.default_rng(seed)
        hash_seeds = rng.integers(low=0, high=2**32 - 1, size=repeats, dtype=np.uint64)
        hash_values = []
        for k in range(repeats):
            s = hash_seeds[k]
            for x in data:
                hash_values.append(LowHash._hash(x, seed=s))
        hash_values = np.array(hash_values, dtype=np.int64)
        return hash_values

    @staticmethod
    def _get_lowhash_count(
        hash_count: int,
        lowhash_fraction: float | None = None,
        lowhash_count: int | None = None,
    ) -> int:
        if lowhash_fraction is None and lowhash_count is None:
            raise TypeError(
                "Either `lowhash_fraction` or `lowhash_count` must be specified."
            )
        if lowhash_fraction is not None and lowhash_count is not None:
            raise TypeError(
                f"`lowhash_fraction` and `lowhash_count` cannot be specified at the same time. {lowhash_fraction=} {lowhash_count=}"
            )

        if lowhash_fraction is not None:
            lowhash_count = ceil(hash_count * lowhash_fraction)
            lowhash_count = max(lowhash_count, 1)
        if lowhash_count is None:
            raise ValueError()
        return lowhash_count
    
    def _lowhash(
        self,
        data: csr_matrix | np.ndarray,
        repeats: int,
        lowhash_fraction: float | None,
        lowhash_count: int | None = None,
        seed: int = 5731343,
        verbose=True,
    ) -> csr_matrix:
        sample_count, feature_count = data.shape
        buckets = sparse.dok_matrix(
            (feature_count * repeats, sample_count), dtype=np.bool_
        )

        # Calculate hash values
        hash_values = self._get_hash_values(
            np.arange(feature_count), repeats=repeats, seed=seed
        )

        # For each sample, find the lowest hash values for its features
        for j in range(sample_count):
            feature_indices = sparse.find(data[j, :] > 0)[1]
            hash_count = feature_indices.shape[0]
            sample_lowhash_count = self._get_lowhash_count(
                hash_count=hash_count,
                lowhash_fraction=lowhash_fraction,
                lowhash_count=lowhash_count,
            )
            for k in range(repeats):
                bucket_indices = feature_indices + (k * feature_count)
                sample_hash_values = hash_values[bucket_indices]
                low_hash_buckets = bucket_indices[
                    np.argsort(sample_hash_values)[:sample_lowhash_count]
                ]
                buckets[low_hash_buckets, j] = 1
            if verbose and j % 1000 == 0:
                print(j, end=" ")
        if verbose:
            print("")
        buckets = sparse.csr_matrix(buckets)
        return buckets

    def _get_adjacency_matrix(
        self,
        data: csr_matrix | np.ndarray,
        buckets: csr_matrix,
        n_neighbors: int,
        min_bucket_size,
        max_bucket_size,
        min_cooccurence_count,
    ) -> np.ndarray:

        # Select neighbor candidates based on cooccurence counts
        row_sums = buckets.sum(axis=1).A1  # type: ignore
        matrix = buckets[
            (row_sums >= min_bucket_size) & (row_sums <= max_bucket_size), :
        ].astype(np.uint8)
        cooccurrence_matrix = matrix.T.dot(matrix)

        neighbor_dict = collections.defaultdict(dict)
        nonzero_indices = list(zip(*cooccurrence_matrix.nonzero()))
        for i, j in nonzero_indices:
            if i >= j: 
                continue

            count = cooccurrence_matrix[i, j]
            neighbor_dict[i][j] = count
            neighbor_dict[j][i] = count

        # Construct neighbor matrix
        n_rows = data.shape[0]
        nbr_matrix = []
        for i in range(n_rows)[8:]:
            row_nbr_dict = {
                j: count
                for j, count in neighbor_dict[i].items()
                if count >= min_cooccurence_count and j < 8
            }
            neighbors = list(
                sorted(row_nbr_dict, key=lambda x: row_nbr_dict[x], reverse=True)
            )[:n_neighbors]
            nbr_matrix.append(neighbors)
        return nbr_matrix

    def get_neighbors(
        self,
        data: csr_matrix | np.ndarray,
        n_neighbors: int,
        lowhash_fraction: float | None = None,
        lowhash_count: int | None = None,
        repeats=100,
        min_bucket_size=2,
        max_bucket_size=float("inf"),
        min_cooccurence_count=1,
        *,
        seed=1,
        verbose=True,
    ) -> np.ndarray:

        buckets = self._lowhash(
            data,
            repeats=repeats,
            lowhash_fraction=lowhash_fraction,
            lowhash_count=lowhash_count,
            seed=seed,
            verbose=verbose,
        )
        nbr_matrix = self._get_adjacency_matrix(
            data,
            buckets,
            n_neighbors=n_neighbors,
            min_bucket_size=min_bucket_size,
            max_bucket_size=max_bucket_size,
            min_cooccurence_count=min_cooccurence_count,
        )
        return nbr_matrix


In [62]:
COVERAGE_DEPTH = 20
max_bucket_size = COVERAGE_DEPTH * 1.5
indices = LowHash().get_neighbors(data=merged_matrix,repeats=100,lowhash_count=20,n_neighbors=1,
            max_bucket_size=max_bucket_size,
            seed=458)

0 1000 2000 3000 


In [67]:
##calculate sensitivity and precision

from sklearn.metrics import precision_score, recall_score  

def evaluate(indices):
    actual = []
    prediction = []
    for query_read_num,x in enumerate(indices):
        if len (x) > 0:
            neighbor = x[0]
            neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
            query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
            prediction.append(ref_read_tax[neighbor])
            actual.append(que_read_tax[query_read_num])

    precision = precision_score(actual, prediction,average='macro')
    sensitivity = recall_score(actual, prediction,average='macro')
    ##计算每个类别的
    precision_sep = precision_score(actual, prediction, average=None)  
    sensitivity_sep = recall_score(actual, prediction, average=None)
    return precision,sensitivity,precision_sep,sensitivity_sep

In [66]:

empty_count = sum([1 for sublist in indices if len(sublist) == 0])
len(indices)

3192

In [68]:
precision,sensitivity,precision_sep,sensitivity_sep = evaluate(indices)

In [69]:
precision_sep,sensitivity_sep
##输入种类比较少的时候，一些本身reads很少的genome 可以设置一个阈值，超过阈值的才被判断为近邻/

(array([0.98910082, 0.23989218, 0.90131579, 0.89467593]),
 array([0.68815166, 1.        , 0.83664122, 0.99357326]))