## Make whole-protein figures for NS2B-NS3 dataset
Currently, the pipeline performs per-tile anaylsis. Here, I'll concatenate together amino acid preferences and mutational tolerance datasets from across three tiles. I'll use these concatenated datasets make plots for the paper. 

In [49]:
import os
import altair as alt

import matplotlib.pyplot as plt

import numpy as np

import pandas as pd


# ignore maxrows
alt.data_transformers.disable_max_rows()


DataTransformerRegistry.enable('default')

In [5]:
datadir = '../data'
resultsdir = '../results'

tiles = ['tile_1', 'tile_2', 'tile_3']


In [52]:

prefs_concat = pd.DataFrame()
muteffects_concat = pd.DataFrame()

for t in tiles:
    prefsfile = os.path.join(resultsdir, t, 'prefs', 'prefs_virus.csv')
    muteffectsfile = os.path.join(resultsdir, t, 'muteffects', 'virus_muteffects.csv')

    # load prefs
    prefs_df = pd.read_csv(prefsfile)
    # # renumber sites
    # if t == 'tile_1':
    #     prefs_df['renumbered_site'] = list(range(1,len(prefs_df)+1))
    # elif t == 'tile_2':
    #     prefs_df['renumbered_site'] = list(range(1000,1000+len(prefs_df)))
    # elif t == 'tile_3':
    #     prefs_df['renumbered_site'] = list(range(2000,2000+len(prefs_df)))

     # load muteffects
    muteffects_df = pd.read_csv(muteffectsfile)
    # # renumber sites
    # if t == 'tile_1':
    #     muteffects_df['renumbered_site'] = list(range(1,len(muteffects_df)+1))
    # elif t == 'tile_2':
    #     muteffects_df['renumbered_site'] = list(range(1000,1000+len(muteffects_df)))
    # elif t == 'tile_3':
    #     muteffects_df['renumbered_site'] = list(range(2000,2000+len(muteffects_df)))

    # concatenate prefs and muteffects to concat dfs
    prefs_concat = pd.concat([prefs_concat, prefs_df])
    muteffects_concat = pd.concat([muteffects_concat, muteffects_df])

In [83]:
# get wildtype map
muteffects_concat[['site','wildtype']].drop_duplicates().reset_index(drop=True)

Unnamed: 0,site,wildtype
0,1,R
1,2,S
2,3,W
3,4,P
4,5,P
...,...,...
303,304,E
304,305,T
305,306,P
306,307,V


## Make prefs heatmap

In [53]:
# reshape data
prefs_concat_stacked = (pd.DataFrame(prefs_concat.set_index('site')
                                   .stack(future_stack=True), columns = ['pref'])
                        .reset_index()
                        .rename(columns = {'level_1': 'mutation'})

                       )

prefs_concat_stacked

Unnamed: 0,site,mutation,pref
0,1,A,0.00265
1,1,C,0.00669
2,1,D,0.00694
3,1,E,0.00377
4,1,F,0.00770
...,...,...,...
6155,308,S,0.05766
6156,308,T,0.05049
6157,308,V,0.03815
6158,308,W,0.00945


In [93]:

## Make plot
charts = []
ranges = [list(range(1,104)), list(range(104, 207)), list(range(207, 309))]

for r in ranges:

    ## identify data
    data = prefs_concat_stacked
    data = prefs_concat_stacked[prefs_concat_stacked['site'].isin(r)]
    
    
    heatmap = alt.Chart(data).mark_rect(stroke='black').encode(
        x = alt.X('site:O',
                  # sort = viruses,
                  axis=alt.Axis(labelFontSize=12, title="site")),
        y = alt.Y('mutation:O', 
                  axis=alt.Axis(labelFontSize=12, title="mutation")),
        color=alt.Color('pref:Q', scale = alt.Scale(
            # type='log',
                                                     scheme='lightgreyteal')
                       ), 
    )

    charts.append(heatmap)

(alt.concat(*charts, title = '', columns = 1)
 # .resolve_scale(y='shared')
 .configure_title(fontSize=18)
 .configure_legend(titleFontSize=20, 
                   labelFontSize = 18,
                   strokeColor='gray',
                   # fillColor='#EEEEEE',
                   padding=10,
                   cornerRadius=10,
                   labelLimit = 500)
)

## Make mutational effects heatmap

In [84]:
# no need to reshape data
muteffects_concat

Unnamed: 0,site,wildtype,mutant,mutation,effect,log2effect
0,1,R,A,R1A,0.003142,-8.31430
1,1,R,C,R1C,0.007931,-6.97830
2,1,R,D,R1D,0.008227,-6.92530
3,1,R,E,R1E,0.004469,-7.80570
4,1,R,F,R1F,0.009128,-6.77540
...,...,...,...,...,...,...
2035,308,E,S,E308S,1.259500,0.33285
2036,308,E,T,E308T,1.102900,0.14128
2037,308,E,V,E308V,0.833330,-0.26303
2038,308,E,W,E308W,0.206420,-2.27630


In [95]:
# Make plot

# initialize empty list for plots
charts = []

# define site ranges for heatmaps 
ranges = [list(range(1,104)), list(range(104, 207)), list(range(207, 309))]

for r in ranges:

    ## identify data
    # data = prefs_concat_stacked
    data = muteffects_concat[muteffects_concat['site'].isin(r)]
    
    
    heatmap = alt.Chart(data).mark_rect(stroke='black').encode(
        x = alt.X('site:O',
                  # sort = viruses,
                  axis=alt.Axis(labelFontSize=12, title="site")),
        y = alt.Y('mutant:O', 
                  axis=alt.Axis(labelFontSize=12, title="mutation")),
        color=alt.Color('effect:Q', scale = alt.Scale(
            type='log',
                                                     scheme='redblue')
                       ), 
    )

    charts.append(heatmap)

(alt.concat(*charts, title = '', columns = 1)
 # .resolve_scale(y='shared')
 .configure_title(fontSize=18)
 .configure_legend(titleFontSize=20, 
                   labelFontSize = 18,
                   strokeColor='gray',
                   # fillColor='#EEEEEE',
                   padding=10,
                   cornerRadius=10,
                   labelLimit = 500)
)