Pairwise sequence alignment of each clade, with 19A as the baseline. We select the most seen sequence of each clade to represent the clade. 

In [1]:
import pandas as pd
import os

project_folder = os.path.join("..", "..")

In [111]:
df = pd.read_csv(os.path.join(project_folder, 'data', 'final', "all.csv"))
df.head(1)

Unnamed: 0,state,p_sequence,p_accession,date,count,n_accession,n_sequence,clade,timespan
0,MA,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QTP71261,2020,2,MW885877,GGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTA...,20A,1


In [112]:
temp = df.groupby('clade', as_index = False).max('count')
df = df.merge(temp, on=['clade', 'count']).sort_values('clade', ignore_index = True) # reindex so order of clade matches index
df.head(20)

Unnamed: 0,state,p_sequence,p_accession,date,count,n_accession,n_sequence,clade,timespan_x,timespan_y
0,CA,"""MFVFFVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QOF09449,2020-03-22,19,MW064459,AGATCTGTTCTTTAAACGAACTTTAAAATCTGTGTGGCTGTCACTC...,19A,2,5
1,WA,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QLJ57227,2020,1367,MT252714,CTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGT...,19B,1,11
2,MD,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QVR40114,2/1/21,450,MZ267382,ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGA...,20A,6,11
3,CA,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QSW59156,2/1/21,659,MW739404,AACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTA...,20B,6,11
4,VA,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QTN63407,2021-01,3558,MW868904,ACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATC...,20C,5,11
5,PA,"""MFVFFVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QWU02970,2021-06-03,11,MZ412281,CTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCT...,20D,10,11
6,NY,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QQX01070,2021-01-02,23,MW518168,GTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAA...,20E (EU1),5,8
7,TX,"""MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QVR40162,2/1/21,3755,MZ267386,ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGA...,20G,6,11
8,FL,"""MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRS...",UAR25116,2021-07-11,3,OK129036,AGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTC...,"20H (Beta, V2)",11,11
9,VA,"""MFVFFVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRS...",QXP14639,2021-06,2493,MZ571287,TTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTG...,"20I (Alpha, V1)",10,11


In [147]:
from Bio import Align
from Bio.Align import substitution_matrices
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [166]:
def make_aligner():
    aligner = Align.PairwiseAligner(match = 4,
         mismatch = -1, 
         target_open_gap_score = -1000, 
         target_extend_gap_score  = -1000,
         query_open_gap_score = -20,
         query_extend_gap_score = -4) # parameters from this paper https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7314508/

    substitution_matrices.load() 
    matrix = substitution_matrices.load("BLOSUM62")
    aligner.substitution_matrix = matrix
    return aligner

def get_relevent_seqs(df,baseline_index, clade_index):
    seq1 = df.loc[baseline_index]
    seq2 = df.loc[clade_index]
    aligner = make_aligner()
    alignments = aligner.align(seq1['p_sequence'][1:-1], seq2['p_sequence'][1:-1])
    target_row = df.loc[df['p_accession'] == seq1['p_accession']].to_dict('records')[0]
    query_row =df.loc[df['p_accession'] == seq2['p_accession']].to_dict('records')[0]
    return alignments, target_row, query_row
    
def calc_traces_labels_and_subtitle(df,baseline_index, clade_index, debug = False):
    alignments, target_row, query_row = get_relevent_seqs(df,baseline_index, clade_index)
    string_alignments = alignments[0].format().split('\n')
    target = string_alignments[0]
    match_type = string_alignments[1]
    query = string_alignments[2]
    y_height = 1
    length = len(target)
    if debug:
        print(length)
        
    x_insert =[None]
    x_delete =[None]
    x_sub_low =[None]
    x_sub_high =[None] # none means it will show in legend even when there is no data

    labels = []
    yshift = 18
    level = 0
    for i in range(0, length):
        if match_type[i] == '.':
            # substitution, low conservation
            x_sub_low.append(i)
            labels.append({"text":target[i] + str(i) + query[i], 'x':i, 'y':y_height , 
                           'font' :{'size': 12}, 'showarrow':False, 'yshift': yshift * (level % 2 + 1), 'yanchor':'middle', 'xanchor':'center'})
           
            if yshift < 0:
                level = level + 1
            yshift = -yshift
        elif match_type[i] == ':':
            # subsitution, high conservation
            x_sub_high.append(i)
            labels.append({"text":target[i] + str(i) + query[i], 'x' :i,  'y':y_height , 
                           'font' :{'size': 12}, 'showarrow':False, 'yshift': yshift * (level % 2 + 1), 'yanchor':'middle', 'xanchor':'center'})
            if yshift < 0:
                level = level + 1
            yshift = -yshift
        elif match_type[i] == '-':
            # deletion or insertion
            if target[i] =='-':
                # insertion
                x_insert.append(i)
                labels.append({"text":str(i) + query[i], 'x' :i,  'y':y_height , 
                           'font' :{'size': 12}, 'showarrow':False, 'yshift': yshift * (level % 2 + 1), 'yanchor':'middle', 'xanchor':'center'})
                if yshift < 0:
                    level = level + 1
                yshift = -yshift
            else: 
                # deletion (alg not allowed to match gap to gap)
                x_delete.append(i)
                labels.append({"text": target[i] + str(i), 'x' :i, 'y':y_height , 
                           'font' :{'size': 12}, 'showarrow':False, 'yshift': yshift * (level % 2 + 1), 'yanchor':'middle', 'xanchor':'center'})
                if yshift < 0:
                    level = level + 1
                yshift = -yshift
                
    total = len(x_insert)+ len(x_delete) +len(x_sub_low) + len(x_sub_high) - 4 # total number of changes
    if total == 0:
        # same sequence /exact match
        labels.append({'x':680, 'y':1, 'text':'Same Sequence', 'yanchor':'bottom','showarrow':False, 'textangle':0, 'font':{'size':16}})
        total = 1 # exact match, won't affect bars because each is 0, do want to avoid division by zero
    counts = [len(x_insert)-1, len(x_delete) - 1, len(x_sub_low) -1 , len(x_sub_high) - 1 ]
    names = ['Insertions', 'Deletions', 'Substitution, semi-conservative', 'Substitution, conservative']
    
    if debug:
        print(total)

    score_percentage = (alignments[0].score / 6722)# best score is when all positions match, match contributes score of 4
    
    traces =  [go.Scatter( # base gray line
            mode = "lines",
            x =  [0, length],
            y = [y_height, y_height], 
            name = "sequence line",
            line = {'color':  'gray'},
            showlegend = False,
            hoverinfo = 'skip',
            visible = False
        ),
         go.Scatter( # S1  line
            mode = "lines",
            x =  [14, 685],
            y = [y_height, y_height], 
            name = "S1",
            line = {'color':  'plum', 'width' : 3},
            legendgroup = 'Subunit',
            showlegend = True,
            hoverinfo = 'skip',
            visible = False
        ),
         go.Scatter( # S2  line
            mode = "lines",
            x =  [686, 1273],
            y = [y_height, y_height], 
            name = "S2",
            line = {'color':  'pink', 'width':3},
            legendgroup = 'Subunit',
            showlegend = True,
            hoverinfo = 'skip',
            visible = False
        ),
        go.Scatter( # insert markers
            mode = "markers",
            x = x_insert ,
            y = [y_height for y in x_insert], 
            name = "insertion",
            marker=dict(size=8, color = 'orange', symbol = 'square'),
            legendgroup = 'changes',
            showlegend = True,
            hoverinfo = 'x',
            visible = False
        ),
          go.Scatter( # delete markers
            mode = "markers",
            x = x_delete,
            y = [y_height for y in x_delete], 
            name = "deletion",
            marker=dict(size=8, color = 'red', symbol = 'square'),
            legendgroup = 'changes',
            showlegend = True,
            hoverinfo = 'x',
            visible = False
        ),
         go.Scatter( # sub semi markers
            mode = "markers",
            x = x_sub_low ,
            y = [y_height for y in x_sub_low], 
            name = "substitution, semi-conservative",
            marker=dict(size=8, color = 'blue', symbol = 'circle'),
            legendgroup = 'changes',
             showlegend = True,
            hoverinfo = 'x',
            visible = False
        ),
          go.Scatter( # sub conservative markers
            mode = "markers",
            x = x_sub_high ,
            y = [y_height for y in x_sub_high], 
            name = "substitution, conservative",
            marker=dict(size=8, color = 'lightblue', symbol = 'circle'),
            legendgroup = 'changes',
            showlegend = True,
            hoverinfo = 'x',
            visible = False
        ),
         # go.Bar( # score bar
          #   x = [length,  score_percentage * length], 
           #  y = [0.08, 0.08],
            # orientation = 'h',
             #showlegend = False,
          #   width = 0.16,
           #  text = [None, int(alignments[0].score)], # score should always be int, we don't need decimal precision and don't want .0 cluttering
            # textposition = ['inside', "outside"],
            # marker_color =['lightgray', 'lightgray'],
             #marker = dict(pattern = {'bgcolor':['lightgray', 'gray'], 'shape':['', '/'], 'size':5}),
             #hoverinfo = 'skip',
            #visible = False
         #) ,
        go.Bar( # counts bars
             name = "counts",
             x = [1540, 1590, 1640,1700], 
             y = [(c/total) *0.85 for c in counts ],
             showlegend = False,
             width = 30,
             text =  counts, 
             marker_color =['orange', 'red', 'blue', 'lightblue'],
             marker = {'line' :dict(width = 2, color = ['orange','red','blue','lightblue'])},
             textposition = 'outside',
             hoverinfo = 'skip', # can't get the names without them in x,y,text
            visible = False
         )

    ]
    
    labels.append({'x': 0.1, 'y':0 ,'text': "Alignment Score: {:.2f}%".format(score_percentage * 100) ,'font': { 'size': 14},'showarrow':False, 'xanchor':'left', 'yanchor':'bottom' })
    labels.append({'x':1450, 'y':0, 'text':'counts', 'yanchor':'bottom','showarrow':False, 'textangle':0})
    subtitle = "{2} ({3}) difference from {0} ({1})".format(target_row['clade'], target_row['p_accession'], query_row['clade'], query_row['p_accession'])
    title = "Spike Protein Comparison<br><sup>of Most Frequent Sequence of Selected Clade to Baseline <br>{0}<sup>".format(subtitle)

    return [dict(data = traces), dict(annotations = labels, title = dict(text = title, x = 0.1))]
        

def plot_alignment_plotly(df, save = False, debug = False):
    pairwise = dict()
 
    
    
    def get_state(pairwise, query):
        states = pairwise[query]

        visible = []
        for i in range(1,20):
            if i == query: 
                visible = visible + [True, True, True, True, True, True, True, True]
            else:
                visible = visible + [False, False,False,False,False,False,False,False]
        return [dict(visible = visible),states[1]]
    
    
    
   
    traces = []
    for i  in range(1,20):
        pairwise[i] = calc_traces_labels_and_subtitle(df, 0, i)
        new_traces = pairwise[i][0]['data']
        
        if i == 1:
            for t in new_traces:
                t.visible = True
                
        traces = traces + new_traces
    fig = go.Figure(traces)        
    states = pairwise[1]

    fig.update_xaxes(
            range = [-15,1850],
            showgrid=False, # thin lines in the background
            zeroline= False, # thick line at x=0
            visible= False  # numbers below
            )

    fig.update_yaxes(
            range =[0,1.5],
            showgrid= False, # thin lines in the background
            zeroline= False, # thick line at x=0
            visible= False  # numbers below
            )
    
    button_layer_1_height = 0.7
    fig.update_layout(plot_bgcolor='white',
                      height=380,
                      width = 1100,
                      barmode = 'overlay',
                      title = states[1]['title'],
                      legend=dict(
                                yanchor="middle",
                                y=1,
                                xanchor="right",
                                x=1
                    ),
                      annotations= states[1]['annotations'],
                      
                      updatemenus=[
                               
                                  dict(
                                    buttons=list([
                                        dict(
                                            label="19B",
                                             method="update",
                                             args = get_state(pairwise,query = 1)
                                        ),
                                        dict(
                                            label="20A",
                                            method="update",
                                            args = get_state(pairwise,query = 2)
                                        ),
                                        dict(
                                            label="20B",
                                            method="update",
                                            args = get_state(pairwise,query = 3)
                                        ),
                                        dict(
                                            label="20C",
                                            method="update",
                                            args = get_state(pairwise,query = 4)
                                        ),
                                        dict(
                                            label="20D",
                                            method="update",
                                            args = get_state(pairwise,query = 5)
                                        ),
                                        dict(
                                            label="20E (EU1)",
                                            method="update",
                                            args = get_state(pairwise,query = 6)
                                        ),
                                        dict(
                                            label="20G",
                                            method="update",
                                            args = get_state(pairwise,query = 7)
                                        ) ,
                                        dict(
                                            label="20H (Beta, V2)",
                                            method="update",
                                            args = get_state(pairwise,query = 8)
                                        )  ,
                                        dict(
                                            label="20I (Alpha, V1)",
                                            method="update",
                                            args = get_state(pairwise,query = 9)
                                        ),
                                         dict(
                                            label="20J (Gamma, V3)",
                                             method="update",
                                             args = get_state(pairwise,query = 10)
                                        ),
                                        dict(
                                            label="21A (Delta)",
                                            method="update",
                                            args = get_state(pairwise,query = 11)
                                        ),
                                        dict(
                                            label="21B (Kappa)",
                                            method="update",
                                            args = get_state(pairwise,query = 12)
                                        ),
                                        dict(
                                            label="21C (Epsilon)",
                                            method="update",
                                            args = get_state(pairwise,query = 13)
                                        ),
                                        dict(
                                            label="21D (Eta)",
                                            method="update",
                                            args = get_state(pairwise,query = 14)
                                        ),
                                        dict(
                                            label="21F (Iota)",
                                            method="update",
                                            args = get_state(pairwise,query = 15)
                                        ),
                                        dict(
                                            label="21G (Lambda)	",
                                            method="update",
                                            args = get_state(pairwise,query = 16)
                                        ) ,
                                        dict(
                                            label="21H (Mu)",
                                            method="update",
                                            args = get_state(pairwise,query = 17)
                                        )  ,
                                        dict(
                                            label="21I (Delta)",
                                            method="update",
                                            args = get_state(pairwise,query = 18)
                                        ),
                                        dict(
                                            label="21J (Delta)",
                                            method="update",
                                            args = get_state(pairwise,query = 18)
                                        )  
                                        
                                    ]),
                                    direction = 'down',
                                    pad={"r": 10, "t": 10},
                                    showactive=True,
                                    x=0.5,
                                    xanchor="left",
                                    y=1.4,
                                    yanchor="top"
                                    
                                )]
    )
    fig.show()
    if save:
        fig.write_html('../../visualizations/compare_most_frequent_seq_for_each_clade_to_19A.html')

In [167]:
plot_alignment_plotly(df, save = True, debug = False)