In [1]:
#without subclass/variants:

import os
import pandas as pd
from multiprocessing import Pool, cpu_count

def get_label_from_filename(filename):
    if "hcov-19" in filename:
        return "sars_cov_2"
    elif "influenzaA" in filename:
        return "influenza_a"
    elif "influenzaB" in filename:
        return "influenza_b"
    elif "rsv" in filename:
        return "rsv"
    else:
        return None

def process_fasta_file(fasta_file):
    valid_nucleotides = {'A', 'T', 'G', 'C'}
    label = get_label_from_filename(fasta_file)
    if label is None:
        return []

    sequences = []

    with open(fasta_file, 'r') as file:
        sequence = ''
        epi_id = ''
        for line in file:
            if line.startswith('>'):
                if sequence and set(sequence.upper()).issubset(valid_nucleotides):
                    sequences.append((sequence.upper(), epi_id, label))
                header_parts = line.strip().split('|')
                epi_id = header_parts[1] if len(header_parts) > 1 else ''
                sequence = ''
            else:
                sequence += line.strip()
        
        if sequence and set(sequence.upper()).issubset(valid_nucleotides):
            sequences.append((sequence.upper(), epi_id, label))

    return sequences

def label_to_number(label):
    label_map = {
        "sars_cov_2": 1,
        "influenza_a": 2,
        "influenza_b": 3,
        "rsv": 4
    }
    return label_map.get(label, 0)

def process_fasta_files(directory, output_csv):
    fasta_files = os.popen(f"ls {directory}/*.fasta").read().split()

    with Pool(cpu_count()) as pool:
        results = pool.map(process_fasta_file, fasta_files)

    with open(output_csv, mode='w', newline='') as csv_file:
        fieldnames = ['sequence', 'EPI_ID', 'label_name', 'label_number']
        df = pd.DataFrame(columns=fieldnames)
        df.to_csv(csv_file, mode='w', index=False, header=True)

        for result in results:
            for sequence, epi_id, label in result:
                label_number = label_to_number(label)
                df = pd.DataFrame([[sequence, epi_id, label, label_number]], columns=fieldnames)
                df.to_csv(csv_file, mode='a', header=False, index=False)

    # Load the CSV and count the number of sequences for each label
    data = pd.read_csv(output_csv)
    label_counts = data['label_name'].value_counts()

    print("Number of sequences for each label:")
    for label, count in label_counts.items():
        print(f"Label {label}: {count} sequences")

# Prompt user for the main directory path
#main_dir = input("Enter the main directory path: ")
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files"

process_fasta_files(main_dir, 'WGS_by_virus_finetune.csv')

Number of sequences for each label:
Label influenza_a: 76821 sequences
Label influenza_b: 75303 sequences
Label sars_cov_2: 18636 sequences
Label rsv: 918 sequences


In [2]:
## Path1 # create nonoverlapping 250 bps fragments with labels from above output csv

import pandas as pd
import csv
from tqdm import tqdm

# Function to generate XXX bp non-overlapping fragments
def generate_non_overlapping_fragments(sequence, fragment_len=250):
    fragments = []
    for start in range(0, len(sequence), fragment_len):
        fragment = sequence[start:start + fragment_len]
        if len(fragment) == fragment_len:
            fragments.append(fragment)
    return fragments


# Prompt user for the main directory path
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes"

# Path to the input and output CSV files
input_csv_path = os.path.join(main_dir, "WGS_by_virus_finetune.csv")
output_csv_path = os.path.join(main_dir, "WGS_by_virus_finetune1_250bp_fragments.csv")

# Read the input CSV file in chunks
chunksize = 1000  # Adjust the chunk size as needed
input_columns = ["EPI_ID", "label_name", "label_number", "sequence"]

# Open the output CSV file for writing
with open(output_csv_path, mode='w', newline='') as outfile:
    writer = csv.writer(outfile)
    writer.writerow(["EPI_ID", "label_name", "label_number", "sequence"])  # Write header
    
    for chunk in tqdm(pd.read_csv(input_csv_path, usecols=input_columns, chunksize=chunksize)):
        for index, row in chunk.iterrows():
            epi_id = row["EPI_ID"]
            label_name = row["label_name"]
            label_number = row["label_number"]
            sequence = row["sequence"]
            
            fragments = generate_non_overlapping_fragments(sequence)
            
            for fragment in fragments:
                writer.writerow([epi_id, label_name, label_number, fragment])


172it [00:27,  6.25it/s]


In [4]:
input_csv_path = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/WGS_by_virus_finetune1_250bp_fragments.csv"
data = pd.read_csv(input_csv_path,low_memory=False)

In [5]:
data.head()

Unnamed: 0,EPI_ID,label_name,label_number,sequence
0,EPI_ISL_736996,sars_cov_2,1,GGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTA...
1,EPI_ISL_736996,sars_cov_2,1,GCCTTGTCCCTGGTTTCAACGAGAAAACACACGTCCAACTCAGTTT...
2,EPI_ISL_736996,sars_cov_2,1,TGAGCTGGTAGCAGAACTCGAAGGCATTCAGTACGGTCGTAGTGGT...
3,EPI_ISL_736996,sars_cov_2,1,ACCCGTGAACTCATGCGTGAGCTTAACGGAGGGGCATACACTCGCT...
4,EPI_ISL_736996,sars_cov_2,1,CACCTTTTGAAATTAAATTGGCAAAGAAATTTGACACCTTCAATGG...


In [10]:
## Path2 # creating overlapping 250bps fragments 

import os
import pandas as pd
import csv
from tqdm import tqdm

# Function to generate overlapping 250 bp fragments
def generate_overlapping_fragments(sequence, fragment_len=250, overlap=50):
    fragments = []
    for start in range(0, len(sequence) - fragment_len + 1, fragment_len - overlap):
        fragment = sequence[start:start + fragment_len]
        if len(fragment) == fragment_len:
            fragments.append(fragment)
    return fragments

# Prompt user for the main directory path
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/"

# Path to the input and output CSV files
input_csv_path = os.path.join(main_dir, "WGS_by_virus_finetune.csv")
#input_csv_path = os.path.join(main_dir, "WGS_by_VOC_IAV_finetune_5k_epi.csv")

output_csv_path = os.path.join(main_dir, "WGS_by_virus_finetune_250bp_200overlap.csv")
#output_csv_path = os.path.join(main_dir, "WGS_by_VOC_IAV_finetune_5k_epi_250bp_fragments.csv")

# Read the input CSV file in chunks
chunksize = 1000  # Adjust the chunk size as needed
input_columns = ["EPI_ID", 
                 "label_name", "label_number", 
                 "sequence"]

# Open the output CSV file for writing
with open(output_csv_path, mode='w', newline='') as outfile:
    writer = csv.writer(outfile)
    writer.writerow(["EPI_ID", 
                     "label_name", "label_number",  
                     "sequence"])  # Write header
    
    for chunk in tqdm(pd.read_csv(input_csv_path, usecols=input_columns, chunksize=chunksize)):
        for index, row in chunk.iterrows():
            epi_id = row["EPI_ID"]
            variant_label = row["label_name"]
            variant_label_number = row["label_number"]
            sequence = row["sequence"]
            
            fragments = generate_overlapping_fragments(sequence)
            
            for fragment in fragments:
                writer.writerow([epi_id, variant_label, variant_label_number, fragment])


172it [00:32,  5.32it/s]


In [11]:
# Calculate the number of unique EPI IDs for each unique value of label_name
import pandas as pd

# Assuming the CSV file path is 'WGS_by_VOC_finetune.csv'
csv_file_path = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/WGS_by_virus_finetune_250bp_200overlap.csv'

# Load the CSV file
data = pd.read_csv(csv_file_path)

# Group by label_name and count unique EPI_ID values
label_counts = data.groupby('label_name')['EPI_ID'].nunique()

# Print out the counts
for label, count in label_counts.items():
    print(f"Label: {label}, Unique EPI_ID Count: {count}")


Label: influenza_a, Unique EPI_ID Count: 9849
Label: influenza_b, Unique EPI_ID Count: 9977
Label: rsv, Unique EPI_ID Count: 918
Label: sars_cov_2, Unique EPI_ID Count: 18636


In [12]:
# Create 250 bps long nonoverlapping fragments from sequences with 900 unique EPI_IDs

import os
import pandas as pd

def filter_sequences(input_csv, output_csv):
    # Load the data from the CSV file
    data = pd.read_csv(input_csv,low_memory=False)
    
    # Count unique EPI IDs per label_name
    unique_epi_counts = data.groupby('label_name')['EPI_ID'].nunique()
    
    # Find labels with at least 100 unique EPI IDs
    valid_labels = unique_epi_counts[unique_epi_counts >= 918].index
    
    # Filter the dataset to include only the valid labels
    filtered_data = data[data['label_name'].isin(valid_labels)]
    
    # Initialize a list to store the filtered sequences
    filtered_sequences = []
    
    # For each valid label, select 100 random EPI IDs and filter sequences
    for label in valid_labels:
        # Get unique EPI IDs for the current label
        label_data = filtered_data[filtered_data['label_name'] == label]
        unique_ep_ids = label_data['EPI_ID'].unique()
        
        # Randomly select 100 unique EPI IDs
        if len(unique_ep_ids) > 918:
            selected_ep_ids = pd.Series(unique_ep_ids).sample(n=918, random_state=1).tolist()
        else:
            selected_ep_ids = unique_ep_ids
        
        # Filter data based on selected EPI IDs
        selected_data = label_data[label_data['EPI_ID'].isin(selected_ep_ids)]
        filtered_sequences.append(selected_data)
    
    # Concatenate all filtered sequences into a single DataFrame
    result_df = pd.concat(filtered_sequences)
    
    # Save the filtered data to a new CSV file
    result_df.to_csv(output_csv, index=False)
    
    # Print number of sequences and unique EPI ID values for each label_name
    filtered_label_counts = result_df.groupby('label_name').agg(
        num_sequences=('sequence', 'count'),
        num_unique_epi_ids=('EPI_ID', 'nunique')
    )
    
    print("Number of sequences and unique EPI ID values for each label:")
    for label, row in filtered_label_counts.iterrows():
        print(f"Label {label}: {row['num_sequences']} sequences, {row['num_unique_epi_ids']} unique EPI IDs")


# Prompt user for the main directory path
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes"

# Path to the input and output CSV files
input_csv = os.path.join(main_dir, "WGS_by_virus_finetune_250bp_200overlap.csv")
output_csv = os.path.join(main_dir, "WGS_by_virus_finetune_250bp_200overlap_918epi.csv")

# Execute the filtering and reporting
filter_sequences(input_csv, output_csv)


Number of sequences and unique EPI ID values for each label:
Label influenza_a: 53926 sequences, 918 unique EPI IDs
Label influenza_b: 56420 sequences, 918 unique EPI IDs
Label rsv: 68961 sequences, 918 unique EPI IDs
Label sars_cov_2: 135912 sequences, 918 unique EPI IDs
