# This notebook analyzes features of henipavirus and nipah sequence conservation and compares to DMS data

In [None]:
#this cell is tagged as parameters for `papermill` parameterization
nipah_config = None
nipah_alignment = None
entropy_output = None

In [None]:
import math
import os

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats
from scipy import stats

import subprocess
import tempfile
import yaml
from Bio import Entrez
from Bio import SeqIO
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Align.Applications import MuscleCommandline
from Bio.Align.Applications import MafftCommandline
from Bio.Seq import Seq
from Bio.Align import PairwiseAligner

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")

### For running interactively:

In [None]:
if nipah_alignment is None:
    altair_config = 'data/custom_analyses_data/theme.py'
    nipah_config = 'nipah_config.yaml'
    #e2_binding = 'results/receptor_affinity/averages/EFNB2_monomeric_mut_effect.csv'
    #e2_entry = 'results/func_effects/averages/CHO_EFNB2_low_func_effects.csv'
    #e3_binding = 'results/receptor_affinity/averages/EFNB3_dimeric_mut_effect.csv'
    #e3_entry = 'results/func_effects/averages/CHO_EFNB3_low_func_effects.csv'
    nipah_alignment = 'data/custom_analyses_data/alignments/Nipah_RBP_AA_align.fasta'
    entropy_output = 'results/entropy/entropy.csv'
    #entry_scores_niv_poly = 'results/images/niv_polymorphic_entry.html'
    #binding_scores_niv_poly = 'results/images/niv_polymorphic_binding.html'

In [None]:
with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Pull represantative henipavirus RBP amino acid sequences from genbank, align, calculate entropy, and convert to a dataframe

In [None]:
def shannon_entropy(column):
    """Compute the Shannon entropy of a column in the alignment."""
    counts = {}
    for aa in column:
        if aa in counts:
            counts[aa] += 1
        else:
            counts[aa] = 1

    entropy = 0.0
    for key in counts:
        freq = counts[key] / len(column)
        entropy += freq * math.log2(freq)
    return -entropy

def fetch_and_align(accession_numbers, email, output_folder="."):
    """
    Fetch sequences from GenBank based on accession numbers, align them,
    and return the alignment as a pandas DataFrame.

    Parameters:
    - accession_numbers: List of accession numbers.
    - email: Email address to be used with NCBI's Entrez.
    - output_folder: The directory where output files will be saved.

    Returns:
    - DataFrame representation of the alignment.
    """
    # Ensure the output directory exists, if not, create it.
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Fetch sequences from GenBank
    Entrez.email = email
    sequences = []
    for acc in accession_numbers:
        handle = Entrez.efetch(db="protein", id=acc, rettype="fasta", retmode="text")
        seq_record = SeqIO.read(handle, "fasta")
        sequences.append(seq_record)
        handle.close()

    # Define file paths
    temp_sequences_path = os.path.join(output_folder, "temp_sequences.fasta")
    aligned_path = os.path.join(output_folder, "aligned.fasta")

    # Write sequences to a temporary fasta file
    SeqIO.write(sequences, temp_sequences_path, "fasta")

    # Align using MUSCLE
    muscle_exe = "/fh/fast/bloom_j/software/miniconda3/envs/BloomLab/bin/muscle"
    muscle_result = subprocess.check_output([muscle_exe, "-align", temp_sequences_path, "-output", aligned_path])

    # Read the aligned sequences
    alignment = AlignIO.read(aligned_path, "fasta")

    # Convert alignment to DataFrame
    alignment_dict = {record.id: list(record.seq) for record in alignment}
    df_alignment = pd.DataFrame(alignment_dict)
    df_alignment = df_alignment.rename(columns={'YP_009094086.1':'cedar','AFH96011.1':'ghana','NP_112027.1':'nipah','NP_047112.2':'hendra','UCY33670.1':'hendra_G2','QDJ04463.1':'nipah_cambodia','QKV44014.1':'nipah_india','YP_009094095.1':'Mojiang','UUV47206.1':'Langya','AJP33320.1':'cedar_2'})
    
    # Calculate and add Shannon entropy for each site to the dataframe
    df_alignment['henipavirus_entropy'] = [shannon_entropy(df_alignment.iloc[i]) for i in range(df_alignment.shape[0])]
    
    return df_alignment

# Pull these genbank sequences
cedar = 'YP_009094086.1'
cedar2 = 'AJP33320.1'
ghana = 'AFH96011.1'
nipah = 'NP_112027.1',
nipah_cambodia = 'QDJ04463.1'
nipah_india = 'QKV44014.1'
hendra = 'NP_047112.2'
hendra_G2 = 'UCY33670.1'

seqs = [cedar, cedar2, ghana, nipah, nipah_cambodia, nipah_india, hendra, hendra_G2]
output_folder = "results/alignments/"
df = fetch_and_align(seqs, "blarsen@fredhutch.org", output_folder)
display(df.head(3))

### Make site numbering relative to Nipah reference sequence

In [None]:
# Create a boolean mask for the 'nipah' column
mask = df['nipah'] != '-'
# Use cumsum to count the occurrences and assign it to a new column 'site'
df['site'] = mask.cumsum()
# Reset the count to 0 for rows where 'nipah' is '-'
df.loc[~mask, 'site'] = 'NaN'

#Save entropy file for other notebooks to use
df.to_csv(entropy_output)
display(df.head(5))

### Calculate Which Sites are 100% conserved across represantative henipavirus sequences

In [None]:
relevant_columns = df.drop(columns=['henipavirus_entropy', 'site'])
df['conserved'] = relevant_columns.apply(lambda row: len(set(row)) == 1, axis=1)
conserved_sites = df[df['conserved']]['site'].sort_values().tolist()
print(f" These sites are completely conserved among representative Henipaviruses: {conserved_sites}")
print(f" The number of sites conserved across representative Henipaviruses are: {len(conserved_sites)}")

### Calculate entropy from Nipah sequence alignment of RBP

In [None]:
# read in nipah alignment
alignment_path = nipah_alignment
alignment = AlignIO.read(alignment_path, "fasta")

# Convert alignment to DataFrame
alignment_dict = {record.id: list(record.seq) for record in alignment}
df_alignment = pd.DataFrame(alignment_dict)

In [None]:
def shannon_entropy_and_mutant_aa(column, wildtype_aa):
    """
    Compute the Shannon entropy of a column in the alignment and return the top amino acid excluding the wildtype.
    
    Parameters:
    - column: A column from a sequence alignment, representing one site across multiple sequences.
    - wildtype_aa: The wildtype (original) amino acid at this position in a reference sequence.
    
    Returns:
    - The Shannon entropy of the column (a measure of diversity).
    - The amino acid variant that appears most frequently, excluding the wildtype.
    """
    # Initialize a dictionary to count occurrences of each amino acid
    counts = {}
    # Iterate through each amino acid in the column
    for aa in column:
        # Ignore gap ('-') and unknown ('X') characters
        if aa not in ["-", "X"]:
            # If the amino acid is already in the dictionary, increment its count
            if aa in counts:
                counts[aa] += 1
            # Otherwise, add it to the dictionary with a count of 1
            else:
                counts[aa] = 1
    
    # If counts is empty after filtering, return 0.0 entropy and None for the mutant amino acid
    if not counts:
        return 0.0, None
      
    # Calculate Shannon entropy
    entropy = 0.0
    for key in counts:
        freq = counts[key] / sum(counts.values())  # Calculate frequency of each amino acid
        entropy += freq * math.log2(freq)  # Add the frequency times the log base 2 of the frequency to the entropy

    # Remove the wildtype amino acid from counts if it's present
    counts.pop(wildtype_aa, None)
    # Sort the amino acids by frequency to find the mutant
    sorted_aas = sorted(counts.items(), key=lambda x: x[1], reverse=True)

    mutant_aa = sorted_aas[0][0] if sorted_aas and sorted_aas[0][1] >= 2 else None
    
    # return entropy
    return -entropy, mutant_aa

# Path to the alignment file
alignment_path = nipah_alignment

# Read the alignment file using BioPython's AlignIO
alignment = AlignIO.read(alignment_path, "fasta")

# Convert the alignment to pandas
alignment_dict = {record.id: list(record.seq) for record in alignment}
df_alignment = pd.DataFrame(alignment_dict)

# Extract the wildtype sequence from the DataFrame
wildtype_series = df_alignment['NC_002728.1_Nipah_virus_complete_genome']

# Compute entropy and mutant amino acid for each site in the alignment
values = [shannon_entropy_and_mutant_aa(df_alignment.iloc[i], wildtype_series[i]) for i in range(df_alignment.shape[0])]
# Unpack the computed values into two lists: entropies and mutants
entropies, mutants = zip(*values)

# Create a final DataFrame to hold the computed values along with site numbers
df_final = pd.DataFrame({
    'site': range(1, len(mutants) + 1),
    'entropy': entropies,
    'wildtype': wildtype_series,
    'mutant': mutants
})

# Filter to get rid of extra site at end
df_final = df_final[df_final['site'] < 603]
display(df_final)

### Find polymorphic Nipah sites

In [None]:
sites_with_mutants = df_final.loc[df_final['mutant'].notnull(), 'site'].tolist()
polymorphisms = list(sites_with_mutants)
data_series = pd.Series(polymorphisms)
# filter out sites that are outside mutagenized region
filtered_series = data_series[data_series > 71]
polymorphisms = list(filtered_series)
polymorphism_length = len(polymorphisms)
print(f'There are {polymorphism_length} polymorphic sites in NiV RBP sequences: {polymorphisms}')

In [None]:
def find_other_henipavirus_mutants(df,virus):
    df_comparison = df.rename(columns={virus:'mutant'})
    df_comparison = df_comparison[['mutant','nipah','site']]

    #filter rows to differences
    no_dash_df = df_comparison[~df_comparison['mutant'].str.contains('-') & ~df_comparison['nipah'].str.contains('-')]
    filtered_df = no_dash_df[no_dash_df['mutant'] != no_dash_df['nipah']]    

    sites = list(filtered_df['site'].unique())
    data_series = pd.Series(sites)
    filtered_series = data_series[data_series > 71]
    series = list(filtered_series)
    print(f'{virus} is different from nipah at these sites:\n {series}\n')

find_other_henipavirus_mutants(df,'hendra')
find_other_henipavirus_mutants(df,'cedar')