# Imports

In [51]:
import os
import random
import pickle
import datetime
import networkx as nx
from multiprocessing import Process, Pool
import itertools
from collections import Counter, defaultdict
import torch
from torch_geometric.utils.convert import from_networkx
import numpy as np

# Load reads from .fastq files

In [37]:
def generate_kmers(sequence, k):
    """
    Generate k-mers from a sequence.
    """
    return [sequence[i:i + k] for i in range(len(sequence) - k + 1)]

def generate_kmer_to_index_dict(k):
    all_kmers = sorted([''.join(x) for x in itertools.product('ATGCN', repeat=k)])
    return {k: v for v, k in enumerate(all_kmers)}
    
def kmer_to_index(kmer, kmer_to_index_dict):
    return kmer_to_index_dict[kmer]

def subkmer_frequencies_in_kmer(kmer, subkmer_length, kmer_to_index_dict):
    """Calculates the frequency of each subkmer in a kmer."""
    subkmer_counts = Counter(kmer[i:i+subkmer_length] for i in range(len(kmer) - subkmer_length + 1))
    frequencies = np.zeros(5**subkmer_length)
    for subkmer, count in subkmer_counts.items():
        index = kmer_to_index(subkmer, kmer_to_index_dict)
        frequencies[index] = count
    return frequencies

In [53]:
sub2mer_dict = generate_kmer_to_index_dict(2)
sub2mer_dict

{'AA': 0,
 'AC': 1,
 'AG': 2,
 'AN': 3,
 'AT': 4,
 'CA': 5,
 'CC': 6,
 'CG': 7,
 'CN': 8,
 'CT': 9,
 'GA': 10,
 'GC': 11,
 'GG': 12,
 'GN': 13,
 'GT': 14,
 'NA': 15,
 'NC': 16,
 'NG': 17,
 'NN': 18,
 'NT': 19,
 'TA': 20,
 'TC': 21,
 'TG': 22,
 'TN': 23,
 'TT': 24}

In [39]:
def parse_train_labels(data_path: str, outdir: str = None, save_to_json=True,
                       data_exts: tuple[str] = ('.gz', '.fastq', '.fasta', '.fq', '.fa')) -> tuple[dict, dict]:
    unique_codes = set()

    for sample in os.listdir(data_path):
        sample_filename = os.path.basename(sample)
        sample_ext = os.path.splitext(sample_filename)[1]
        if sample_ext not in data_exts:
            continue

        # Example: CAMDA20_MetaSUB_CSD16_BCN_012_1_kneaddata_subsampled_20_percent.fastq
        city_id = list(sample_filename.split('_'))[3]

        unique_codes.add(city_id)

    unique_codes_list = list(unique_codes)
    id_to_code = {}
    code_to_id = {}
    for idx, code in enumerate(unique_codes_list):
        id_to_code[idx] = code
        code_to_id[code] = idx

    if outdir:
        id_to_code_outfile = os.path.join(outdir, 'id_to_code.json')
        code_to_id_outfile = os.path.join(outdir, 'code_to_id.json')
    else:
        id_to_code_outfile = 'id_to_code.json'
        code_to_id_outfile = 'code_to_id.json'

    if save_to_json:
        with open(id_to_code_outfile, 'w') as f:
            json.dump(id_to_code, f)

        with open(code_to_id_outfile, 'w') as f:
            json.dump(code_to_id, f)

    return id_to_code, code_to_id


def get_reads_from_fq(fq_path: str) -> list[str]:
    reads = []
    with open(fq_path, 'r') as f:
        fastq_reads = f.readlines()
        for i in range(0, len(fastq_reads), 4):
            reads.append(str(fastq_reads[i + 1].rstrip()))
    return reads


def get_labeled_reads_from_dir_with_samples(indir: str) -> dict:
    reads_for_samples = {} # dict
    id_to_code, code_to_id = parse_train_labels(data_path=indir, save_to_json=False)
    files_in_dir = os.listdir(indir)
    for file in files_in_dir:
        print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} processing file {file}')
        if os.path.splitext(file)[1] != '.fastq':
            print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} skipping {file}')
            continue
        city_code = os.path.basename(file).split('_')[3]
        sample_name = os.path.splitext(os.path.basename(file))[0]
        int_label = int(code_to_id[city_code])
        print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} {city_code = } ; {int_label = }')
        print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} getting reads')
        reads = get_reads_from_fq(os.path.join(indir, file))
        print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} saving labelled reads')
        if sample_name not in reads_for_samples:
            reads_for_samples[sample_name] = [int_label, reads]
        else:
            reads_for_samples[sample_name][1].extend(reads)
    return reads_for_samples

In [40]:
genome_sequences = get_labeled_reads_from_dir_with_samples('/path/to/dir/with/fastq/samples')

07 Aug 2024 21:27:44 processing file CAMDA20_MetaSUB_CSD17_HKG_010_1_kneaddata_subsampled_20_percent.fastq
07 Aug 2024 21:27:44 city_code = 'HKG' ; int_label = 0
07 Aug 2024 21:27:44 getting reads
07 Aug 2024 21:27:44 saving labelled reads
07 Aug 2024 21:27:44 processing file dbgs
07 Aug 2024 21:27:44 skipping dbgs


In [41]:
outdir = '/path/to/graphs/outdir'
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [42]:
graphs = []
kmer_len = 4
subkmer_len = 2
num_features = 5**subkmer_len
sub2mer_to_dict = generate_kmer_to_index_dict(subkmer_len)

In [45]:
def build_graph_max(dict_item):
    sample_name, code_and_reads = dict_item
    city_code, seqs = code_and_reads

    G = nx.DiGraph()
    kmers = set()
    print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} getting k-mers from {len(seqs)} reads')
    transition_counts = defaultdict(int)
    for idx, seq in enumerate(seqs):
        if idx % 100_000 == 0:
            print(f'processed {idx} reads')
        kmers_in_read = generate_kmers(seq, kmer_len)
        kmers = kmers.union(set(kmers_in_read))
        for kk in range(len(kmers_in_read) - 1):
            transition_counts[(kmers_in_read[kk], kmers_in_read[kk + 1])] += 1
    nodes = []
    print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} adding nodes to graph')
    for kmer in kmers:
        nodes.append((kmer, {"x": torch.as_tensor(subkmer_frequencies_in_kmer(kmer, subkmer_len, sub2mer_to_dict)/(kmer_len-1), dtype=torch.float32)}))
    G.add_nodes_from(nodes)

    max_count = max(transition_counts.values())
    print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} adding edges')
    for key in transition_counts.keys():
        G.add_edge(key[0], key[1], weight=transition_counts[key]/max_count)

    print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} saving as torch graph')
    torch_graph = from_networkx(G)
    torch_graph['y'] = torch.tensor([city_code])

    print(f'{datetime.datetime.now().strftime("%d %h %Y %H:%M:%S")} saving graph for sample {sample_name}')
    outfile_graph_name = os.path.join(outdir, sample_name + '.labeled_graph_max')
    with open(outfile_graph_name, 'wb') as f:
        pickle.dump(torch_graph, f)

In [52]:
for genome_seq in genome_sequences.items():
    build_graph_max(genome_seq)

07 Aug 2024 21:32:14 getting k-mers from 535034 reads
processed 0 reads
processed 100000 reads
processed 200000 reads
processed 300000 reads
processed 400000 reads
processed 500000 reads
07 Aug 2024 21:32:32 adding nodes to graph
07 Aug 2024 21:32:32 adding edges
07 Aug 2024 21:32:32 saving as torch graph
07 Aug 2024 21:32:32 saving graph for sample CAMDA20_MetaSUB_CSD17_HKG_010_1_kneaddata_subsampled_20_percent
