# This notebook will read in experimentally determined fraction infectivity curves, plot, and then make correlations with DMS data

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
altair_config=None
nipah_config=None

neut = None
escape_file = None

nah1_validation_neut_curves = None
IC50_validation_plot = None
combined_ic50_neut_curve_plot = None

In [None]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import yaml

import pickle

import neutcurve
from neutcurve.colorschemes import CBPALETTE
from neutcurve.colorschemes import CBMARKERS
import scipy.stats
print(f"Using `neutcurve` version {neutcurve.__version__}")

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")
#with open("config.yaml") as f:
#    config = yaml.safe_load(f)

In [None]:
#altair_config = 'data/custom_analyses_data/theme.py'
#nipah_config = 'nipah_config.yaml'
#escape_file = 'results/antibody_escape/averages/nAH1.3_mut_effect.csv'
#neut = 'data/custom_analyses_data/experimental_data/nAH1_3_mab_validation_neuts.csv'
#IC50_validation_plot

### Read in config files

In [None]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Read in raw data

In [None]:
escape = pd.read_csv(escape_file)
neuts = pd.read_csv(neut)

Get curves

In [None]:
# Get rid of Y455M because its not present in escape data
neuts = neuts[neuts['virus'] != 'Y455M']

In [None]:
def get_neutcurve(df,serum,virus,replicate='average'):
    fits = neutcurve.curvefits.CurveFits(
                data=df,
                fixbottom=0,
                )
    fitParams = fits.fitParams(ics=[50, 90, 95, 97, 98, 99])

    curve = fits.getCurve(serum=serum, virus=virus, replicate=replicate)
    
    neut_df = curve.dataframe()
    neut_df['antibody'] = serum
    neut_df['virus'] = virus
    neut_df['upper'] = neut_df['measurement'] + neut_df['stderr']
    neut_df['lower'] = neut_df['measurement'] - neut_df['stderr']
    
    return fitParams,neut_df

serum_list = list(neuts['serum'].unique())
virus_list = list(neuts['virus'].unique())

# Assuming neuts is your DataFrame, and you have lists of serums and viruses
empty = []
for serum in serum_list:  # Iterate over each serum
    for virus in virus_list:  # For each serum, iterate over each virus
        fit_df,neut_df = get_neutcurve(neuts, serum, virus)  # Pass single serum and virus
        empty.append(neut_df)

#fit_df is the calculated fit for IC50,IC90, etc
display(fit_df.head(8))

# neut_curve is full fitting dataframe
neut_curve = pd.concat(empty,axis=0)
display(neut_curve.head(3))

### Make neut curve plot

In [None]:
# Sorting function to put 'WT' on top of the legend, followed by numerical order
def custom_sort_order(array):
    # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
    def extract_number(virus):
        num = re.search(r'\d+', virus)
        return int(num.group()) if num else 0

    array = sorted(array, key=extract_number)

    # Move 'WT' to the beginning of the list
    if 'WT' in array:
        array.remove('WT')
        array.insert(0, 'WT')
    return array

def plot_validation_curves(df,name):
    # Define the category10 colors manually
    category10_colors = ['#4E79A5', '#F18F3B', '#E0585B', '#77B7B2', '#5AA155', '#EDC958', '#AF7AA0', '#FE9EA8', '#9C7561', '#BAB0AC']

    # Adjust colors based on the unique mutations
    colors = ['black'] + category10_colors[:len(df['virus'].unique())-1]

    chart = alt.Chart(df).mark_line(size=1,opacity=1).encode(
        x=alt.X('concentration:Q',scale=alt.Scale(type='log'),axis=alt.Axis(format='.0e',tickCount=3),title=f'{name} conc. (μg/mL)'),
        y=alt.Y('fit:Q',title='Fraction Infectivity',axis=alt.Axis(tickCount=3)),
        color=alt.Color('virus',title='Mutant',scale=alt.Scale(domain=custom_sort_order(df['virus'].unique()), range=colors))
    ).properties(
        height=200,
        width=300,
    )
    circle = alt.Chart(df).mark_circle(size=50,opacity=1).encode(
        x=alt.X('concentration',scale=alt.Scale(type='log'),axis=alt.Axis(format='.0e',tickCount=3),title=f'{name} conc. (μg/mL)'),
        y=alt.Y('measurement:Q',title='Fraction Infectivity',axis=alt.Axis(tickCount=3)),
        color=alt.Color('virus',title='Mutant',scale=alt.Scale(domain=custom_sort_order(df['virus'].unique()), range=colors))
    ).properties(
        height=200,
        width=300,
    )
    error = alt.Chart(df).mark_errorbar(opacity=1).encode(
        x='concentration',
        y=alt.Y('lower',title='Fraction Infectivity'),
        y2='upper',
        color='virus'
    )
    plot = chart+circle+error
    return plot

nah1_neut_curves = plot_validation_curves(neut_curve,'nAH1.3')
nah1_neut_curves.display()
nah1_neut_curves.save(nah1_validation_neut_curves)

### Now calculate r correlation value, and plot

In [None]:
def plot_ic50_correlations(df):
    #Merge dataframes and append WT so it has escape score of 0
    df['lower_bound'] = df['ic50_bound'].apply(lambda x: x == 'lower')
    df['mutation'] = df['virus']
    # Merge with DMS escape data
    merged = df.merge(escape,on=['mutation'])
    wt_rows = df[df['mutation'] == 'WT'].copy()
    wt_rows['escape_median'] = 0
    merged = pd.concat([merged, wt_rows], ignore_index=True)
    
    #calculate R value:
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(merged['escape_median'], merged['ic50'])
    #print(f'r={r_value:.2f}')
    
    # Sorting function to put 'WT' on top of the legend, followed by numerical order
    def custom_sort_order(array):
        # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
        def extract_number(mutation):
            num = re.search(r'\d+', mutation)
            return int(num.group()) if num else 0
    
        array = sorted(array, key=extract_number)
    
        # Move 'WT' to the beginning of the list
        if 'WT' in array:
            array.remove('WT')
            array.insert(0, 'WT')
        return array
    
    # Define the category10 colors manually
    category10_colors = ['#4E79A5', '#F18F3B', '#E0585B', '#77B7B2', '#5AA155', '#EDC958', '#AF7AA0', '#FE9EA8', '#9C7561', '#BAB0AC']
    
    # Adjust colors based on the unique mutations
    colors = ['black'] + category10_colors[:len(merged['mutation'].unique())-1]
    
    corr_chart = (
        alt.Chart(merged)
        .encode(
            x=alt.X(
                "escape_median",
                title="DMS escape score",
                axis=alt.Axis(grid=True)
            ),
            y=alt.Y(
                "ic50",
                title="nAH1.3 IC₅₀ (μg/ml)",
                scale=alt.Scale(type="log"),
                axis=alt.Axis(grid=True),
            ),
            color=alt.Color('mutation', title='Mutant', scale=alt.Scale(domain=custom_sort_order(merged['mutation'].unique()), range=colors)),
            shape=alt.Shape('lower_bound',title='Lower Bound'),
        )
        .mark_point(filled=True, size=150, opacity=1)
        .properties(width=200, height=200)
    )
    
    text = alt.Chart({'values':[{'x': merged['ic50'].min(), 'y': merged['escape_median'].max(), 'text': f'r = {r_value:.2f}'}]}).mark_text(
        align='left',
        baseline='top',
        dx=5  # Adjust this for position
    ).encode(
        x=alt.X('x:Q'),
        y=alt.Y('y:Q'),
        text='text:N'
    )
    final_chart = corr_chart + text
    return final_chart

ic50_validations = plot_ic50_correlations(fit_df)
ic50_validations.display()
ic50_validations.save(IC50_validation_plot)

In [None]:
(nah1_neut_curves | ic50_validations).save(combined_ic50_neut_curve_plot)