In [1]:
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from snakemake_stub import *


import gzip, json, collections
from typing import Sequence, Mapping, Collection
from Bio import SeqIO
import scipy.sparse as sp
import numpy as np
import pandas as pd
import sys
sys.path.append("scripts")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/4262b1bf4bf1ffb403c0eb7a42ad5906_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/4506eccf78279d93d0e8a34c035e91c5_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/6bda807e3967eae797c7b1b9eeaee8db_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/c2a47d89d1d34e789fdf782557bb7194_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/c6c5514ada15b890fb27d1e36371554c_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/d964a294c2d0fef56a434c021026281e_/lib/python3.12/site-packages")
sys.path.append("/home/miaocj/docker_dir/kNN-overlap-finder/.snakemake/conda/e1c932db5cd4271709e54d8028824bc9_/lib/python3.12/site-packages")
def init_reverse_complement():
    TRANSLATION_TABLE = str.maketrans("ACTGactg", "TGACtgac")

    def reverse_complement(sequence: str) -> str:
        """
        >>> reverse_complement("AATC")
        'GATT'
        >>> reverse_complement("CCANT")
        'ANTGG'
        """
        sequence = str(sequence)
        return sequence.translate(TRANSLATION_TABLE)[::-1]

    return reverse_complement


reverse_complement = init_reverse_complement()
import mmh3
import sharedmem
from sklearn.neighbors import NearestNeighbors  
import numpy as np  



In [2]:
ref_database = '/home/miaocj/docker_dir/data/metagenome/part1.fa'
query_reads = '/home/miaocj/docker_dir/data/metagenome/pbsim_ONT_95_30k_10dep_part1_reads.fasta'

In [3]:
ref_reads_tax_list = []
with open(ref_database) as file:
    for lines in file:
        if lines[0] == '>':
            line = lines.strip().split(' ')
            ref_reads_tax_list.append(line[1]+line[2])
ref_read_tax = {i:tax for i,tax in enumerate(ref_reads_tax_list)}

flag = 0
que_read_tax = {}
with open(query_reads) as file:
    for lines in file:
        if lines[0] == '>':
            start = lines.index('S')
            end = lines.index('_')
            ref_num = lines[start+1:end]
            que_read_tax[flag] = ref_read_tax[int(ref_num)-1]
            flag +=1


In [4]:
def load_reads(fasta_path: str):
    read_sequences = []
    read_names = []
    read_orientations = []

    with open(fasta_path, "rt") as handle:  # Open gzipped file in text mode
        for record in SeqIO.parse(handle, "fasta"):
            seq = str(record.seq)
            read_sequences.append(seq)
            read_names.append(record.id)
            read_orientations.append("+")

            # Include reverse complement
            read_sequences.append(reverse_complement(seq))
            read_names.append(record.id)
            read_orientations.append("-")

    return read_names, read_orientations, read_sequences

read_names, read_orientations, read_sequences = load_reads(ref_database)


In [5]:
def build_kmer_index(
    read_sequences: Sequence[str],
    k: int,
    *,
    sample_fraction: float,
    min_multiplicity: int,
    seed: int,
) -> Mapping[str, int]:
    kmer_counter = collections.Counter()
    for seq in read_sequences:
        for p in range(len(seq) - k + 1):
            kmer = seq[p : p + k]
            kmer_counter[kmer] += 1

    kmer_spectrum = collections.Counter(x for x in kmer_counter.values() if x <= 10)
    print(kmer_spectrum)

    rng = np.random.default_rng(seed=seed)
    vocabulary = set(
        x
        for x, count in kmer_counter.items()
        if count >= min_multiplicity and rng.random() <= sample_fraction
    )
    vocabulary |= set(reverse_complement(x) for x in vocabulary)
    kmer_indices = {kmer: i for i, kmer in enumerate(vocabulary)}
    return kmer_indices

sample_fraction=0.1
min_multiplicity=2
seed=562104830
kmer_indices = build_kmer_index(        
        read_sequences=read_sequences,
        k=16,
        sample_fraction=sample_fraction,
        min_multiplicity=min_multiplicity,
        seed=seed)


Counter({1: 276256424, 2: 23715451, 3: 3702956, 4: 1039535, 5: 423488, 6: 221958, 7: 134826, 8: 86799, 9: 59680, 10: 48269})


In [6]:
qread_names, qread_orientations, qread_sequences = load_reads(query_reads)

In [8]:
def build_feature_matrix(
    read_sequences: Sequence[str],
    kmer_indices: Mapping[str, int],
    k: int,
) -> tuple[sp.csr_matrix, Sequence[Sequence[int]]]:
    row_ind, col_ind, data = [], [], []
    read_features = []
    for i, seq in enumerate(read_sequences):
        features_i = []
        for p in range(len(seq) - k + 1):
            kmer = seq[p : p + k]
            j = kmer_indices.get(kmer)
            if j is None:
                continue
            features_i.append(j)

        read_features.append(features_i)

        kmer_counts = collections.Counter(features_i)
        for j, count in kmer_counts.items():
            row_ind.append(i)
            col_ind.append(j)
            data.append(count)

    feature_matrix = sp.csr_matrix(
        (data, (row_ind, col_ind)), shape=(len(read_sequences), len(kmer_indices))
    )
    return feature_matrix, read_features

In [9]:
##calculate sensitivity and precision

from sklearn.metrics import precision_score, recall_score  

def evaluate(indices):
    actual = []
    prediction = []
    for query_read_num,x in enumerate(indices):
        neighbor = x[0]
        neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
        query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
        prediction.append(ref_read_tax[neighbor])
        actual.append(que_read_tax[query_read_num])

    precision = precision_score(actual, prediction,average='macro')
    sensitivity = recall_score(actual, prediction,average='macro')
    ##计算每个类别的
    precision_sep = precision_score(actual, prediction, average=None)  
    sensitivity_sep = recall_score(actual, prediction, average=None)
    return precision,sensitivity,precision_sep,sensitivity_sep

In [9]:
ref_feature_matrix,ref_read_features = build_feature_matrix(read_sequences=read_sequences,kmer_indices=kmer_indices,k=16)

In [11]:
ref_feature_matrix.shape

(1000, 5647520)

In [15]:
from isal import igzip
def open_gzipped(path, mode="rt", gzipped: bool | None = None, **kw):
    if gzipped is None:
        gzipped = path.endswith(".gz")
    if gzipped:
        open_ = igzip.open
        return open_(path, mode)
    else:
        open_ = open
    return open_(path, mode, **kw)
                   
output_npz_file = '/home/miaocj/docker_dir/data/metagenome/ref_feature_matrix.npz'
output_json_file = '/home/miaocj/docker_dir/data/metagenome/ref_read_features.json.gz'
sp.save_npz(output_npz_file, ref_feature_matrix)
with open_gzipped(output_json_file, "wt") as f:
    json.dump(ref_read_features, f)

In [10]:
que_feature_matrix,que_read_features = build_feature_matrix(read_sequences=qread_sequences,kmer_indices=kmer_indices,k=16)

In [16]:
##存储
output_npz_file2 = '/home/miaocj/docker_dir/data/metagenome/que_feature_matrix.npz'
output_json_file2 = '/home/miaocj/docker_dir/data/metagenome/que_read_features.json.gz'
sp.save_npz(output_npz_file2, que_feature_matrix)
with open_gzipped(output_json_file2, "wt") as f:
    json.dump(que_read_features, f)

In [4]:
##读取
output_npz_file2 = '/home/miaocj/docker_dir/data/metagenome/que_feature_matrix.npz'
output_json_file2 = '/home/miaocj/docker_dir/data/metagenome/que_read_features.json.gz'
with gzip.open(output_json_file2, "rt") as f:
    que_read_features = json.load(f)
que_feature_matrix = sp.load_npz(output_npz_file2)

output_npz_file = '/home/miaocj/docker_dir/data/metagenome/ref_feature_matrix.npz'
output_json_file = '/home/miaocj/docker_dir/data/metagenome/ref_read_features.json.gz'
with gzip.open(output_json_file, "rt") as f:
    ref_read_features = json.load(f)
ref_feature_matrix = sp.load_npz(output_npz_file)

In [13]:
que_feature_matrix.shape

(119250, 5647520)

In [None]:
empty_indices = [index for index, sublist in enumerate(que_read_features) if len(sublist) == 0]  
empty_indices

In [13]:
sys.path.append("scripts")
sys.path.append("../../scripts")
merged_matrix = sp.vstack([ref_feature_matrix, que_feature_matrix])  
from dim_reduction import scBiMapEmbedding
dim500 = scBiMapEmbedding().transform(merged_matrix,n_dimensions=500)

In [14]:
dim500_ref = dim500[:1000]
dim500_que = dim500[1001:]

In [15]:
## 降维+精确求解

query_vectors = np.array(dim500_que) 
database_vectors = np.array(dim500_ref) 
nbrs = NearestNeighbors(n_neighbors=2, algorithm='auto',metric='cosine').fit(database_vectors)  
distances, indices = nbrs.kneighbors(query_vectors)
precision,sensitivity,precision_sep,sensitivity_sep = evaluate(indices)

In [16]:
precision,sensitivity

(0.970391944806762, 0.9575761776138907)

In [None]:
import matplotlib.pyplot as plt
categories = []
for i in range(1,13):
    categories.append(i)

plt.figure(figsize=(4,3), dpi=300) 
plt.bar(categories,precision_sep,color = 'gray')

plt.ylabel('Genus Precision')
plt.ylim(0.6,1.01)
plt.xticks(rotation=45, ha='right') 
plt.gca().spines['right'].set_color('none')  
plt.gca().spines['top'].set_color('none') 

In [None]:
import matplotlib.pyplot as plt
import seaborn
plt.figure(figsize=(6, 6), dpi=300) 
plt.scatter(precision_sep,sensitivity_sep, label='Data points', marker='o', c=sensitivity_sep, cmap='rainbow')

plt.xlim(0.8,1.01)
plt.ylim(0.8,1.01)
plt.gca().spines['right'].set_color('none')  
plt.gca().spines['top'].set_color('none') 
plt.xlabel('precision')
plt.ylabel('sensitivity')
plt.margins(0.3) 

In [114]:
##rough evaluate
right_num = 0
for query_read_num,x in enumerate(indices):
    neighbor = x[0]
    neighbor = (neighbor-1)/2  if neighbor %2 !=0 else neighbor/2
    query_read_num = (query_read_num-1)/2  if query_read_num %2 !=0 else query_read_num/2
    if ref_read_tax[neighbor] == que_read_tax[query_read_num]:
        right_num +=1

print(right_num/len(indices))

0.9584399030599837


In [65]:
precision_sep,sensitivity_sep

(array([0.87046535, 0.62569832, 0.83373867, 0.86720252, 0.89721499,
        0.87860642, 0.86374657, 0.41751232, 0.90688083, 0.54479376,
        0.83079625, 0.93578767]),
 array([0.47510944, 0.98778833, 0.61177807, 0.68116937, 0.66164303,
        0.6721266 , 0.60990641, 0.99246462, 0.41897312, 0.68407741,
        0.71957404, 0.33051104]))

In [41]:
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Sequence, Type, Mapping, Iterable, Literal
from warnings import warn
from math import ceil
from scipy.sparse._csr import csr_matrix
import numpy as np
import hnswlib

class HNSW():
    def get_neighbors(
        self,
        ref_data: csr_matrix | np.ndarray,
        que_data: csr_matrix | np.ndarray,
        n_neighbors: int,
        metric: Literal["euclidean", "cosine"] = "euclidean",
        *,
        threads: int | None = None,
        M: int = 16,
        ef_construction: int = 200,
        ef_search: int = 50,
    ) -> np.ndarray:
        if metric == "euclidean":
            space = "l2"
        else:
            space = metric

        # Initialize the HNSW index
        p = hnswlib.Index(space=space, dim=ref_data.shape[1])
        if threads is not None:
            p.set_num_threads(threads)
        p.init_index(max_elements=ref_data.shape[0], ef_construction=ef_construction, M=M)
        ids = np.arange(ref_data.shape[0])
        p.add_items(ref_data, ids)
        p.set_ef(ef_search)
        nbr_indices, _ = p.knn_query(que_data, k=n_neighbors)
        return nbr_indices
nbr_indices = HNSW().get_neighbors(ref_data=dim500_ref,que_data=dim500_que,n_neighbors=1)


In [42]:
##dim500_HNSW

indices = HNSW().get_neighbors(ref_data=dim500_ref,que_data=dim500_que,n_neighbors=1)
precision,sensitivity,precision_sep,sensitivity_sep = evaluate(indices)

In [46]:
precision,sensitivity

(0.9581952089324665, 0.952445601680645)

In [46]:
from dataclasses import dataclass, field
from functools import lru_cache
import collections
from typing import Sequence, Type, Mapping, Iterable, Literal
from warnings import warn
from math import ceil
from scipy import sparse
from scipy.sparse._csr import csr_matrix
import numpy as np
import mmh3




class LowHash():

    @staticmethod
    def _hash(x: int, seed: int) -> int:
        hash_value = mmh3.hash(str(x), seed=seed)
        return hash_value

    @staticmethod
    def _get_hash_values(data: Iterable[int], repeats: int, seed: int) -> np.ndarray:
        rng = np.random.default_rng(seed)
        hash_seeds = rng.integers(low=0, high=2**32 - 1, size=repeats, dtype=np.uint64)
        hash_values = []
        for k in range(repeats):
            s = hash_seeds[k]
            for x in data:
                hash_values.append(LowHash._hash(x, seed=s))
        hash_values = np.array(hash_values, dtype=np.int64)
        return hash_values

    @staticmethod
    def _get_lowhash_count(
        hash_count: int,
        lowhash_fraction: float | None = None,
        lowhash_count: int | None = None,
    ) -> int:
        if lowhash_fraction is None and lowhash_count is None:
            raise TypeError(
                "Either `lowhash_fraction` or `lowhash_count` must be specified."
            )
        if lowhash_fraction is not None and lowhash_count is not None:
            raise TypeError(
                f"`lowhash_fraction` and `lowhash_count` cannot be specified at the same time. {lowhash_fraction=} {lowhash_count=}"
            )

        if lowhash_fraction is not None:
            lowhash_count = ceil(hash_count * lowhash_fraction)
            lowhash_count = max(lowhash_count, 1)
        if lowhash_count is None:
            raise ValueError()
        return lowhash_count
    
    def _lowhash(
        self,
        data: csr_matrix | np.ndarray,
        repeats: int,
        lowhash_fraction: float | None,
        lowhash_count: int | None = None,
        seed: int = 5731343,
        verbose=True,
    ) -> csr_matrix:
        sample_count, feature_count = data.shape
        buckets = sparse.dok_matrix(
            (feature_count * repeats, sample_count), dtype=np.bool_
        )

        # Calculate hash values
        hash_values = self._get_hash_values(
            np.arange(feature_count), repeats=repeats, seed=seed
        )

        # For each sample, find the lowest hash values for its features
        for j in range(sample_count):
            feature_indices = sparse.find(data[j, :] > 0)[1]
            hash_count = feature_indices.shape[0]
            sample_lowhash_count = self._get_lowhash_count(
                hash_count=hash_count,
                lowhash_fraction=lowhash_fraction,
                lowhash_count=lowhash_count,
            )
            for k in range(repeats):
                bucket_indices = feature_indices + (k * feature_count)
                sample_hash_values = hash_values[bucket_indices]
                low_hash_buckets = bucket_indices[
                    np.argsort(sample_hash_values)[:sample_lowhash_count]
                ]
                buckets[low_hash_buckets, j] = 1
            if verbose and j % 1000 == 0:
                print(j, end=" ")
        if verbose:
            print("")
        buckets = sparse.csr_matrix(buckets)
        return buckets

    def _get_adjacency_matrix(
        self,
        data: csr_matrix | np.ndarray,
        buckets: csr_matrix,
        n_neighbors: int,
        min_bucket_size,
        max_bucket_size,
        min_cooccurence_count,
    ) -> np.ndarray:

        # Select neighbor candidates based on cooccurence counts
        row_sums = buckets.sum(axis=1).A1  # type: ignore
        matrix = buckets[
            (row_sums >= min_bucket_size) & (row_sums <= max_bucket_size), :
        ].astype(np.uint8)
        cooccurrence_matrix = matrix.T.dot(matrix)

        neighbor_dict = collections.defaultdict(dict)
        nonzero_indices = list(zip(*cooccurrence_matrix.nonzero()))
        for i, j in nonzero_indices:
            if i >= j: 
                continue

            count = cooccurrence_matrix[i, j]
            neighbor_dict[i][j] = count
            neighbor_dict[j][i] = count

        # Construct neighbor matrix
        n_rows = data.shape[0]
        nbr_matrix = []
        for i in range(n_rows)[1000:]:
            row_nbr_dict = {
                j: count
                for j, count in neighbor_dict[i].items()
                if count >= min_cooccurence_count and j < 1000
            }
            neighbors = list(
                sorted(row_nbr_dict, key=lambda x: row_nbr_dict[x], reverse=True)
            )[:n_neighbors]
            nbr_matrix.append(neighbors)
        return nbr_matrix

    def get_neighbors(
        self,
        data: csr_matrix | np.ndarray,
        n_neighbors: int,
        lowhash_fraction: float | None = None,
        lowhash_count: int | None = None,
        repeats=100,
        min_bucket_size=2,
        max_bucket_size=float("inf"),
        min_cooccurence_count=1,
        *,
        seed=1,
        verbose=True,
    ) -> np.ndarray:

        buckets = self._lowhash(
            data,
            repeats=repeats,
            lowhash_fraction=lowhash_fraction,
            lowhash_count=lowhash_count,
            seed=seed,
            verbose=verbose,
        )
        nbr_matrix = self._get_adjacency_matrix(
            data,
            buckets,
            n_neighbors=n_neighbors,
            min_bucket_size=min_bucket_size,
            max_bucket_size=max_bucket_size,
            min_cooccurence_count=min_cooccurence_count,
        )
        return nbr_matrix


In [None]:
COVERAGE_DEPTH = 20
max_bucket_size = COVERAGE_DEPTH * 1.5
indices = LowHash().get_neighbors(data=merged_matrix,repeats=100,lowhash_count=20,n_neighbors=1,
            max_bucket_size=max_bucket_size,
            seed=458)

In [None]:
merged_matrix.shape