In [2]:
# Extracting WGS or HA/NA segments from fasta files belonging to IAV strains
# Removing sequences with ambigious nucleotides
# Mapping IAV strain labels to numerical labels
# Write out sequences, respective labels, segment name and EPI ID to output csv

import os
import pandas as pd
from multiprocessing import Pool, cpu_count
from glob import glob

def get_label_from_filename(filename):
    if "H1N1" in filename:
        return "influenza_a_H1N1"
    elif "H1N2" in filename:
        return "influenza_a_H1N2" 
    elif "H3N2" in filename:
        return "influenza_a_H3N2"
    elif "H3N8" in filename:
        return "influenza_a_H3N8"
    elif "H4N6" in filename:
        return "influenza_a_H4N6"
    elif "H5N1" in filename:
        return "influenza_a_H5N1"
    elif "H5N2" in filename:
        return "influenza_a_H5N2"
    elif "H5N6" in filename:
        return "influenza_a_H5N6"
    elif "H5N8" in filename:
        return "influenza_a_H5N8"
    elif "H7N9" in filename:
        return "influenza_a_H7N9"
    elif "H9N2" in filename:
        return "influenza_a_H9N2"
    else:
        return None

def process_fasta_file(fasta_file):
    valid_nucleotides = {'A', 'T', 'G', 'C'}
    label = get_label_from_filename(fasta_file)
    if label is None:
        return []

    sequences = []

    with open(fasta_file, 'r') as file:
        sequence = ''
        epi_id = ''
        segment = ''
        for line in file:
            if line.startswith('>'):
                if sequence and set(sequence.upper()).issubset(valid_nucleotides):
                    sequences.append((sequence.upper(), epi_id, segment, label))
                header_parts = line.strip().split('|')
                epi_id = header_parts[1] if len(header_parts) > 1 else ''
                segment = header_parts[6] if len(header_parts) > 6 else ''
                sequence = ''
            else:
                sequence += line.strip()
        
        if sequence and set(sequence.upper()).issubset(valid_nucleotides):
            sequences.append((sequence.upper(), epi_id, segment, label))

    return sequences

def label_to_number(label):
    label_map = {
        "influenza_a_H1N1": 1,
        "influenza_a_H1N2": 2,
        "influenza_a_H3N2": 3,
        "influenza_a_H3N8": 4,
        "influenza_a_H4N6": 5,
        "influenza_a_H5N1": 6,
        "influenza_a_H5N2": 7,
        "influenza_a_H5N6": 8,
        "influenza_a_H5N8": 9,
        "influenza_a_H7N9": 10,
        "influenza_a_H9N2": 11
    }
    return label_map.get(label, 0)

def process_fasta_files(directory, output_csv):
    fasta_files = glob(os.path.join(directory, "*.fasta"))

    with Pool(cpu_count()) as pool:
        results = pool.map(process_fasta_file, fasta_files)

    with open(output_csv, mode='w', newline='') as csv_file:
        fieldnames = ['sequence', 'EPI_ID', 'segment', 'label_name', 'label_number']
        df = pd.DataFrame(columns=fieldnames)
        df.to_csv(csv_file, mode='w', index=False, header=True)

        for result in results:
            for sequence, epi_id, segment, label in result:
                label_number = label_to_number(label)
                df = pd.DataFrame([[sequence, epi_id, segment, label, label_number]], columns=fieldnames)
                df.to_csv(csv_file, mode='a', header=False, index=False)

    # Load the CSV and count the number of sequences for each label
    data = pd.read_csv(output_csv)
    label_counts = data['label_name'].value_counts()

    print("Number of sequences for each label:")
    for label, count in label_counts.items():
        print(f"Label {label}: {count} sequences")

# Prompt user for the main directory path
main_dir = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/IAV_WGS'
#main_dir = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/IAV_strains'

process_fasta_files(main_dir, 'WGS_IAV_strains.csv')
#process_fasta_files(main_dir, 'HA_NA_IAV_strains.csv')

Number of sequences for each label:
Label influenza_a_H3N2: 37081 sequences
Label influenza_a_H1N1: 36966 sequences
Label influenza_a_H5N1: 30288 sequences
Label influenza_a_H1N2: 11402 sequences
Label influenza_a_H9N2: 9766 sequences
Label influenza_a_H5N8: 5809 sequences
Label influenza_a_H7N9: 4648 sequences
Label influenza_a_H3N8: 4622 sequences
Label influenza_a_H5N6: 3846 sequences
Label influenza_a_H4N6: 3190 sequences
Label influenza_a_H5N2: 2719 sequences


In [3]:
# Creating 250bps fragments with 50 bp overlaps from WGS or HA/NA segments generated above
# Generating reverse complementary sequence for each generated fragment 

import os
import pandas as pd
import csv
from tqdm import tqdm
from Bio.Seq import Seq

# Function to generate overlapping 250 bp fragments
def generate_overlapping_fragments(sequence, fragment_len=250, overlap=200):
    fragments = []
    for start in range(0, len(sequence) - fragment_len + 1, fragment_len - overlap):
        fragment = sequence[start:start + fragment_len]
        if len(fragment) == fragment_len:
            fragments.append(fragment)
    return fragments

# Prompt user for the main directory path
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/"
intermediate_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/intermediate_csvs"

# Path to the input and output CSV files
#input_csv_path = os.path.join("HA_NA_IAV_strains.csv")
input_csv_path = os.path.join("WGS_IAV_strains.csv")

#output_csv_path = os.path.join(intermediate_dir, "HA_NA_IAV_strains_250bp_50overlap_complementary.csv")
output_csv_path = os.path.join(intermediate_dir, "WGS_IAV_strains_250bp_50overlap_complementary.csv")

# Read the input CSV file in chunks
chunksize = 1000  # Adjust the chunk size as needed
input_columns = ["EPI_ID", "segment",
                 "label_name", "label_number", 
                 "sequence"]

# Open the output CSV file for writing
with open(output_csv_path, mode='w', newline='') as outfile:
    writer = csv.writer(outfile)
    writer.writerow(["EPI_ID", "segment",
                     "label_name", "label_number",  
                     "sequence"])  # Write header
    
    for chunk in tqdm(pd.read_csv(input_csv_path, usecols=input_columns, chunksize=chunksize)):
        for index, row in chunk.iterrows():
            epi_id = row["EPI_ID"]
            segment = row["segment"]
            variant_label = row["label_name"]
            variant_label_number = row["label_number"]
            sequence = row["sequence"]
            
            fragments = generate_overlapping_fragments(sequence)
            
            for fragment in fragments:
                # Write the original fragment
                writer.writerow([epi_id, segment, variant_label, variant_label_number, fragment])
                
                # Generate the reverse complementary fragment
                reverse_complement_fragment = str(Seq(fragment).reverse_complement())
                
                # Write the reverse complementary fragment
                writer.writerow([epi_id, segment, variant_label, variant_label_number, reverse_complement_fragment])

print(f"Fragments and their reverse complements have been written to {output_csv_path}")

151it [00:58,  2.59it/s]

Fragments and their reverse complements have been written to /mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/intermediate_csvs/HA_NA_IAV_strains_250bp_50overlap_complementary.csv





In [1]:
# Calculating the number of unique EPI IDs for each unique value of label_name (IAV strain)

import pandas as pd

# Assuming the CSV file path is 'WGS_by_VOC_finetune.csv'
iav_csv_file_path = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/intermediate_csvs/WGS_IAV_strains_250bp_50overlap_complementary.csv'
#iav_csv_file_path = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/intermediate_csvs/HA_NA_IAV_strains_250bp_50overlap_complementary.csv'


# Load the CSV file
data_iav = pd.read_csv(iav_csv_file_path)

# Group by label_name and count unique EPI_ID values
label_counts_iav = data_iav.groupby('label_name')['EPI_ID'].nunique()

# Print out the counts
for label, count in label_counts_iav.items():
    print(f"IAV labels: {label}, Unique EPI_ID Count: {count}")


IAV labels: influenza_a_H1N1, Unique EPI_ID Count: 4280
IAV labels: influenza_a_H1N2, Unique EPI_ID Count: 3441
IAV labels: influenza_a_H3N2, Unique EPI_ID Count: 7184
IAV labels: influenza_a_H3N8, Unique EPI_ID Count: 2366
IAV labels: influenza_a_H4N6, Unique EPI_ID Count: 1772
IAV labels: influenza_a_H5N1, Unique EPI_ID Count: 7016
IAV labels: influenza_a_H7N9, Unique EPI_ID Count: 2498
IAV labels: influenza_a_H9N2, Unique EPI_ID Count: 5538


In [1]:
# Creating subset of above output csv with desired number of unique EPI_IDs for each IAV strain label
import os
import pandas as pd

def filter_sequences(input_csv, output_csv):
    # Load the data from the CSV file
    data = pd.read_csv(input_csv,low_memory=False)
    
    # Count unique EPI IDs per label_name
    unique_epi_counts = data.groupby('label_name')['EPI_ID'].nunique()
    
    # Find labels with at least 100 unique EPI IDs
    valid_labels = unique_epi_counts[unique_epi_counts >= 3000].index
    
    # Filter the dataset to include only the valid labels
    filtered_data = data[data['label_name'].isin(valid_labels)]
    
    # Initialize a list to store the filtered sequences
    filtered_sequences = []
    
    # For each valid label, select 5k random EPI IDs and filter sequences
    for label in valid_labels:
        # Get unique EPI IDs for the current label
        label_data = filtered_data[filtered_data['label_name'] == label]
        unique_ep_ids = label_data['EPI_ID'].unique()
        
        # Randomly select 5k unique EPI IDs
        if len(unique_ep_ids) > 3000:
            selected_ep_ids = pd.Series(unique_ep_ids).sample(n=3000, random_state=1).tolist()
        else:
            selected_ep_ids = unique_ep_ids
        
        # Filter data based on selected EPI IDs
        selected_data = label_data[label_data['EPI_ID'].isin(selected_ep_ids)]
        filtered_sequences.append(selected_data)
    
    # Concatenate all filtered sequences into a single DataFrame
    result_df = pd.concat(filtered_sequences)
    
    # Save the filtered data to a new CSV file
    result_df.to_csv(output_csv, index=False)
    
    # Print number of sequences and unique EPI ID values for each label_name
    filtered_label_counts = result_df.groupby('label_name').agg(
        num_sequences=('sequence', 'count'),
        num_unique_epi_ids=('EPI_ID', 'nunique')
    )
    
    print("Number of sequences and unique EPI ID values for each label:")
    for label, row in filtered_label_counts.iterrows():
        print(f"Label {label}: {row['num_sequences']} sequences, {row['num_unique_epi_ids']} unique EPI IDs")


# Prompt user for the main directory path
main_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/"
intermediate_dir = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/intermediate_csvs"

# Path to the input and output CSV files
input_csv = os.path.join(intermediate_dir, "WGS_IAV_strains_250bp_50overlap_complementary.csv")
#input_csv = os.path.join(intermediate_dir, "HA_NA_IAV_strains_250bp_50overlap_complementary.csv")

output_csv = os.path.join(main_dir, "WGS_IAV_strains_250bp_50overlap_complementary_3k_epi.csv")
#output_csv = os.path.join(main_dir, "HA_NA_IAV_strains_250bp_50overlap_complementary_500_epi.csv")

# Execute the filtering and reporting
filter_sequences(input_csv, output_csv)


Number of sequences and unique EPI ID values for each label:
Label influenza_a_H1N1: 1360680 sequences, 3000 unique EPI IDs
Label influenza_a_H1N2: 1297556 sequences, 3000 unique EPI IDs
Label influenza_a_H3N2: 1324690 sequences, 3000 unique EPI IDs
Label influenza_a_H5N1: 1270086 sequences, 3000 unique EPI IDs
Label influenza_a_H9N2: 1247340 sequences, 3000 unique EPI IDs


In [2]:
# Checking to see if label_names are correctly mapped to label_numbers
import pandas as pd

def check_label_numbers_and_count_epi_ids(input_csv):
    """
    Check the label_number values for each unique label_name value and print the number of EPI_ID values for each label_name.
    """
    # Load the input CSV file
    df = pd.read_csv(input_csv)

    # Group by label_name and analyze label_number and EPI_ID counts
    grouped = df.groupby('label_name')

    for label_name, group in grouped:
        unique_label_numbers = group['label_number'].unique()
        epi_id_count = group['EPI_ID'].nunique()  # Count unique EPI_ID values

        print(f"Label Name: {label_name}")
        print(f"  - Unique Label Numbers: {unique_label_numbers}")
        print(f"  - Number of Unique EPI_IDs: {epi_id_count}")
        print("-" * 40)

input_csv = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/WGS_IAV_strains_250bp_50overlap_complementary_3k_epi.csv"
#input_csv = "/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/HA_NA_IAV_strains_250bp_50overlap_complementary_500_epi.csv"

# Run the function
check_label_numbers_and_count_epi_ids(input_csv)

Label Name: influenza_a_H1N1
  - Unique Label Numbers: [1]
  - Number of Unique EPI_IDs: 3000
----------------------------------------
Label Name: influenza_a_H1N2
  - Unique Label Numbers: [2]
  - Number of Unique EPI_IDs: 3000
----------------------------------------
Label Name: influenza_a_H3N2
  - Unique Label Numbers: [3]
  - Number of Unique EPI_IDs: 3000
----------------------------------------
Label Name: influenza_a_H5N1
  - Unique Label Numbers: [6]
  - Number of Unique EPI_IDs: 3000
----------------------------------------
Label Name: influenza_a_H9N2
  - Unique Label Numbers: [11]
  - Number of Unique EPI_IDs: 3000
----------------------------------------


In [5]:
# Verifying the number of unique EPI IDs for each IAV strain label
import pandas as pd

# Assuming the CSV file path is 'WGS_by_VOC_finetune.csv'
iav_csv_file_path = '/mmfs1/projects/changhui.yan/DeewanB/DNABert2_rnaseq/genome_files/unfiltered_multiple_genomes/WGS_IAV_strains_250bp_50overlap_complementary_3k_epi.csv'
# Load the CSV file
data_iav = pd.read_csv(iav_csv_file_path)

# Group by label_name and count unique EPI_ID values
label_counts_iav = data_iav.groupby('label_name')['EPI_ID'].nunique()

# Print out the counts
for label, count in label_counts_iav.items():
    print(f"IAV labels: {label}, Unique EPI_ID Count: {count}")


IAV labels: influenza_a_H1N1, Unique EPI_ID Count: 3000
IAV labels: influenza_a_H1N2, Unique EPI_ID Count: 3000
IAV labels: influenza_a_H3N2, Unique EPI_ID Count: 3000
IAV labels: influenza_a_H5N1, Unique EPI_ID Count: 3000
IAV labels: influenza_a_H9N2, Unique EPI_ID Count: 3000
