In [6]:
import os
import sys
import glob
from io import StringIO
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pandas as pd
import torch
import zstandard
from tqdm import tqdm
from pyfaidx import Fasta
from Bio import SeqIO
from dnaHMM import dnaHMM
import plotnine as p9

In [3]:
ls -al data/* | tail -n 5

-rw-r----- 1 knonchev INFK-Raetsch-Collab    8261157 Jan 22 17:05 [01;31mdata/ERR1738359.contigs.fa.zst[0m
-rw-r----- 1 knonchev INFK-Raetsch-Collab      98497 Jan 22 17:05 [01;31mdata/ERR1875855.contigs.fa.zst[0m
-rw-r----- 1 knonchev INFK-Raetsch-Collab    3786929 Jan 22 17:05 [01;31mdata/ERR1938089.contigs.fa.zst[0m
-rw-r----- 1 knonchev INFK-Raetsch-Collab     625470 Jan 22 17:05 [01;31mdata/ERR1950831.contigs.fa.zst[0m
-rw-r----- 1 knonchev INFK-Raetsch-Collab  721547668 Jan 22 17:05 [01;31mdata/ERR1976343.contigs.fa.zst[0m


In [10]:
# Each 'acc' represents a contig file containing available sequences.
# 'organism_kmeans' is a meta label used to represent different groups of organisms.
# there are 500 organism_kmeans labels
# For each 'organism_kmeans' group, we sample 50 contigs.
metadata = pd.read_csv("data/eukaryota_genomics_500_50_2025-01-22_15-14-40.csv")
metadata

Unnamed: 0,acc,kingdom,organism,organism_kmeans,mbases
0,SRR11413904,Fungi,Zymoseptoria tritici IPO323,173,2342.0
1,SRR14510856,Viridiplantae,Elaeis oleifera,175,330.0
2,SRR13694312,Viridiplantae,Lactuca saligna,247,43070.0
3,ERR615301,Viridiplantae,Oryza sativa,6,204.0
4,ERR7412375,Viridiplantae,Ambrosia artemisiifolia,239,6867.0
...,...,...,...,...,...
24995,SRR7630445,Fungi,Parastagonospora nodorum,359,2776.0
24996,SRR5042517,Viridiplantae,Cynodon dactylon x Cynodon transvaalensis,404,407.0
24997,SRR11106753,Viridiplantae,Zea mays subsp. mays,103,12020.0
24998,SRR17421512,Viridiplantae,Ceratonia siliqua,229,46.0


In [12]:
for cl in list(metadata["organism_kmeans"].unique())[:10]:
    print(cl, metadata.query(f"organism_kmeans == {cl}").organism.unique())

173 ['Zymoseptoria tritici IPO323' 'Zymoseptoria tritici'
 'Zymoseptoria tritici STIR04_A26b']
175 ['Elaeis oleifera' 'Elaeis guineensis' 'Elaeis']
247 ['Lactuca saligna' 'Lactuca serriola' 'Lactuca sativa' 'Lactuca georgica'
 'Lactuca virosa' 'Lactuca biennis' 'Lactuca sativa x Lactuca serriola'
 'Lactuca perennis' 'Lactuca dregeana' 'Lactuca altaica']
6 ['Oryza sativa']
239 ['Ambrosia artemisiifolia']
338 ['Verticillium dahliae' 'Verticillium dahliae VDG1']
36 ['Apis mellifera' 'Apis mellifera capensis'
 'Apis mellifera capensis x Apis mellifera scutellata'
 'Apis mellifera mellifera']
18 ['Gallus gallus' 'Gallus gallus gallus' 'Gallus']
424 ['Alces alces' 'Alces americanus' 'Alces alces shirasi']
95 ['Drosophila nasuta x Drosophila albomicans' 'Drosophila albomicans'
 'Drosophila nasuta']


In [9]:
#### Loading functions

In [13]:
### load sequences from the contig file
def fasta_parsing_func(fasta_path):
    
    with open(fasta_path, "rb") as f:
        data = f.read()

    dctx = zstandard.ZstdDecompressor()
    data = dctx.decompress(data)
    
    if data is None:
        return [[]]
        
    sequences = []
    decoded_lines = data.decode() # .split("\n")
    
    for s in tqdm(SeqIO.parse(StringIO(decoded_lines), "fasta")):
        s = str(s.seq)
        s = "".join([c for c in s if c in ALPHABET]) # make sure only ALPHABET
        s = chop_at_first_repeated_kmer(s, KMER)
        yield s

In [14]:
from tqdm import tqdm
from collections import defaultdict

ALPHABET = {"A", "T", "C", "G"}
KMER = 31
MAX_SEQ_LENGTH = 2048


# Function to chop the sequence at the first repeated k-mer
def chop_at_first_repeated_kmer(sequence, k):
    kmers = set()
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if kmer in kmers:
            return sequence[:i + k - 1]
        kmers.add(kmer)
    return sequence  # No repeated k-mers found, return the whole sequence

# Reconstruct assembly graph
def find_overlaps_and_build_graph(sequences, k_mer=3):
    min_overlap = k_mer - 1
    prefix_dict = defaultdict(list)
    
    # Precompute the suffixes
    for i, seq in enumerate(sequences):
        prefix_dict[seq[:min_overlap]].append(i)

    graph = defaultdict(list)

    # Check for overlaps
    for i, seq1 in tqdm(enumerate(sequences), total=len(sequences)):
        seq1_suffix = seq1[-min_overlap:]
        graph[i] = []
        for j in prefix_dict[seq1_suffix]:
            if i != j:
                graph[i].append(j)

    return graph

# Perform random walk on the graph
def dfs_paths(graph, start, path=None, all_paths=None, depth=10):
    if path is None:
        path = [start]  # Initialize the path with the starting node
    if all_paths is None:
        all_paths = []  # Initialize the list to store all paths

    # If we revisit a node in the current path, it's a cycle, so we stop
    if start in path[:-1]:
        all_paths.append(path[:-1])
        return all_paths

    # Check if the current node is a leaf (no neighbors)
    if start not in graph or not graph[start]:
        all_paths.append(path)  # Add the current path as a complete path
        return all_paths
        
    if len(path) >= depth:
        all_paths.append(path)
        return all_paths
    # Explore each neighbor recursively, ensuring no cycles
    for neighbor in graph[start]:
        dfs_paths(graph, neighbor, path + [neighbor], all_paths)

    return all_paths

def random_walk_graph_sequences(graph, sequences):
    random_walk_sequences = []
    for node in tqdm(graph):
        paths = dfs_paths(graph, node)
        idx = np.random.randint(len(paths))
        path = paths[idx]
        seq = sequences[path[0]] + "".join([sequences[p][KMER-1:] for p in path[1:]])
        seq = seq[:MAX_SEQ_LENGTH]
        random_walk_sequences.append(seq)
    return random_walk_sequences

In [None]:
# EXAMPLE

contig_file = "data/ERR1738359.contigs.fa.zst"

In [16]:
# RAW SEQUENCES

sequences = np.array(list(fasta_parsing_func(contig_file)))
sequences[0]

152127it [00:06, 22512.74it/s]


'AAAAAAAAAAAAAAAAAAAAAAAATCAAAATAATC'

In [17]:
graph = find_overlaps_and_build_graph(sequences, KMER)

100%|██████████| 152127/152127 [00:01<00:00, 102638.31it/s]


In [18]:
# SEQUENCES AFTER RANDOM WALK PERFROMED TO EXTEND THEIR LENGTH
# LIKELY THIS SHOULD BE USED FOR TRAINING
# tbd filter short sequences
# tbd shuffle sequences
random_walk_sequences = random_walk_graph_sequences(graph, sequences)
random_walk_sequences[0]

100%|██████████| 152127/152127 [00:04<00:00, 34288.73it/s]


'AAAAAAAAAAAAAAAAAAAAAAAATCAAAATAATCA'