# This notebook analyzes features of henipavirus and nipah sequence conservation and compares to DMS data

In [None]:
#this cell is tagged as parameters for `papermill` parameterization
altair_config = None
nipah_config = None

e2_binding = None
e2_entry = None

e3_binding = None
e3_entry = None

nipah_alignment = None

entropy_output = None
entry_scores_niv_poly = None
binding_scores_niv_poly = None

In [None]:
import math
import os

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats
from scipy import stats

import subprocess
import tempfile
import yaml
from Bio import Entrez
from Bio import SeqIO
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Align.Applications import MuscleCommandline
from Bio.Align.Applications import MafftCommandline
from Bio.Seq import Seq
from Bio.Align import PairwiseAligner

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")

In [None]:
if nipah_alignment is None:
    altair_config = 'data/custom_analyses_data/theme.py'
    nipah_config = 'nipah_config.yaml'
    e2_binding = 'results/receptor_affinity/averages/EFNB2_monomeric_mut_effect.csv'
    e2_entry = 'results/func_effects/averages/CHO_EFNB2_low_func_effects.csv'
    e3_binding = 'results/receptor_affinity/averages/EFNB3_dimeric_mut_effect.csv'
    e3_entry = 'results/func_effects/averages/CHO_EFNB3_low_func_effects.csv'
    nipah_alignment = 'data/custom_analyses_data/alignments/Nipah_RBP_AA_align.fasta'
    entropy_output = 'results/entropy/entropy.csv'
    entry_scores_niv_poly = 'results/images/niv_polymorphic_entry.html'
    binding_scores_niv_poly = 'results/images/niv_polymorphic_binding.html'

In [None]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Pull in cell entry and binding scores and pre-filter

In [None]:
# Make E2 monomeric
e2 = pd.read_csv(e2_binding)
e2_func = pd.read_csv(e2_entry)
df_E2 = pd.merge(
    e2_func,
    e2, 
    on=['site','mutant','wildtype'],
    suffixes=['_cell_entry','_affinity'],
    validate='one_to_one',
    how='outer'
)
df_E2 = df_E2.rename(columns={'Ephrin binding_mean':'binding_mean','Ephrin binding_std':'binding_std','Ephrin binding_median':'binding_median'})

# Now do E3
e3 = pd.read_csv(e3_binding)
e3_func = pd.read_csv(e3_entry)
df_E3 = pd.merge(
    e3_func,
    e3,
    on=['site','mutant','wildtype'],
    suffixes=['_cell_entry','_affinity'],
    validate='one_to_one',
    how='outer'
)
df_E3 = df_E3.rename(columns={'Ephrin binding_mean':'binding_mean','Ephrin binding_std':'binding_std','Ephrin binding_median':'binding_median'})

# don't filter binding mutants, just cell entry
def filter_df(df):
    df_filter = df[
        #(df['effect'] >= -1.5) &
        #(df['frac_models'] >= 0.5) &
        (df['mutant'] != '*') &
        (df['mutant'] != '-') &
        (df['site'] != 603) &
        (df['times_seen_cell_entry'] >= config['func_times_seen_cutoff']) &
        #(df['times_seen_affinity'] >= config[') &
        #(df['binding_std'] <= 1.5) &
        (df['effect_std'] <= config['func_std_cutoff']) 
        #(df['frac_models'] >= 0.5)
    ]
    #df_filter = df_filter.sort_values(by='binding_mean',ascending=False)
    return df_filter


df_E2_filter = filter_df(df_E2)
df_E3_filter = filter_df(df_E3)

df_affinity_filter_merge = pd.merge(
    df_E2_filter,
    df_E3_filter,
    on=['site','wildtype','mutant'],
    suffixes=['_E2','_E3'],
    how='outer'
)
df_affinity_filter_merge['func_effect_diff'] = (df_affinity_filter_merge['effect_E2'] - df_affinity_filter_merge['effect_E3']).abs()
df_affinity_filter_merge['binding_effect_diff'] = (df_affinity_filter_merge['binding_mean_E2'] - df_affinity_filter_merge['binding_mean_E3']).abs()

df_assign = df_affinity_filter_merge[['site','wildtype','mutant','effect_E2','binding_median_E2','binding_std_E2','effect_E3','binding_median_E3','binding_std_E3']]
display(df_assign)

### Pull represantative henipavirus RBP amino acid sequences from genbank, align, calculate entropy, and convert to a dataframe

In [None]:
def shannon_entropy(column):
    """Compute the Shannon entropy of a column in the alignment."""
    counts = {}
    for aa in column:
        if aa in counts:
            counts[aa] += 1
        else:
            counts[aa] = 1

    entropy = 0.0
    for key in counts:
        freq = counts[key] / len(column)
        entropy += freq * math.log2(freq)
    return -entropy

def fetch_and_align(accession_numbers, email, output_folder="."):
    """
    Fetch sequences from GenBank based on accession numbers, align them,
    and return the alignment as a pandas DataFrame.

    Parameters:
    - accession_numbers: List of accession numbers.
    - email: Email address to be used with NCBI's Entrez.
    - output_folder: The directory where output files will be saved.

    Returns:
    - DataFrame representation of the alignment.
    """
    # Ensure the output directory exists, if not, create it.
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Fetch sequences from GenBank
    Entrez.email = email
    sequences = []
    for acc in accession_numbers:
        handle = Entrez.efetch(db="protein", id=acc, rettype="fasta", retmode="text")
        seq_record = SeqIO.read(handle, "fasta")
        sequences.append(seq_record)
        handle.close()

    # Define file paths
    temp_sequences_path = os.path.join(output_folder, "temp_sequences.fasta")
    aligned_path = os.path.join(output_folder, "aligned.fasta")

    # Write sequences to a temporary fasta file
    SeqIO.write(sequences, temp_sequences_path, "fasta")

    # Align using MUSCLE (you might need to adjust the path to the MUSCLE executable)
    muscle_exe = "/fh/fast/bloom_j/software/miniconda3/envs/BloomLab/bin/muscle"
    muscle_result = subprocess.check_output([muscle_exe, "-align", temp_sequences_path, "-output", aligned_path])

    # Read the aligned sequences
    alignment = AlignIO.read(aligned_path, "fasta")

    # Convert alignment to DataFrame
    alignment_dict = {record.id: list(record.seq) for record in alignment}
    df_alignment = pd.DataFrame(alignment_dict)
    df_alignment = df_alignment.rename(columns={'YP_009094086.1':'cedar','AFH96011.1':'ghana','NP_112027.1':'nipah','NP_047112.2':'hendra','UCY33670.1':'hendra_G2','QDJ04463.1':'nipah_cambodia','QKV44014.1':'nipah_india','YP_009094095.1':'Mojiang','UUV47206.1':'Langya','AJP33320.1':'cedar_2'})
    
    # Calculate and add Shannon entropy for each site to the dataframe
    df_alignment['henipavirus_entropy'] = [shannon_entropy(df_alignment.iloc[i]) for i in range(df_alignment.shape[0])]
    
    return df_alignment

# Pull these genbank sequences
cedar = 'YP_009094086.1'
cedar2 = 'AJP33320.1'
ghana = 'AFH96011.1'
nipah = 'NP_112027.1',
nipah_cambodia = 'QDJ04463.1'
nipah_india = 'QKV44014.1'
hendra = 'NP_047112.2'
hendra_G2 = 'UCY33670.1'

seqs = [cedar, cedar2, ghana, nipah, nipah_cambodia, nipah_india, hendra, hendra_G2]
output_folder = "results/alignments/"
df = fetch_and_align(seqs, "blarsen@fredhutch.org", output_folder)
display(df.head(3))

### Make site numbering relative to Nipah reference sequence

In [None]:
# Create a boolean mask for the 'nipah' column
mask = df['nipah'] != '-'
# Use cumsum to count the occurrences and assign it to a new column 'site'
df['site'] = mask.cumsum()
# Reset the count to 0 for rows where 'nipah' is '-'
df.loc[~mask, 'site'] = 'NaN'

In [None]:
#Save file for other notebooks use
df.to_csv(entropy_output)

### Calculate Which Sites are 100% conserved across represantative henipavirus sequences

In [None]:
relevant_columns = df.drop(columns=['henipavirus_entropy', 'site'])
df['conserved'] = relevant_columns.apply(lambda row: len(set(row)) == 1, axis=1)
conserved_sites = df[df['conserved']]['site'].sort_values().tolist()
print(f" These sites are completely conserved among represantative Henipaviruses: {conserved_sites}")
print(f" The number of sites conserved across all Henipaviruses are: {len(conserved_sites)}")
df_merged = pd.merge(df_assign, df, on='site', how='left')

### Plot median cell entry at sites conserved in henipaviruses

In [None]:
def plot_conserved_sites(df):
    df_subset = df.loc[df['site'].isin(conserved_sites)]
    df_melted = df_subset.melt(id_vars=['site','wildtype','mutant'], value_vars=['effect_E2', 'effect_E3'], var_name='type', value_name='effect')
    df_melted['type'] = df_melted['type'].replace({'effect_E2': 'EFNB2', 'effect_E3': 'EFNB3'})
    
    df_melted = df_melted.groupby(['site','type'])['effect'].median().reset_index()
    chart = alt.Chart(df_melted).mark_point(size = 100, filled=True,opacity=1).encode(
        x = alt.X('site:N',title='Site', axis=alt.Axis(grid=True,labelAngle=-45)),
        y = alt.Y('effect',title='Mean Cell Entry by RBP Mutants'),
        #tooltip=['site','wildtype','mutant','type'],
        color = alt.Color('type', legend=alt.Legend(title='Cell Type')),
    ).properties(
        width=alt.Step(20),
        height=alt.Step(20)
    )

    return chart.display()

plot_conserved_sites(df_assign)

### Calculate entropy from Nipah sequence alignment of RBP

In [None]:
alignment_path = nipah_alignment
alignment = AlignIO.read(alignment_path, "fasta")

# Convert alignment to DataFrame
alignment_dict = {record.id: list(record.seq) for record in alignment}
df_alignment = pd.DataFrame(alignment_dict)
display(df_alignment)

In [None]:
def shannon_entropy_and_mutant_aa(column, wildtype_aa):
    """
    Compute the Shannon entropy of a column in the alignment and return the top amino acid excluding the wildtype.
    
    Parameters:
    - column: A column from a sequence alignment, representing one site across multiple sequences.
    - wildtype_aa: The wildtype (original) amino acid at this position in a reference sequence.
    
    Returns:
    - The Shannon entropy of the column (a measure of diversity).
    - The amino acid variant that appears most frequently, excluding the wildtype.
    """
    # Initialize a dictionary to count occurrences of each amino acid
    counts = {}
    # Iterate through each amino acid in the column
    for aa in column:
        # Ignore gap ('-') and unknown ('X') characters
        if aa not in ["-", "X"]:
            # If the amino acid is already in the dictionary, increment its count
            if aa in counts:
                counts[aa] += 1
            # Otherwise, add it to the dictionary with a count of 1
            else:
                counts[aa] = 1
    
    # If counts is empty after filtering, return 0.0 entropy and None for the mutant amino acid
    if not counts:
        return 0.0, None
      
    # Calculate Shannon entropy
    entropy = 0.0
    for key in counts:
        freq = counts[key] / sum(counts.values())  # Calculate frequency of each amino acid
        entropy += freq * math.log2(freq)  # Add the frequency times the log base 2 of the frequency to the entropy

    # Remove the wildtype amino acid from counts if it's present
    counts.pop(wildtype_aa, None)
    # Sort the amino acids by frequency to find the mutant
    sorted_aas = sorted(counts.items(), key=lambda x: x[1], reverse=True)

    mutant_aa = sorted_aas[0][0] if sorted_aas and sorted_aas[0][1] > 2 else None
    # Correctly select the top amino acid if any exist
    #mutant_aa = sorted_aas[0][0] if sorted_aas else None
    
    # Return the negative entropy (since entropy is traditionally a measure of disorder, negative entropy can be seen as order) and the mutant amino acid
    return -entropy, mutant_aa

# Path to the alignment file (presumed defined elsewhere in your code)
alignment_path = nipah_alignment
# Read the alignment file using BioPython's AlignIO
alignment = AlignIO.read(alignment_path, "fasta")

# Convert the alignment to a pandas DataFrame for easier manipulation
alignment_dict = {record.id: list(record.seq) for record in alignment}
df_alignment = pd.DataFrame(alignment_dict)

# Extract the wildtype sequence from the DataFrame
wildtype_series = df_alignment['NC_002728.1_Nipah_virus_complete_genome']

# Compute entropy and mutant amino acid for each site in the alignment
values = [shannon_entropy_and_mutant_aa(df_alignment.iloc[i], wildtype_series[i]) for i in range(df_alignment.shape[0])]
# Unpack the computed values into two lists: entropies and mutants
entropies, mutants = zip(*values)

# Create a final DataFrame to hold the computed values along with site numbers
df_final = pd.DataFrame({
    'site': range(1, len(mutants) + 1),
    'entropy': entropies,
    'wildtype': wildtype_series,
    'mutant': mutants
})

# Filter to get rid of extra site at end
df_final = df_final[df_final['site'] < 603]
#display(df_final[df_final)
#df_mutations are all sites that have polymorphisms in Nipah sequences
df_mutations = pd.merge(df_final,df_assign,on=['site','wildtype','mutant'],how='inner')

#df_total is data frame with all information
df_total = pd.merge(df_merged,df_final,on=['site'])
df_total = df_total.rename(columns={'wildtype_y':'nipahM_con','mutant_y':'nipahM_minor','wildtype_x':'wildtype','mutant_x':'mutant'})


In [None]:
display(df_total)

### Find all sites that were mutagenized with a polymophism in NiV

In [None]:
sites_with_mutants = df_final.loc[df_final['mutant'].notnull(), 'site'].tolist()
polymorphisms = list(sites_with_mutants)
print(polymorphisms)
data_series = pd.Series(polymorphisms)  
filtered_series = data_series[data_series > 70]
polymorphisms = list(filtered_series)
polymorphism_length = len(polymorphisms)
print(polymorphism_length)
print(f'These are a list of polymorphic sites in NiV RBP sequences: {polymorphisms}')

In [None]:
# Find the 5th and 95th percentiles
lower_limit_E3_binding = df_total['binding_median_E3'].dropna().quantile(0.05)
median_E3_binding = df_total['binding_median_E3'].dropna().quantile(0.50)
upper_limit_E3_binding = df_total['binding_median_E3'].dropna().quantile(0.95)

print(f"Lower limit (5th percentile): {lower_limit_E3_binding:.2f}")
print(f"Upper limit (95th percentile): {upper_limit_E3_binding:.2f}")

# Find the 5th and 95th percentiles
lower_limit_E2_binding = df_total['binding_median_E2'].dropna().quantile(0.05)
median_E2_binding = df_total['binding_median_E2'].dropna().quantile(0.50)
upper_limit_E2_binding = df_total['binding_median_E2'].dropna().quantile(0.95)

print(f"Lower limit (5th percentile): {lower_limit_E2_binding:.2f}")
print(f"Upper limit (95th percentile): {upper_limit_E2_binding:.2f}")

# Find the 5th and 95th percentiles
lower_limit_E2_entry = df_total['effect_E2'].dropna().quantile(0.05)
median_E2_entry = df_total['effect_E2'].dropna().quantile(0.50)
upper_limit_E2_entry = df_total['effect_E2'].dropna().quantile(0.95)

print(f"Lower limit (5th percentile): {lower_limit_E2_entry:.2f}")
print(f"Upper limit (95th percentile): {upper_limit_E2_entry:.2f}")

# Find the 5th and 95th percentiles
lower_limit_E3_entry = df_total['effect_E3'].dropna().quantile(0.05)
median_E3_entry = df_total['effect_E3'].dropna().quantile(0.50)
upper_limit_E3_entry = df_total['effect_E3'].dropna().quantile(0.95)

print(f"Lower limit (5th percentile): {lower_limit_E3_entry:.2f}")
print(f"Upper limit (95th percentile): {upper_limit_E3_entry:.2f}")

In [None]:
#If I draw random samples of size x, what is probability I get an observation in there as high as observed?
def random_draws(df,column,threshold):
    num_trials = 1000  
    count = 0  
    x = polymorphism_length
    for _ in range(num_trials):
        sample = df[column].dropna().sample(n=x, replace=True)
        # Check if any value in the sample is above 1
        if any(sample > threshold):
            count += 1
    
        # Calculate the fraction of times at least one observation above 1 was found
        fraction = count / num_trials
        fraction
    print(f'The fraction of times a random draw of {column} included a value greater than {threshold:.2f} was {fraction:.2f}')
df_total_polymorphic = df_total[df_total['site'].isin(polymorphisms)]

columns = ['effect_E2','effect_E3','binding_median_E2','binding_median_E3']
for column_name in columns:
    if column_name == 'effect_E2':
        random_draws(df_total_polymorphic,column_name,df_mutations['effect_E2'].max())
    if column_name == 'effect_E3':
        random_draws(df_total_polymorphic,column_name,df_mutations['effect_E3'].max())
    if column_name == 'binding_mean_E2':
        random_draws(df_total_polymorphic,column_name,df_mutations['binding_median_E2'].max())
    if column_name == 'binding_mean_E3':
        random_draws(df_total_polymorphic,column_name,df_mutations['binding_median_E3'].max())

### Plot Nipah Polymorphisms Cell Entry Scores by Site

In [None]:
def get_stats(df,column):
    df = df[column].dropna()
    df_clean = df_mutations[column].dropna()
    mean = df.mean()
    median = df.median()
    std = df.std()
    print(f'The mean for {column} is: {mean:.2f}')
    print(f'The median for {column} is: {median:.2f}')
    print(f'The std for {column} is: {std:.2f}')
    t_statistic, p_value = stats.ttest_1samp(df_clean, mean)
    print(f'The p_value for {column} is: {p_value:.3f}')
    print('')
    return mean,std

columns = ['effect_E2','effect_E3','binding_median_E2','binding_median_E3']
for column_name in columns:
    if column_name == 'effect_E2':
        effect_E2_mean,effect_E2_std = get_stats(df_total,column_name)
    if column_name == 'effect_E3':
        effect_E3_mean,effect_E3_std = get_stats(df_total,column_name)
    if column_name == 'binding_median_E2':
        binding_E2_mean,binding_E2_std = get_stats(df_total,column_name)
    if column_name == 'binding_median_E3':
        binding_E3_mean,binding_E3_std = get_stats(df_total,column_name)
    #get_stats(df_total,column_name)

effect_E2_min = effect_E2_mean - 2 * effect_E2_std
effect_E2_max = effect_E2_mean + 2 * effect_E2_std
print(f'The max and min of effect_E2 are: {effect_E2_min:.2f} and {effect_E2_max:.2f}')

effect_E3_min = effect_E3_mean - 2 * effect_E3_std
effect_E3_max = effect_E3_mean + 2 * effect_E3_std
print(f'The max and min of effect_E2 are: {effect_E3_min:.2f} and {effect_E3_max:.2f}')

binding_E2_min = binding_E2_mean - 2 * binding_E2_std
binding_E2_max = binding_E2_mean + 2 * binding_E2_std
print(f'The max and min of effect_E2 are: {binding_E2_min:.2f} and {binding_E2_max:.2f}')

binding_E3_min = binding_E3_mean - 2 * binding_E3_std
binding_E3_max = binding_E3_mean + 2 * binding_E3_std
print(f'The max and min of effect_E2 are: {binding_E3_min:.2f} and {binding_E3_max:.2f}')

### Plot cell entry of all naturally occuring nipah mutants

In [None]:
def plot_functional_effect_polymorphism(df):
    df_mutations = df.rename(columns={'effect_E2':'EFNB2','effect_E3':'EFNB3'})
    df_melted = df_mutations.melt(id_vars=['site','wildtype','mutant'], value_vars=['EFNB2','EFNB3'], 
                                  var_name='effect', value_name='value')
    charts = []
    
    for effect in ['EFNB2','EFNB3']: 
        base = alt.Chart(df_melted,title=f'{effect}').encode(
            x=alt.X('site:N', title='Site', axis=alt.Axis(labelAngle=-90),
                    scale=alt.Scale(domain=list(polymorphisms))),
            y=alt.Y('value:Q', title='Cell Entry'),
            tooltip=['effect', 'value', 'site', 'mutant', 'wildtype']
        ).transform_filter(
            alt.datum.effect == effect
        ).properties(
            width=500,
            height=100#alt.Step(5)
        )

        chart_effect = base.mark_circle(size=100,opacity=1,color='black').encode(
            #color=alt.Color('effect:N', scale=alt.Scale(domain=['EFNB2', 'EFNB3'], range=['#1f4e79', '#ff7f0e']),legend=None)
        )
        
        if effect == 'EFNB2':
            rule95 = alt.Chart(pd.DataFrame({'y': [upper_limit_E2_entry]})).mark_rule(color='black', size=1.5,opacity=0.5).encode(y='y:Q')
            rule50 = alt.Chart(pd.DataFrame({'y': [median_E2_entry]})).mark_rule(color='#black', size=1.5,opacity=0.5).encode(y='y:Q')
            rule5 = alt.Chart(pd.DataFrame({'y': [lower_limit_E2_entry]})).mark_rule(color='#black', size=1.5,opacity=0.5).encode(y='y:Q')
        else:  # effect is EFNB3
            rule95 = alt.Chart(pd.DataFrame({'y': [upper_limit_E3_entry]})).mark_rule(color='#black', size=1.5,opacity=0.5).encode(y='y:Q')
            rule50 = alt.Chart(pd.DataFrame({'y': [median_E3_entry]})).mark_rule(color='#black', size=1.5,opacity=0.5).encode(y='y:Q')
            rule5 = alt.Chart(pd.DataFrame({'y': [lower_limit_E3_entry]})).mark_rule(color='#black', size=1.5,opacity=0.5).encode(y='y:Q')

        chart = alt.layer(chart_effect).resolve_scale(color='independent')
        #chart = alt.layer(chart_effect, area).resolve_scale(color='independent')
        charts.append(chart)

    combined_chart = alt.vconcat(*charts).resolve_scale(y='independent',  color='independent')
    return combined_chart

entry_nipah = plot_functional_effect_polymorphism(df_mutations)
entry_nipah.display()
entry_nipah.save(entry_scores_niv_poly)

### Make plot of binding by each NiV polymorphism (same as above)

In [None]:
def plot_functional_effect_polymorphism_E2(df):
    df_mutations = df.rename(columns={'binding_median_E2':'EFNB2','binding_median_E3':'EFNB3'})
    df_melted = df_mutations.melt(id_vars=['site','wildtype','mutant'], value_vars=['EFNB2','EFNB3'], 
                                  var_name='effect', value_name='value')
    charts = []
    
    for effect in ['EFNB2','EFNB3']: 
        base = alt.Chart(df_melted,title=f'{effect}').encode(
            x=alt.X('site:N', title='Site', axis=alt.Axis(labelAngle=-90),
                    scale=alt.Scale(domain=list(polymorphisms))),
            y=alt.Y('value:Q', title='Binding score'),
            tooltip=['effect', 'value', 'site', 'mutant', 'wildtype']
        ).transform_filter(
            alt.datum.effect == effect
        ).properties(
            width=500,
            height=100
        )

        chart_effect = base.mark_circle(size=100,opacity=1,color='black').encode(
            #color=alt.Color('effect:N', scale=alt.Scale(domain=['EFNB2', 'EFNB3'], range=['#1f4e79', '#ff7f0e']),legend=None)
        )
        
        if effect == 'EFNB2':
            rule95 = alt.Chart(pd.DataFrame({'y': [upper_limit_E2_binding]})).mark_rule(color='#1f4e79', size=1.5,opacity=0.5).encode(y='y:Q')
            rule50 = alt.Chart(pd.DataFrame({'y': [median_E2_binding]})).mark_rule(color='#1f4e79', size=1.5,opacity=0.5).encode(y='y:Q')
            rule5 = alt.Chart(pd.DataFrame({'y': [lower_limit_E2_binding]})).mark_rule(color='#1f4e79', size=1.5,opacity=0.5).encode(y='y:Q')
        else:  # effect is EFNB3
            rule95 = alt.Chart(pd.DataFrame({'y': [upper_limit_E3_binding]})).mark_rule(color='#ff7f0e', size=1.5,opacity=0.5).encode(y='y:Q')
            rule50 = alt.Chart(pd.DataFrame({'y': [median_E3_binding]})).mark_rule(color='#ff7f0e', size=1.5,opacity=0.5).encode(y='y:Q')
            rule5 = alt.Chart(pd.DataFrame({'y': [lower_limit_E3_binding]})).mark_rule(color='#ff7f0e', size=1.5,opacity=0.5).encode(y='y:Q')

        chart = alt.layer(chart_effect).resolve_scale(color='independent')
        charts.append(chart)

    combined_chart = alt.vconcat(*charts).resolve_scale(y='independent', x='independent', color='independent')
    return combined_chart

niv_poly = plot_functional_effect_polymorphism_E2(df_mutations)
niv_poly.display()
niv_poly.save(binding_scores_niv_poly)

### Plot entropy vs mean effect

In [None]:
def entropy_scatter_chart(df,metric,effect):
    if effect == 'effect_E2':
        effect_name = 'EFNB2'
    else:
        effect_name = 'EFNB3'
    
    aggregation = getattr(df.groupby('site')[['effect_E2', 'effect_E3', 'binding_median_E2', 'binding_median_E3']], metric)
    means = aggregation().reset_index()
    df_total_unique = df.drop_duplicates(subset='site')
    df_mean = pd.merge(means, df_total_unique[['entropy','site','henipavirus_entropy']], on='site', how='left')
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_mean['henipavirus_entropy'], df_mean[effect])

    print(f'The r value is: {r_value:.2f}')
    print(f'The p_value is: {p_value:.2f}')
    scatter_chart = alt.Chart(df_mean).mark_point(color='black',size=70, filled=True,opacity=0.5).encode(
        x=alt.X('henipavirus_entropy', title=(f"Henipavirus Entropy"), axis=alt.Axis(grid=True, tickCount=3),scale=alt.Scale(domain=[-0.2,2.5])),
        y=alt.Y(effect, title=(f"Summed Cell Entry for {effect_name}"), axis=alt.Axis(grid=True, tickCount=3)),
        tooltip=['site'],
    ).properties(
        width=alt.Step(10),
        height=alt.Step(10)
    )
    # Regression line
    reg_df = pd.DataFrame({
        'henipavirus_entropy': df_mean['henipavirus_entropy'],
        'predicted': intercept + slope * df_mean['henipavirus_entropy']
    })

    line_chart = alt.Chart(reg_df).mark_line(color='red',opacity=0.5).encode(
        x='henipavirus_entropy',
        y='predicted'
    )

    # Combine scatter plot with regression line
    combined_chart = scatter_chart + line_chart

    return combined_chart

    #return chart.display()

e2_entry_vs_entropy = entropy_scatter_chart(df_total,'sum','effect_E2')
e2_entry_vs_entropy.display()
e3_entry_vs_entropy = entropy_scatter_chart(df_total,'sum','effect_E3')
e3_entry_vs_entropy.display()

### Plot Functional Effects of Differences with Hendra

In [None]:
def plot_hendra_mutations_E2(df):
    df_hendra_comparison = df.rename(columns={'hendra':'mutant'})
    df_hendra = pd.merge(df_hendra_comparison,df_merged,on=['site','mutant'],how='inner')
    # Melt the dataframe to make it long-form for Altair plotting
    df_melted = df_hendra.melt(id_vars=['site','wildtype','hendra'], value_vars=['binding_median_E2','binding_median_E3'], 
                                  var_name='effect', value_name='value')
    
    # Altair line plot
    chart = alt.Chart(df_melted.query("effect == 'binding_median_E2'")).mark_point(color='black',filled=True,size=70).encode(
        x=alt.X('site:N',title='Site',axis=alt.Axis(labelAngle=-90)),
        y=alt.Y('value:Q',title='Receptor Binding',axis=alt.Axis(grid=True, tickCount=4)),
        tooltip=['site','wildtype','hendra','value'],
    ).properties()

    chart_master = chart 
    return chart_master.display()

def plot_hendra_mutations_E3(df):
    df_hendra_comparison = df.rename(columns={'hendra':'mutant'})
    df_hendra = pd.merge(df_hendra_comparison,df_merged,on=['site','mutant'],how='inner')

    # Melt the dataframe to make it long-form for Altair plotting
    df_melted = df_hendra.melt(id_vars=['site','wildtype','hendra'], value_vars=['binding_median_E2','binding_median_E3'], 
                                  var_name='effect', value_name='value')
    
    # Altair line plot
    chart = alt.Chart(df_melted.query("effect == 'binding_median_E3'")).mark_point(color='black',filled=True,size=70).encode(
        x=alt.X('site:N',title='Site',axis=alt.Axis(labelAngle=-45)),
        y=alt.Y('value:Q',title='Receptor Binding'),
        tooltip=['site','wildtype','hendra','value'],
    ).properties()
    
    chart_master = chart 
    return chart_master.display()

plot_hendra_mutations_E2(df)
plot_hendra_mutations_E3(df)

In [None]:
def find_hendra_mutants(df,virus):
    df_hendra_comparison = df.rename(columns={virus:'mutant'})
    df_hendra_comparison = df_hendra_comparison[['mutant','site']]
    df_hendra = pd.merge(df_hendra_comparison,df_merged,on=['site','mutant'],how='inner')
    hendra_sites = list(df_hendra['site'].unique())
    print(hendra_sites)
    print(len(hendra_sites))
    #display(df_hendra)

find_hendra_mutants(df,'hendra')
find_hendra_mutants(df,'cedar')