Compare delta

In [1]:
import pandas as pd
import os
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

project_folder = os.path.join("..", "..")

In [2]:
df = pd.read_csv(os.path.join(project_folder, 'data', 'final', "all.csv"))
df.head(1)

Unnamed: 0,state,p_sequence,p_accession,date,count,n_accession,n_sequence,clade,timespan
0,MA,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QTP71261,2020,2,MW885877,GGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTA...,20A,1


In [3]:
def get_base(df):
    temp =  df[df['clade'] == '19A'].groupby('clade', as_index = False).max('count')
    df = df.merge(temp, on=['clade', 'count']).sort_values('clade', ignore_index = True) # reindex so order of clade matches index
    return df.loc[0]
    
top_19A = get_base(df)

In [4]:
def get_deltas(df, delta):
    if delta == 'all':
        return df[(df['clade'] == '21A (Delta)') | (df['clade'] == '21I (Delta)') | (df['clade'] == '21J (Delta)')]
    return df[df['clade'] == delta].sort_values('count', ascending = False)

In [5]:
df_all_delta = get_deltas(df, 'all')

In [6]:
colors = {'21A (Delta)':px.colors.qualitative.Safe[0], '21I (Delta)':px.colors.qualitative.Safe[1],'21J (Delta)':px.colors.qualitative.Safe[2]}
US_regions = {"West":["WA", "OR", "ID", "MT", "WY", "CA", "NV", "UT", "CO", "AZ", "NM"],
              "Midwest":["ND", "MN", "WI", "MI", "OH", "IN", "IL", "MO", "IA", "SD", "NE", "KS"],
             "Northeast":["ME", "VT", "NH", "MA","CT", "RI", "NY", "NJ", "PA"],
             "South":["TX", "OK", "AR", "LA", "MS", "TH", "KY", "WV", "MD", "DE", "VA", "NC", "SC", "GA", "AL", "FL"]}
mapping = {}
for region, states in US_regions.items():
    for s in states:
        mapping[s] = region

df_all_delta = df_all_delta.copy()
df_all_delta['region'] =  df_all_delta['state'].map(mapping)
df_all_delta['color'] =  df_all_delta['clade'].map(colors)

In [7]:
df_21A = get_deltas(df_all_delta, '21A (Delta)')
df_21I = get_deltas(df_all_delta, '21I (Delta)')
df_21J = get_deltas(df_all_delta, '21J (Delta)')

bubble map
x axis - data emerged in US (timeline)

y axis - state
size -  count
color - which delta clade

In [17]:
def add_clade_trace(df, fig, name):
    customdf = np.stack((df['count'], df['p_accession'] ),axis = -1)
    fig.add_trace(
        go.Scatter(mode = 'markers', x=df["date"], y=[df['region'],df["state"]],
                        showlegend = True,
                   name = name,meta = [name],
                   customdata = customdf,
                   hovertemplate = "<b>%{customdata[1]}</b><br>%{customdata[0]} reports hear after<br>%{meta[0]}<extra>%{y[0]}<br>%{y[1]}<br>%{x}<br></extra>",
                     marker = dict(
                         size=df["count"],
                         sizeref = 2, 
                         sizemin = 3,
                         color=df["color"],
                         opacity = 0.75),
                  ))

fig = go.Figure()

add_clade_trace(df_21J, fig, "21J")
add_clade_trace(df_21I, fig, "21I")
add_clade_trace(df_21A, fig, "21A")





fig.update_layout(
            title = dict(text = "Delta Emergence in the USA", x= 0.5),
            plot_bgcolor = 'white',
            legend = dict(itemsizing = 'constant', traceorder = "reversed"),
            height = 800
            )
fig.update_yaxes(
            title = dict(text = 'State'),
            showline = True,
            gridcolor = 'lightgrey',
            categoryorder = 'array',
            categoryarray = ['West', 'Midwest', 'Northeast', 'South'],
            type = 'multicategory',
            showdividers=True,
            tickson = 'labels'
)
fig.update_xaxes(
            title = dict(text = 'Date')
)
config = dict(scrollZoom = False, doubleClick = 'reset', displayModeBar = True,
                  modeBarButtonsToRemove=['zoom2d','zoomIn2d', 'zoomOut2d','autoScale2d','lasso2d','select2d'])
fig.show(config = config)
fig.write_html('../../visualizations/delta_emergence.html', config = config)

Jun 1st dates may actually be later in month (missing day). All emerged in the us around the same time (begining of july). 21A seems to have the most cases and the two sequences that are seen a lot, first recorded in NJ and Florida. There looks to be one sequence of 21J with more counts. 

Want to compare the actual sequence next

In [9]:
df_all_delta.sort_values('count', ascending = False)[0:10]

Unnamed: 0,state,p_sequence,p_accession,date,count,n_accession,n_sequence,clade,timespan,region,color
7531,NJ,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QZD09591,2021-07-01,750,MZ849750,CTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACT...,21A (Delta),11,Northeast,"rgb(136, 204, 238)"
7608,FL,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QYJ42034,2021-07-02,393,MZ705184,CCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAA...,21A (Delta),11,South,"rgb(136, 204, 238)"
7702,LA,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QXO37508,2021-07-05,129,MZ568101,AGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCT...,21J (Delta),11,South,"rgb(221, 204, 119)"
6841,VA,"""MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRS...",QXL91697,2021-06,95,MZ542744,CTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGT...,21A (Delta),10,South,"rgb(136, 204, 238)"
7622,TX,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QYB96653,2021-07-02,58,MZ655453,AGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTC...,21A (Delta),11,South,"rgb(136, 204, 238)"
7607,FL,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QYJ41843,2021-07-02,57,MZ705167,ACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACT...,21A (Delta),11,South,"rgb(136, 204, 238)"
7542,TX,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QYP52957,2021-07-01,33,MZ742858,ACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACT...,21A (Delta),11,South,"rgb(136, 204, 238)"
7704,AL,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QZD01238,2021-07-06,29,MZ848991,ATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCT...,21A (Delta),11,South,"rgb(136, 204, 238)"
7749,LA,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QXQ22426,2021-07-06,29,MZ577498,GGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTT...,21J (Delta),11,South,"rgb(221, 204, 119)"
7908,OH,"""MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRS...",QYL32654,2021-07-11,28,MZ714377,ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGA...,21A (Delta),11,Midwest,"rgb(136, 204, 238)"


In [10]:
from Bio import Align
from Bio.Align import substitution_matrices
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

In [11]:
def make_aligner():
    aligner = Align.PairwiseAligner(match = 4,
         mismatch = -1, 
         target_open_gap_score = -1000, 
         target_extend_gap_score  = -1000,
         query_open_gap_score = -20,
         query_extend_gap_score = -4) # parameters from this paper https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7314508/

    substitution_matrices.load() 
    matrix = substitution_matrices.load("BLOSUM62")
    aligner.substitution_matrix = matrix
    return aligner

def get_relevent_seqs(baseline, other):
    aligner = make_aligner()
    alignments = aligner.align(baseline['p_sequence'][1:-1], other['p_sequence'][1:-1])
    return alignments

In [12]:
   
def calc_traces_labels_and_subtitle(baseline, other, y_height, showlegend = False, clades_in_legend = [], debug = False):
    alignments = get_relevent_seqs(baseline, other)
    target, query = baseline['p_accession'][1:-1], other['p_accession'][1:-1]
    string_alignments = alignments[0].format().split('\n')
    target = string_alignments[0]
    match_type = string_alignments[1]
    query = string_alignments[2]
    length = len(target)
    if debug:
        print(length)
        
    x_insert =[None]
    x_delete =[None]
    x_sub_low =[None]
    x_sub_high =[None] # none means it will show in legend even when there is no data
    labels = {'insert':[None], 'delete':[None], 'low':[None], 'high':[None]}

    yshift = 18
    level = 0
    for i in range(0, length):
        if match_type[i] == '.':
            # substitution, low conservation
            x_sub_low.append(i + 1)
            labels['low'].append(target[i] + str(i + 1) + query[i])
        elif match_type[i] == ':':
            # subsitution, high conservation
            x_sub_high.append(i+1)
            labels['high'].append(target[i] + str(i + 1) + query[i])
        elif match_type[i] == '-':
            # deletion or insertion
            if target[i] =='-':
                # insertion
                x_insert.append(i+ 1)
                labels['insert'].append(str(i+ 1) + query[i])
            else: 
                # deletion (alg not allowed to match gap to gap)
                x_delete.append(i+1)
                labels['delete'].append(target[i] + str(i+1))
                
    flat_labels = set(labels['insert'] + labels['delete'] + labels['low'] + labels['high'])
    # sort by index
    flat_labels.remove(None)
                
    total = len(x_insert)+ len(x_delete) +len(x_sub_low) + len(x_sub_high) - 4 # total number of changes
    if total == 0:
        total = 1 # exact match, won't affect bars because each is 0, do want to avoid division by zero
    counts = [len(x_insert)-1, len(x_delete) - 1, len(x_sub_low) -1 , len(x_sub_high) - 1 ]
    names = ['Insertions', 'Deletions', 'Substitution, semi-conservative', 'Substitution, conservative']
    
    if debug:
        print(total)

    score_percentage = (alignments[0].score / 6722)# best score is when all positions match, match contributes score of 4
    
    if other["clade"] not in clades_in_legend:
        clades_in_legend = True
    else:
        clades_in_legend = False
    
    traces =  [go.Scatter( # base gray line
            mode = "lines",
            x =  [1, length],
            y = [y_height, y_height], 
            name = "sequence line",
            line = {'color':  'gray'},
            showlegend = False,
            hoverinfo = 'skip',
            visible = True,
            yaxis='y'
        ),
        go.Bar(
           name = other["clade"],
           y = [y_height],
           x= [1278],
           orientation = 'h',
           yaxis = 'y',
           #width = 100,
           marker_color = "rgba(" +other['color'][4:-1] + ", 0.15)",
           showlegend = clades_in_legend,
           legendgroup = "Clade",
            legendgrouptitle_text = "Clade",
            legendgrouptitle_font_size = 14,
           visible = True,
           #marker_line_width = 0,
           customdata=  pd.DataFrame(data=[[other['clade'], other['count'], len(x_insert) - 1, len(x_delete)-1, len(x_sub_low) - 1, len(x_sub_high)-1]], columns=['clade', 'count', 'insertions', 'deletions','sublow', 'subhigh']),
           hovertemplate = "<b>%{y}</b><br>Clade: %{customdata[0]}<br>Count: %{customdata[1]}<br><br>%{customdata[2]} inserts<br>%{customdata[3]} deletions<br>%{customdata[4]} semi-comservative substitutions<br>%{customdata[5]} conservative substitutions<extra></extra>"
        ),
         go.Scatter( # S1  line
            mode = "lines",
            x =  [14, 685],
            y = [y_height, y_height], 
            name = "S1",
            line = {'color':  'plum', 'width' : 3},
            legendgroup = 'Subunit',
            showlegend = showlegend,
             legendgrouptitle_text = "Subunit",
             legendgrouptitle_font_size = 14,
            hoverinfo = 'skip',
            visible = True,
            yaxis='y2'
        ),
         go.Scatter( # S2  line
            mode = "lines",
            x =  [686, 1273],
            y = [y_height, y_height], 
            name = "S2",
            line = {'color':  'pink', 'width':3},
            legendgroup = 'Subunit',
            showlegend = showlegend,
            hoverinfo = 'skip',
            visible = True,
            yaxis='y2'
        ),
        go.Scatter( # insert markers
            mode = "markers",
            x = x_insert ,
            y = [y_height for y in x_insert], 
            text = labels['insert'],
            name = "insertion",
            marker=dict(size=6, color = 'orange', symbol = 'square'),
            legendgroup = 'changes',
            showlegend = showlegend,
            legendgrouptitle_text = "Sequence Difference",
            legendgrouptitle_font_size = 14,
            hoverinfo = 'text',
            visible = True,
             yaxis='y2'
        ),
          go.Scatter( # delete markers
            mode = "markers",
            x = x_delete,
            y = [y_height for y in x_delete], 
            name = "deletion",
            marker=dict(size=6, color = 'red', symbol = 'square'),
            legendgroup = 'changes',
            showlegend = showlegend,
            text = labels['delete'],
            hoverinfo = 'text',
            visible = True,
               yaxis='y2'
        ),
         go.Scatter( # sub semi markers
            mode = "markers",
            x = x_sub_low ,
            y = [y_height for y in x_sub_low], 
            name = "substitution, semi-conservative",
            marker=dict(size=6, color = 'blue', symbol = 'circle'),
            legendgroup = 'changes',
             showlegend = showlegend,
            text = labels['low'],
            hoverinfo = 'text',
            visible = True,
              yaxis='y2'
        ),
          go.Scatter( # sub conservative markers
            mode = "markers",
            x = x_sub_high ,
            y = [y_height for y in x_sub_high], 
            name = "substitution, conservative",
            marker=dict(size=6, color = 'green', symbol = 'circle'),
            legendgroup = 'changes',
            showlegend = showlegend,
            text = labels['high'],
            hoverinfo = 'text',
            visible = True,
               yaxis='y2'
        )

    ]
    
    score_label = {'x':1300, 'y':y_height ,'text': "{:.2f}%".format(score_percentage * 100) ,'font': { 'size': 14},'showarrow':False, 'xanchor':'left', 'yanchor':'middle' }
    return dict(data = traces,  score_label = score_label, changes = flat_labels)
        


In [13]:

def plot_alignment_plotly(df_all_delta, top_19A, save = False, debug = False):
    df = df_all_delta.sort_values('count', ascending = False)[0:10]
    
    # build traces and tick labels
    traces = []
    accessions = []
    annotations = []
    clades = []
    shared_changes = None
    for i in range(10):
        row = df.iloc[i]
        if i == 0:
            s = True
        else:
            s  = False
        r = calc_traces_labels_and_subtitle(top_19A, row, 9.5 - i, showlegend = s, clades_in_legend =clades )
        clades.append(row['clade'])
        traces.extend(r['data'])
        accessions.append(row['p_accession'])
        annotations.append(r['score_label']['text'])
        if i == 0:
            shared_changes = r['changes']
        else:
            shared_changes = shared_changes.intersection(r['changes'])
    accessions.reverse()
    
    # sort the changes by index
    def get_index_from_label(label):
        try:
            l = int(label[1:]) # insertion
        except:
            try:
                l = int(label[:-1]) # deletion
            except:
                l = int(label[1:-1])
        return l
    shared_changes = list(shared_changes)
    shared_changes.sort(key = get_index_from_label)
    shared_changes_str = ''
    if len(shared_changes) < 4:
        shared_changes_str = ', '.join(shared_changes)
    else:
        n = 0
        s = iter(shared_changes)
        for a,b,c in zip(s,s,s):
            shared_changes_str = shared_changes_str + '<br>' + "        "+ ', '.join([a,b,c]) +','
            n = n + 3
        # check if 1 or 2 left
        n = len(shared_changes) - n
        if n > 0:
            shared_changes_str = shared_changes_str +"<br>" + "        "+ ', '.join(shared_changes[-n:])
    
    fig = go.Figure(traces)
    
    for change in shared_changes:
        i = get_index_from_label(change)
        fig.add_shape(
               type="line",
                x0= i, y0=0.1, x1= i, y1=9.9,
                line=dict(color="orange", width = 1),
            layer = 'below',
        )
    
    fig.update_xaxes(
            range = [-15,1288],
            showgrid=False, # thin lines in the background
            showspikes = True, # spike line draws a vertical line when hover on a substitution/insertion/deletion
            spikemode = 'across',
            spikethickness = 1,
            title = "Index"
            )

    fig.update_yaxes(
            range =[0,10],
            showgrid= False, # thin lines in the background
            )
    
    fig.update_layout(plot_bgcolor='white',
                      height=700,
                      width = 1200,
                      barmode = 'overlay',
                      title = "Top 10 Delta Sequences: Spike Protien Comparison to Top 19A Sequence",
                      legend=dict(
                                yanchor="top",
                                y=1,
                                xanchor="left",
                                x=1.2,
                                  font_size = 14  
                    ),
                      bargap=0,
                    annotations= [dict(text = "Shared Changes: {0}".format(shared_changes_str),
                                       xref='paper',yref='paper',x = 1.21, y = 0, xanchor = 'left', yanchor = 'bottom',
                                       bordercolor='orange', borderwidth=1, showarrow = False, 
                                       align = 'left', borderpad = 4,
                                       font_size = 14)],
                    yaxis=dict(title='Accession', range=(0, 10),
                               title_font_size = 14,
                              tickmode = 'array',
                              tickvals = [i+0.5 for i in range(0,10)],
                              ticktext = accessions,
                              tickfont = dict(size = 14)),
                    yaxis2=dict(title='Alignment Score', range=(0, 10) ,overlaying="y",
                                title_font_size = 14,
                                side='right' ,
                                tickmode = 'array',
                                tickvals = [i+0.5 for i in range(0,10)],
                                ticktext = annotations,
                                tickfont = dict(size = 14)
                               )
    )    
    config = dict(scrollZoom = False, doubleClick = 'reset', displayModeBar = True,
                  modeBarButtonsToRemove=['zoom2d','zoomIn2d', 'zoomOut2d','autoScale2d','lasso2d','select2d'])
    fig.show(config = config)
    if save:
        fig.write_html('../../visualizations/compare_delta.html', config = config)
        # center
        with open('../../visualizations/compare_delta.html', 'r') as f:
            lines = f.readlines()
        lines = ["<center>\n"] + lines + ["</center>\n"]    
        with open('../../visualizations/compare_delta.html', 'w') as f:
            f.writelines(lines)
        # save svg
        fig.write_image('../../visualizations/compare_delta.svg')

In [14]:
plot_alignment_plotly(df_all_delta, top_19A, save = True)