In [None]:
import pandas as pd
from matplotlib import pyplot as pl

import altair as alt
_ = alt.data_transformers.disable_max_rows()

In [None]:
rabishield_escape_predictions = pd.read_csv('17C7_filtered_mut_effects.csv_escape_prediction.csv');

#okay we should do some sort of accession 'label' since it will be kind of annoying to use the raw strain name.

rabishield_escape_predictions['accession']=rabishield_escape_predictions['strain'].apply(lambda x: x.split('_')[1])



def format_mut_numbers(mut_list_input):
    #okay so we want to try and reformat the list of measured mutations since they will be important for ID'ing escape muts.
    #the thing is the numbers are off by 19, so i want to make it easier to identify the mutants from mousing over on altair.
    
    #step 1: get muts, split text
    mut_list = mut_list_input['measured_aa_substitutions'].tolist();
    mut_list = mut_list[0];

    mut_list = mut_list.split(' ');
    

    #step 2: initialize replacement array
    mut_list_renumbered = [];

    #step 3: cycle through all values in mut_list and correct.
    for i in mut_list:
        #get integer of position
        position = int(i[1:-2]);

        #subtract--it's off by 19, and re-cast as string
        position_corrected = position-19;
        position_corrected = str(position_corrected);

        #concatenate
        mutant_corrected = i[0]+position_corrected+i[-1];

        #append to corrected list
        mut_list_renumbered.append(mutant_corrected)

    #step 4: return
    return 


#input
#df = dataframe with escape predictions of circulating strains
#antibody = string value of antibody studied


def format_df(df, antibody):
    df_copy = df.copy();

    #create an 'accession' label instead for plotting
    df_copy['accession']=df_copy['strain'].apply(lambda x: x.split('_')[1])

    #create suffix for the pred_phenotype column since I'll be doing pairwise plotting
    new_col_name = 'pred_phenotype_'+antibody;
    df_copy= df_copy.rename(columns={"pred_phenotype": new_col_name})



    
    return df_copy;





In [None]:
#okay so it looks like this worked. Let's go ahead and cycle through the antibodies

antibody_list = ['17C7','CR4098','CR57','CTB012','RVA122','RVC20','RVC58','RVC68'];



dataframes = {}

# Loop through the list and read each CSV file
for antibody in antibody_list:
    #make file name
    file_name = antibody+'_filtered_mut_effects.csv_escape_prediction.csv';
    
    # Extract the name without the file extension for dictionary key
    name = antibody
    
    # Read the CSV file into a DataFrame
    df = pd.read_csv(file_name)

    #format dataframe
    df = format_df(df,antibody)
    # Store the DataFrame in the dictionary
    dataframes[name] = df



In [None]:


df_17C7 = dataframes['17C7']

df_RVC20 = dataframes['RVC20']
df_RVC58 = dataframes['RVC58']

df_CR4098 = dataframes['CR4098']
df_CR57 = dataframes['CR57']

df_RVC68 = dataframes['RVC68']
df_CTB012 = dataframes['CTB012']

df_RVA122 = dataframes['RVA122']


In [None]:
#merge everything

dfs = [df_17C7,df_RVC20,df_RVC58,df_CR4098,df_CR57,df_RVC68,df_CTB012,df_RVA122];

merged_df = dfs[0]

for df in dfs[1:]:
    merged_df = pd.merge(merged_df, df, on=['strain','all_aa_substitutions','measured_aa_substitutions','unmeasured_aa_substitutions','disallowed_aa_substitutions','n_disallowed_aa_substitutions','accession'], how='inner')



In [None]:
merged_df

In [None]:
import altair as alt
_ = alt.data_transformers.disable_max_rows()

In [None]:
# Create the scatter plot with a tooltip


scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_RVC20',
    y='pred_phenotype_CR57',
    tooltip=['accession','strain','measured_aa_substitutions']
).properties(
    title='RVC20 versus CR57'
)

# Display the plot
scatter_plot

In [None]:
subset = ['2718','JQ685952','KJ174664','OR500222'];

merged_df[merged_df['accession'].isin(subset)]

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_RVC20',
    y='pred_phenotype_RVC58',
    tooltip=['accession','strain','measured_aa_substitutions']
).properties(
    title='RVC20 versus RVC58'
)

# Display the plot
scatter_plot

In [None]:
merged_df[merged_df['accession']=='2008-XX-XX']

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_17C7',
    y='pred_phenotype_RVC58',
    tooltip=['accession','strain','measured_aa_substitutions']
).properties(
    title='RVC58 versus 17C7'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_17C7',
    y='pred_phenotype_CR57',
    tooltip=['accession','strain','measured_aa_substitutions']
).properties(
    title='CR57 versus 17C7'
)

# Display the plot
scatter_plot

In [None]:
merged_df[merged_df['accession']=='2008-XX-XX']


In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_17C7',
    y='pred_phenotype_RVA122',
    tooltip=['accession','strain']
).properties(
    title='RVA122 versus 17C7'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CTB012',
    y='pred_phenotype_RVA122',
    tooltip=['accession','strain']
).properties(
    title='CTB012 versus RVA122'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CR4098',
    y='pred_phenotype_17C7',
    tooltip=['accession','strain']
).properties(
    title='CR4098 vs 17C7'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CR4098',
    y='pred_phenotype_CR57',
    tooltip=['accession','strain']
).properties(
    title='CR4098 vs CR57'
)

# Display the plot
scatter_plot

In [None]:
#CR4098 subset
CR4098_subset = ['4125','JQ685952','0000069','KJ174664','0625','2718']

merged_df[merged_df['accession'].isin(CR4098_subset)==True]

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_RVC58',
    y='pred_phenotype_RVC20',
    tooltip=['accession','strain']
).properties(
    title='RVC58 vs RVC20'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CR57',
    y='pred_phenotype_RVC20',
    tooltip=['accession','strain']
).properties(
    title='CR57 vs RVC20'
)

# Display the plot
scatter_plot

In [None]:
merged_df[merged_df['accession']=='0001663']

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CR4098',
    y='pred_phenotype_17C7',
    tooltip=['accession','strain']
).properties(
    title='CR4098 vs 17C7'
)

# Display the plot
scatter_plot

In [None]:
merged_df[merged_df['accession']=='0625']

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CTB012',
    y='pred_phenotype_RVC68',
    tooltip=['accession','strain']
).properties(
    title='CTB012 vs RVC68'
)

# Display the plot
scatter_plot

In [None]:
#pull out RVC68 strains
candidates = ['DQ875051', 'KM492765', 'LC071944'];
merged_df[merged_df['accession'].isin(candidates)]



In [None]:
#pull out CTB012 strains

KF620489, KC792208, 

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_RVC20',
    y='pred_phenotype_RVA122',
    tooltip=['accession','strain']
).properties(
    title='RVC20 vs RVA122'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_RVC20',
    y='pred_phenotype_RVC58',
    tooltip=['accession','strain']
).properties(
    title='RVC20 vs RVC58'
)

# Display the plot
scatter_plot

In [None]:
scatter_plot = alt.Chart(merged_df).mark_circle(size=60).encode(
    x='pred_phenotype_CR57',
    y='pred_phenotype_RVA122',
    tooltip=['accession','strain']
).properties(
    title='CR57 vs RVA122'
)

# Display the plot
scatter_plot

In [None]:
#collect strain information

accession = ['