# This notebook will read in experimentally determined luciferase entry of individual mutants and plot correlation with DMS entry scores

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
altair_config = None
nipah_config = None
validation_file_E2 = None
validation_file_E3 = None

func_scores_E2_file = None
func_scores_E3_file = None

func_score_E2_plot = None
func_score_E3_plot = None
corr_plots_combined = None

In [None]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import yaml

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")
with open("config.yaml") as f:
    config = yaml.safe_load(f)

In [None]:
#altair_config = 'data/custom_analyses_data/theme.py'
#nipah_config = 'nipah_config.yaml'
#validation_file_E2 = 'data/custom_analyses_data/experimental_data/functional_validations_EFNB2.csv'
#validation_file_E3 = 'data/custom_analyses_data/experimental_data/functional_validations_EFNB3.csv'
#
#func_scores_E2_file = "results/func_effects/averages/CHO_EFNB2_low_func_effects.csv"
#func_scores_E3_file = "results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
#
#func_score_E2_plot 
#func_score_E3_plot

In [None]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Import luciferase (RLUs/uL) readings for each mutant

In [None]:
func_validations_EFNB2 = pd.read_csv(validation_file_E2,na_filter=None)
func_validations_EFNB2 = func_validations_EFNB2.rename(columns={'mean_luciferase':'mean_luciferase_E2'})
func_validations_EFNB3 = pd.read_csv(validation_file_E3,na_filter=None)
func_validations_EFNB3 = func_validations_EFNB3.rename(columns={'mean_luciferase':'mean_luciferase_E3'})
func_validations_EFNB3 = func_validations_EFNB3.drop('mutation',axis=1)
concat = pd.concat([func_validations_EFNB2,func_validations_EFNB3],axis=1)
display(concat)

### Now import func scores and make new column to match above data frame to merge on

In [None]:
func_scores = pd.read_csv(func_scores_E2_file)
func_scores['mutation'] = func_scores['wildtype'] + func_scores['site'].astype(str) + func_scores['mutant']

func_scores_E3 = pd.read_csv(func_scores_E3_file)
func_scores_E3['mutation'] = func_scores_E3['wildtype'] + func_scores_E3['site'].astype(str) + func_scores_E3['mutant']

func_scores_merged = func_scores.merge(func_scores_E3, on=['mutation'], how='left', suffixes=['_E2','_E3'])
merged = concat.merge(func_scores_merged,on=['mutation'],how='left')
#Change effect of WT to very small number other than 0 so can plot on log scale
merged.loc[merged['mutation'] == 'WT', ['effect_E2', 'effect_E3']] = 0.0000001

#for column in merged.select_dtypes(include=['int64']).columns:
#    merged[column] = merged[column].astype(int)

### Now Plot Correlations

### E2 Correlations

In [None]:
##### calculate R value:
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(merged['effect_E2'], merged['mean_luciferase_E2'])
r_value = float(r_value)

# Sorting function to put 'WT' on top of the legend, followed by numerical order
def custom_sort_order(array):
    # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
    def extract_number(mutation):
        num = re.search(r'\d+', mutation)
        return int(num.group()) if num else 0

    array = sorted(array, key=extract_number)

    # Move 'WT' to the beginning of the list
    if 'WT' in array:
        array.remove('WT')
        array.insert(0, 'WT')
    return array

# Define the category10 colors manually
category10_colors = ['#4E79A5', '#F18F3B', '#E0585B', '#77B7B2', '#5AA155', '#EDC958', '#AF7AA0', '#FE9EA8', '#9C7561', '#BAB0AC']

# Adjust colors based on the unique mutations
colors = ['black'] + category10_colors[:len(merged['mutation'].unique())-1]

# Create the Altair chart
corr_chart = (
    alt.Chart(merged,title=alt.Title('CHO-EFNB2',anchor='middle', fontSize=16))
    .encode(
        x=alt.X(
            "effect_E2:Q",
            title="DMS entry score",
            scale=alt.Scale(domain=[-4,1]),
            axis=alt.Axis(values=[-4, -3, -2, -1, 0, 1],tickCount=6)
        ),
        y=alt.Y(
            "mean_luciferase_E2",
            title="RLU/μL",
            scale=alt.Scale(type="log", base=10),
            axis=alt.Axis(format=".0e", grid=True, tickCount=4)  # Display in scientific notation
        ),
        color=alt.Color('mutation', title='Mutant', scale=alt.Scale(domain=custom_sort_order(merged['mutation'].unique()), range=colors)),
        tooltip=['mutation','effect_E2','mean_luciferase_E2']
    )
    .mark_point(filled=True, size=80, opacity=0.8)
    .properties(width=alt.Step(20), height=alt.Step(20))
)

min_effect_E2 = int(merged['effect_E2'].min())
max_mean_luciferase_E2 = int(merged['mean_luciferase_E2'].max())

text = alt.Chart({'values':[{'x': min_effect_E2, 'y': max_mean_luciferase_E2, 'text': f'r = {r_value:.2f}'}]}).mark_text(

    align='left',
    baseline='top',
    dx=-30,  # Adjust this for position
    dy=-10  # Adjust this for position
).encode(
    x=alt.X('x:Q'),
    y=alt.Y('y:Q'),
    text='text:N'
)
#text
final_chart = corr_chart + text


func_score_E2_chart = final_chart
func_score_E2_chart.display()
func_score_E2_chart.save(func_score_E2_plot)

In [None]:
#calculate R value:
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(merged['effect_E3'], merged['mean_luciferase_E3'])
r_value = float(r_value)

# Sorting function to put 'WT' on top of the legend, followed by numerical order
def custom_sort_order(array):
    # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
    def extract_number(mutation):
        num = re.search(r'\d+', mutation)
        return int(num.group()) if num else 0

    array = sorted(array, key=extract_number)

    # Move 'WT' to the beginning of the list
    if 'WT' in array:
        array.remove('WT')
        array.insert(0, 'WT')
    return array

# Define the category10 colors manually
category10_colors = ['#4E79A5', '#F18F3B', '#E0585B', '#77B7B2', '#5AA155', '#EDC958', '#AF7AA0', '#FE9EA8', '#9C7561', '#BAB0AC']

# Adjust colors based on the unique mutations
colors = ['black'] + category10_colors[:len(merged['mutation'].unique())-1]

# Create the Altair chart
corr_chart = (
    alt.Chart(merged,title=alt.Title('CHO-EFNB3',anchor='middle', fontSize=16))
    .encode(
        x=alt.X(
            "effect_E3:Q",
            title="DMS entry score",
            scale=alt.Scale(domain=[-4,1]),
            axis=alt.Axis(values=[-4, -3, -2, -1, 0, 1],tickCount=6)
        ),
        y=alt.Y(
            "mean_luciferase_E3",
            title="RLU/μL",
            scale=alt.Scale(type="log", base=10),
            axis=alt.Axis(format=".0e", grid=True, tickCount=4)  # Display in scientific notation
        ),
        color=alt.Color('mutation', title='Mutant', scale=alt.Scale(domain=custom_sort_order(merged['mutation'].unique()), range=colors)),
        tooltip=['mutation','effect_E3','mean_luciferase_E3']    
    )
    .mark_point(filled=True, size=80, opacity=0.6)
    .properties(width=alt.Step(20), height=alt.Step(20))
)

min_effect_E3 = int(merged['effect_E3'].min())
max_mean_luciferase_E3 = int(merged['mean_luciferase_E3'].max())

text = alt.Chart({'values':[{'x': min_effect_E3, 'y': max_mean_luciferase_E3, 'text': f'r = {r_value:.2f}'}]}).mark_text(

    align='left',
    baseline='top',
    dx=-30,  # Adjust this for position
    dy=-10  # Adjust this for position
).encode(
    x=alt.X('x:Q'),
    y=alt.Y('y:Q'),
    text='text:N'
)
#text
final_chart = corr_chart + text

func_score_E3_chart = final_chart
func_score_E3_chart.display()
func_score_E3_chart.save(func_score_E3_plot)

### Now calculate entry of each mutant relative to WT

In [None]:
mean_luc_E2 = merged.query('mutation == "WT"').groupby('mutation')['mean_luciferase_E2'].mean()
mean_luc_E3 = merged.query('mutation == "WT"').groupby('mutation')['mean_luciferase_E3'].mean()
merged['E2_relative'] = merged['mean_luciferase_E2'] / mean_luc_E2[0]
merged['E3_relative'] = merged['mean_luciferase_E3'] / mean_luc_E3[0]

merged_wt_drop = merged[merged['mutation'] != 'WT']

In [None]:
def plot_functional_validations(df):    
    df_melted = df.melt(id_vars='mutation', value_vars=['E2_relative', 'E3_relative'], var_name='type', value_name='effect')    
    chart = alt.Chart(df_melted).mark_circle(filled=True,opacity=0.7,size=100).encode(
        x=alt.X('type:N',title=None,axis=alt.Axis(labels=False, ticks=False, domain=True)),
        y=alt.Y('effect', title='Cell Entry of RBP Mutants Relative to WT',scale=alt.Scale(type="log", base=10),axis=alt.Axis(grid=True, tickCount=6)),  # Display in scientific notation
        #xOffset='random:Q',
        color=alt.Color('type',title='Receptor'),
        column=alt.Column('mutation:N',title=None, header=alt.Header(labelFontSize=16, labelFont="Helvetica Light", labelOrient='bottom'))
    ).properties(
        height=300,
        width=100
    )
    
    return chart.display()

plot_functional_validations(merged_wt_drop)

In [None]:
#corr_plots_combined
(func_score_E2_chart | func_score_E3_chart).save(corr_plots_combined)