In [None]:
"""
AA Enrich Predicted Frequencies (Jupyter version)

- Configure paths in the config section.
- Set SELECT_FAMILIES to a list of virus-family basenames (no extension) to process only those,
  or to None to process all FASTA files in FASTA_DIR.
- Prints show progress.
"""

import os
import random
from collections import Counter
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq

# -----------------------------
# Config (edit these paths)
# -----------------------------
FASTA_DIR = "/Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta"            # folder containing one FASTA per virus family
SWAP_TABLE_CSV = "/Users/ishaharris/Projects/ribolings/data/CodonSwapTable_withStop.csv"
OUTPUT_CSV = "/Users/ishaharris/Projects/ribolings/data/freqs/monoaa_freqs.csv"`
RANDOM_SEED = 42
N_DRAWS_PER_AA = 100

# Set SELECT_FAMILIES to a list of basenames (filename without extension) to process only those.
# Example: SELECT_FAMILIES = ['coronavirus', 'orthomyxovirus']
# Or set to None to process every fasta in FASTA_DIR.
SELECT_FAMILIES = ['Arena', 'Calici', 'Corona', 'Filo', 'Flavi', 'Hanta', 'Picorna', 'Orthomyxo', 'Paramyxo', 'Peribunya','Rhabdo']  

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# -----------------------------
# Helpers
# -----------------------------

def load_swap_table(path):
    print(f"Loading swap table from: {path}")
    df = pd.read_csv(path)
    cols = {c.lower(): c for c in df.columns}
    required = ['codon', 'number', 'aa']
    for r in required:
        if r not in cols:
            raise ValueError(f"Swap table missing required column: {r}")
    df = df.rename(columns={cols['codon']:'Codon', cols['number']:'Number', cols['aa']:'AA'})
    df['Codon'] = df['Codon'].astype(str).str.lower()
    df['AA'] = df['AA'].astype(str)
    print(f"Swap table loaded: {len(df)} rows, {df['AA'].nunique()} amino acids found")
    return df


def build_aa_freqs_normalised(swap_df, draws=N_DRAWS_PER_AA):
    print(f"Building AA-normalised nucleotide frequencies (draws per AA = {draws})")
    aa_groups = swap_df.groupby('AA')
    aa_rows = {}
    for aa, group in aa_groups:
        codons = group['Codon'].tolist()
        if not codons:
            continue
        drawn = [random.choice(codons) for _ in range(draws)]
        concat = ''.join(drawn)
        counts = Counter(list(concat))
        total = len(concat)
        freqs = {
            'A': counts.get('a', 0) / total,
            'C': counts.get('c', 0) / total,
            'G': counts.get('g', 0) / total,
            'T': counts.get('t', 0) / total,
        }
        aa_rows[aa] = {k: v * 3.0 for k, v in freqs.items()}
    aa_freqs_norm = pd.DataFrame.from_dict(aa_rows, orient='index')
    aa_freqs_norm = aa_freqs_norm[['A','C','G','T']]
    print(f"Built AA freq table for {len(aa_freqs_norm)} AAs")
    return aa_freqs_norm


def predict_nt_freq_for_seq(record_seq, aa_freqs_norm):
    seq = Seq(str(record_seq))
    translation = seq.translate()
    aa_list = list(str(translation))
    if not aa_list:
        return None
    nt_sums = {'A': 0.0, 'C': 0.0, 'G': 0.0, 'T': 0.0}
    used_aas = 0
    for aa in aa_list:
        aa_key = 'X' if aa == '*' else aa
        if aa_key not in aa_freqs_norm.index:
            continue
        row = aa_freqs_norm.loc[aa_key]
        for nt in ['A','C','G','T']:
            nt_sums[nt] += float(row[nt])
        used_aas += 1
    if used_aas == 0:
        return None
    total_nt = used_aas * 3.0
    return {nt: nt_sums[nt] / total_nt for nt in ['A','C','G','T']}


def seq_is_standard_dna(seq_str):
    return set(seq_str.lower()).issubset({'a','c','g','t'})


def average_predicted_freqs_for_fasta(fasta_path, aa_freqs_norm):
    predicted_list = []
    n_total, n_skipped_nonstandard, n_skipped_no_prediction = 0, 0, 0
    for rec in SeqIO.parse(fasta_path, 'fasta'):
        n_total += 1
        seq_str = str(rec.seq)
        if not seq_is_standard_dna(seq_str):
            n_skipped_nonstandard += 1
            continue
        pred = predict_nt_freq_for_seq(seq_str, aa_freqs_norm)
        if pred is None:
            n_skipped_no_prediction += 1
            continue
        predicted_list.append([pred['A'], pred['C'], pred['G'], pred['T']])
    if not predicted_list:
        return {'A_pred': np.nan,'C_pred': np.nan,'G_pred': np.nan,'T_pred': np.nan,
                'n_seqs': n_total,'n_skipped_nonstandard': n_skipped_nonstandard,
                'n_skipped_no_prediction': n_skipped_no_prediction}
    arr = np.array(predicted_list)
    mean_vals = np.nanmean(arr, axis=0)
    return {
        'A_pred': float(mean_vals[0]),
        'C_pred': float(mean_vals[1]),
        'G_pred': float(mean_vals[2]),
        'T_pred': float(mean_vals[3]),
        'n_seqs': n_total,
        'n_skipped_nonstandard': n_skipped_nonstandard,
        'n_skipped_no_prediction': n_skipped_no_prediction,
    }

# -----------------------------
# Run in notebook cell
# -----------------------------

# Load swap table and build AA freq table
swap_df = load_swap_table(SWAP_TABLE_CSV)
aa_freqs_norm = build_aa_freqs_normalised(swap_df, draws=N_DRAWS_PER_AA)
if '*' in aa_freqs_norm.index and 'X' not in aa_freqs_norm.index:
    aa_freqs_norm = aa_freqs_norm.rename(index={'*':'X'})

# List FASTA files
all_fnames = sorted([f for f in os.listdir(FASTA_DIR) if f.lower().endswith(('.fa','.fasta','.fas'))])
print(f"Found {len(all_fnames)} FASTA files in {FASTA_DIR}")

# Prepare list of files to process based on SELECT_FAMILIES
if SELECT_FAMILIES is None:
    to_process = all_fnames
    print("SELECT_FAMILIES is None -> processing all FASTA files")
else:
    sel_set = set(SELECT_FAMILIES)
    to_process = []
    for f in all_fnames:
        base = os.path.splitext(f)[0]
        if base in sel_set:
            to_process.append(f)
    missing = sel_set - set(os.path.splitext(f)[0] for f in all_fnames)
    if missing:
        print(f"Warning: the following selected families were not found in {FASTA_DIR}: {sorted(missing)}")
    print(f"Processing {len(to_process)} selected FASTA files")

results = []
if not to_process:
    print("No FASTA files to process. Exiting.")
else:
    for idx, fname in enumerate(to_process, start=1):
        path = os.path.join(FASTA_DIR, fname)
        family = os.path.splitext(fname)[0]
        print(f"[{idx}/{len(to_process)}] Processing '{family}' -> {path}")
        try:
            stats = average_predicted_freqs_for_fasta(path, aa_freqs_norm)
        except Exception as e:
            print(f"  Error processing {fname}: {e}")
            stats = {'A_pred': np.nan,'C_pred': np.nan,'G_pred': np.nan,'T_pred': np.nan,
                     'n_seqs': 0,'n_skipped_nonstandard': 0,'n_skipped_no_prediction': 0}
        print(f"  Seqs: {stats['n_seqs']}, skipped_nonstandard: {stats['n_skipped_nonstandard']}, skipped_no_prediction: {stats['n_skipped_no_prediction']}")
        results.append({
            'VirusFamily': family,
            'A_pred': stats['A_pred'],
            'C_pred': stats['C_pred'],
            'G_pred': stats['G_pred'],
            'T_pred': stats['T_pred'],
            'n_seqs': stats['n_seqs'],
            'n_skipped_nonstandard': stats['n_skipped_nonstandard'],
            'n_skipped_no_prediction': stats['n_skipped_no_prediction'],
        })

out_df = pd.DataFrame(results)
out_df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved predicted nucleotide frequencies for {len(out_df)} FASTA files to {OUTPUT_CSV}")
out_df


Loading swap table from: /Users/ishaharris/Projects/ribolings/data/CodonSwapTable_withStop.csv
Swap table loaded: 64 rows, 21 amino acids found
Building AA-normalised nucleotide frequencies (draws per AA = 100)
Built AA freq table for 21 AAs
Found 25 FASTA files in /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta
Processing 11 selected FASTA files
[1/11] Processing 'Arena' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Arena.fasta




  Seqs: 249, skipped_nonstandard: 12, skipped_no_prediction: 0
[2/11] Processing 'Calici' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Calici.fasta
  Seqs: 72, skipped_nonstandard: 3, skipped_no_prediction: 0
[3/11] Processing 'Corona' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Corona.fasta
  Seqs: 458, skipped_nonstandard: 11, skipped_no_prediction: 0
[4/11] Processing 'Filo' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Filo.fasta
  Seqs: 125, skipped_nonstandard: 2, skipped_no_prediction: 0
[5/11] Processing 'Flavi' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Flavi.fasta
  Seqs: 60, skipped_nonstandard: 14, skipped_no_prediction: 0
[6/11] Processing 'Hanta' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Hanta.fasta
  Seqs: 131, skipped_nonstandard: 5, skipped_no_prediction: 0
[7/11] Processing 'Orthomyxo' -> /Users/ishaharris/Projects/ribolings/data/virus/rna_cds_fasta/Orthomyxo.fasta

IsADirectoryError: [Errno 21] Is a directory: '/Users/ishaharris/Projects/ribolings/data/freqs'

In [7]:
OUTPUT_CSV = "/Users/ishaharris/Projects/ribolings/data/freqs/monoaa_freqs.csv"

out_df = pd.DataFrame(results)
out_df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved predicted nucleotide frequencies for {len(out_df)} FASTA files to {OUTPUT_CSV}")
out_df

Saved predicted nucleotide frequencies for 11 FASTA files to /Users/ishaharris/Projects/ribolings/data/freqs/monoaa_freqs.csv


Unnamed: 0,VirusFamily,A_pred,C_pred,G_pred,T_pred,n_seqs,n_skipped_nonstandard,n_skipped_no_prediction
0,Arena,0.28664,0.227732,0.234334,0.251294,249,12,0
1,Calici,0.251726,0.257963,0.259127,0.231184,72,3,0
2,Corona,0.261607,0.232161,0.23095,0.275283,458,11,0
3,Filo,0.27084,0.250647,0.242735,0.235778,125,2,0
4,Flavi,0.240566,0.244986,0.274156,0.240293,60,14,0
5,Hanta,0.282991,0.224665,0.246442,0.245902,131,5,0
6,Orthomyxo,0.291676,0.212279,0.265402,0.230643,67,1,0
7,Paramyxo,0.287148,0.236739,0.24244,0.233673,647,5,0
8,Peribunya,0.285585,0.197754,0.287368,0.229294,3,0,0
9,Picorna,0.273023,0.241535,0.241724,0.243717,648,94,0
