- GSE251676 dataset as an example 
- Dataset can be downloaded here: https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE251676&format=file
- The reference genome is: GCF_000011065.1

### calculate the normalized gene expression

In [12]:
import os
import gzip
import pandas as pd
import numpy as np
from scipy.stats import zscore
import os
import glob
import gc

import torch
import esm
from Bio import SeqIO
from Bio.Seq import Seq
from tqdm import tqdm

import re
from collections import Counter, OrderedDict
from Bio.SeqRecord import SeqRecord

In [1]:
# Define the folder containing your gzipped files
data_folder = '../example_data/GSE251676/' 

# Function to calculate TPM using gene lengths
def calculate_tpm(counts, lengths):
    rpk = counts / (lengths / 1e3)  # Reads per kilobase
    tpm = (rpk / np.sum(rpk)) * 1e6  # Normalize to per million
    return tpm

# Initialize an empty dictionary to store data from each file
data_dict = {}

# Process each gzipped TSV file in the folder
for file_name in os.listdir(data_folder):
    if file_name.endswith('.gz'):
        file_path = os.path.join(data_folder, file_name)
        
        # Extract file name without extension to use as column name
        column_name = file_name.split('.')[0]
        
        # Open and read the gzipped file
        with gzip.open(file_path, 'rt') as f:
            # Read into a DataFrame
            df = pd.read_csv(f, sep='\t', comment='#')
            
            # Identify the column containing counts dynamically
            count_column = [col for col in df.columns if col.endswith('_sorted.sam')][0]
            df = df[['Geneid', 'Length', count_column]]  # Select relevant columns
            df.columns = ['gene_id', 'length', 'count']  # Standardize column names
            
        # Calculate TPM values for the current file
        df['tpm'] = calculate_tpm(df['count'].values, df['length'].values)
        
        # Log-transform the TPM values
        df['log_tpm'] = np.log1p(df['tpm'])  # log1p handles log(0) by log(1 + x)
        
        # Store the log-transformed values in the dictionary
        data_dict[column_name] = df.set_index('gene_id')['log_tpm']

# Combine all columns into a single DataFrame based on gene_id
combined_df = pd.DataFrame(data_dict)

# Z-score normalization for each row (gene)
combined_df = combined_df.apply(zscore, axis=0)

# Save the DataFrame to a CSV file
combined_df.to_csv('../example_data/GSE251676/log_tpm.csv')

### calculate the ESM embeddings

In [9]:
# ----------------------------------------------------------------------
# 1) Load the ESM-2 model and set up device
# ----------------------------------------------------------------------
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # Disables dropout (no randomness in output)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

model = model.to(device)

scripts_path = os.path.join("..", "scripts")
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

from utils import get_protein_sequences

# Directory containing the paired FASTA/GFF files
input_dir = "../example_data/GSE251676/"  

# Retrieve lists of .fna and .gtf files
fasta_files = glob.glob(os.path.join(input_dir, "*.fna"))
gff_files = glob.glob(os.path.join(input_dir, "*.gff"))

# Pair up files based on the base filename
paired_files = {}
for fasta_file in fasta_files:
    base_name = os.path.basename(fasta_file).rsplit(".", 1)[0]
    matching_gff = [
        gff for gff in gff_files
        if os.path.basename(gff).startswith(base_name)
    ]
    if matching_gff:
        paired_files[base_name] = {
            "fasta": fasta_file,
            "gff": matching_gff[0]
        }

print(f"Found {len(paired_files)} paired FASTA/GFF files.")

files_processed = 0
# Process each pair
for base_name, file_dict in paired_files.items():
    genome_file = file_dict["fasta"]
    annotation_file = file_dict["gff"]
    output_file = os.path.join(input_dir, f"{base_name}_embeddings.txt")

    print(f"\nProcessing: {base_name}")
    print(f"  FASTA: {genome_file}")
    print(f"  GFF  : {annotation_file}")
    print(f"  Output embeddings: {output_file}")

    # If embeddings file already exists, skip
    if os.path.exists(output_file):
        print("  Embedding file already exists. Skipping...")
        continue

    try:
        # 3.1) Extract protein sequences
#         protein_seqs = get_protein_sequences(genome_file, annotation_file)
        protein_seqs, gene_names, gene_lengths, amino_acid_proportions = \
            get_protein_sequences(genome_file, annotation_file)
        print(f"  Extracted {len(protein_seqs)} protein sequences.")

        # 3.2) Generate embeddings with a progress bar
        seq_repr_list = []
        for idx, seq in tqdm(
            enumerate(protein_seqs, start=1),
            total=len(protein_seqs),
            desc="Proteins"
        ):
            if len(seq) >= 1500:
                # Skip extremely long proteins
                continue

            # Prepare data for ESM
            protein_name = f"protein{idx}"
            data = [(protein_name, str(seq))]

            # Convert text to tokens
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)

            with torch.no_grad():
                results = model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_reprs = results["representations"][33]

            # Count non-padding tokens, ignoring start and end tokens
            seq_len = (batch_tokens != alphabet.padding_idx).sum(1)[0]
            seq_repr = token_reprs[0, 1 : seq_len - 1].mean(0).cpu().numpy()
            seq_repr_list.append(seq_repr)

            # Free memory
            del batch_tokens, token_reprs, results
            torch.cuda.empty_cache()
            gc.collect()

        # 3.3) Save the embeddings to a text file
        if seq_repr_list:
            embeddings_matrix = np.vstack(seq_repr_list)
            np.savetxt(output_file, embeddings_matrix)
            print(f"  Saved {embeddings_matrix.shape[0]} embeddings to {output_file}")
        else:
            print(f"  No embeddings generated for {base_name}.")

    except Exception as e:
        print(f"  Error processing {base_name}: {e}")

print("\nProcessing complete.")

Using CUDA: NVIDIA GeForce RTX 3090 Ti
Found 1 paired FASTA/GFF files.

Processing: GCF_000011065.1
  FASTA: ../example_data/GSE251676\GCF_000011065.1.fna
  GFF  : ../example_data/GSE251676\GCF_000011065.1.gff
  Output embeddings: ../example_data/GSE251676/GCF_000011065.1_embeddings.txt
  Extracted 4792 protein sequences.


Proteins: 100%|████████████████████████████████████████████████████████████████████| 4792/4792 [16:10<00:00,  4.94it/s]


  Saved 4792 embeddings to ../example_data/GSE251676/GCF_000011065.1_embeddings.txt

Processing complete.


### generate training data

In [10]:
import sys
import os

scripts_path = os.path.join("..", "scripts")
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

from utils import get_protein_sequences

# Adjust these values as needed
strain_id = 'GCF_000011065.1'
folder_name = '../example_data/GSE251676/'

# Input files
genome_file = f'{folder_name}/{strain_id}.fna'
annotation_file = f'{folder_name}/{strain_id}.gff'
embeddings_file = f'{folder_name}/{strain_id}_embeddings.txt'
tpm_file = f'{folder_name}/log_tpm.csv'

# Output files
filtered_tpm_file = f'{folder_name}/{strain_id}_filtered_log_tpm.csv'
filtered_embeddings_file = f'{folder_name}/{strain_id}_filtered_embeddings.txt'
filtered_meta_file = f'{folder_name}/{strain_id}_filtered_meta.txt'

# 1) Extract protein sequences and gene metadata
protein_sequences, gene_names, gene_lengths, amino_acid_proportions = \
    get_protein_sequences(genome_file, annotation_file)

# 2) Prepare 'gene_meta' as [normalized length + 20 AA proportions]
gene_lengths_array = np.array(gene_lengths, dtype=float)
max_length = gene_lengths_array.max() if len(gene_lengths_array) > 0 else 1.0
gene_lengths_norm = (gene_lengths_array / max_length).reshape(-1, 1)

aa_props_array = np.array(amino_acid_proportions)
gene_meta = np.hstack([gene_lengths_norm, aa_props_array])

# (Optional) If you need to modify gene names (commented out):
# gene_names = [x[:6] + "_RS0" + x[6:] for x in gene_names]
# or other transformations depending on your naming scheme...

# 3) Load existing ESM embeddings (matching the same order as gene_names)
embeddings = np.loadtxt(embeddings_file)
print("Embeddings shape:", embeddings.shape)

# 4) Load log_tpm data
tpm = pd.read_csv(tpm_file)
print("TPM shape:", tpm.shape)
print("TPM columns:", tpm.columns)

# Check your gene ID column name in 'tpm'. Suppose it's 'gene_id'.
# Filter for overlap with 'gene_names':
overlap_indices = tpm['gene_id'].isin(gene_names)

# 5) Create a new tpm dataframe for overlapping entries
new_tpm = tpm[overlap_indices].copy()

# 6) Find corresponding indices in 'gene_names' for the overlapping genes
#    This preserves the correct order for slicing embeddings + meta
indices_in_x = [gene_names.index(g) for g in new_tpm['gene_id']]

# 7) Slice embeddings and meta to match those overlapping indices
new_embeddings = embeddings[indices_in_x, :]
new_meta = gene_meta[indices_in_x, :]

# 8) Save results
new_tpm.to_csv(filtered_tpm_file, index=False)
np.savetxt(filtered_embeddings_file, new_embeddings)
np.savetxt(filtered_meta_file, new_meta)

print("Data has been filtered and saved.")
print("Filtered embeddings shape:", new_embeddings.shape)

Embeddings shape: (4792, 1280)
TPM shape: (4774, 65)
TPM columns: Index(['gene_id', 'GSM7985888_RN1', 'GSM7985889_RN2', 'GSM7985890_RN3',
       'GSM7985891_RN4', 'GSM7985892_RN5', 'GSM7985893_RN6', 'GSM7985894_RN7',
       'GSM7985895_RN8', 'GSM7985896_RN9', 'GSM7985897_RN10',
       'GSM7985898_RN11', 'GSM7985899_RN12', 'GSM7985900_RN13',
       'GSM7985901_RN14', 'GSM7985902_RN15', 'GSM7985903_RN16',
       'GSM7985904_RN17', 'GSM7985905_RN18', 'GSM7985906_RN19',
       'GSM7985907_RN20', 'GSM7985908_RN21', 'GSM7985909_RN22',
       'GSM7985910_RN23', 'GSM7985911_RN24', 'GSM7985912_RN25',
       'GSM7985913_RN26', 'GSM7985914_RN27', 'GSM7985915_RN28',
       'GSM7985916_RN29', 'GSM7985917_RN30', 'GSM7985918_RN31',
       'GSM7985919_RN32', 'GSM7985920_RN33', 'GSM7985921_RN34',
       'GSM7985922_RN35', 'GSM7985923_RN36', 'GSM7985924_RN37',
       'GSM7985925_RN38', 'GSM7985926_RN39', 'GSM7985927_RN40',
       'GSM7985928_RN41', 'GSM7985929_RN42', 'GSM7985930_RN43',
       'GSM798593