In [None]:
import pandas as pd
from pathlib import Path
import numpy as np
import altair as alt
from scipy.stats import pearsonr
import re

from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats

In [None]:
counts_path = '../Data/ATG_lib_data/counts'
x1a_annotations = '../Data/ATG_lib_data/X1A_annotations.xlsx'

In [None]:
def read_counts(path):
    data_path = Path(counts_path)

    all_files = list(data_path.glob('*tsv'))

    counts_dict = {}

    for path in all_files:
        df = pd.read_csv(path, sep = '\t')

        str_path = str(path)

        if 'lib_counts' in str_path:
            counts_dict['lib'] = df
        elif 'NC' in str_path:
            counts_dict['NC'] = df
        elif 'R1R2R3_D05' in str_path:
            counts_dict['D05_rep1'] = df
        elif 'R4R5R6_D05' in str_path:
            counts_dict['D05_rep2'] = df
        elif 'R7R8R9_D05' in str_path:
            counts_dict['D05_rep3'] = df
        elif 'R3_D13' in str_path:
            counts_dict['D13_rep1'] = df
        elif 'R6_D13' in str_path:
            counts_dict['D13_rep2'] = df
        elif 'R9_D13' in str_path:
            counts_dict['D13_rep3'] = df

    return counts_dict

In [None]:
def add_freq(dict):
    keys = list(dict.keys())

    freq_dicts = {}
    for key in keys:
        if key != 'NC':
            df = dict[key]

            total_count = df['Count'].sum()
            df['freq'] = df['Count'] / total_count

            freq_dicts[key] = df

    keys = list(freq_dicts.keys())
    return freq_dicts,keys

In [None]:
def mutate_snvs(dna_sequence): #Mutates all possible SNVs of provided DNA sequence
    snvs = []
    i = 0
    while i < len(dna_sequence):
        if dna_sequence[i] == "A":
            snvs.append(dna_sequence[:i] + "T" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "C" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "G" + dna_sequence[i + 1 :])
        elif dna_sequence[i] == "T":
            snvs.append(dna_sequence[:i] + "A" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "C" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "G" + dna_sequence[i + 1 :])
        elif dna_sequence[i] == "C":
            snvs.append(dna_sequence[:i] + "A" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "T" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "G" + dna_sequence[i + 1 :])
        else:
            snvs.append(dna_sequence[:i] + "A" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "T" + dna_sequence[i + 1 :])
            snvs.append(dna_sequence[:i] + "C" + dna_sequence[i + 1 :])
        i += 1
    return snvs

In [None]:
def reverse_complement_string(seq_string): #Reverse complement and returns string
    reverse_seq = seq_string[::-1]
    reverse_comp_list = []
    for char in reverse_seq:
        if char == "A":
            reverse_comp_list.append("T")
        elif char == "G":
            reverse_comp_list.append("C")
        elif char == "C":
            reverse_comp_list.append("G")
        else:
            reverse_comp_list.append("A")
    reverse_compliment_str = "".join(reverse_comp_list)
    return reverse_compliment_str

In [None]:
def compare_strings(str1, str2, coord_offset):
    list_str1 = []
    list_str2 = []

    for char in str1:
        list_str1.append(char)

    for char in str2:
        list_str2.append(char)

    i = 0 
    while i < len(list_str1):
        if list_str1[i] == list_str2[i]:
            i += 1
        else:
            output_str = str(i + coord_offset) + ':' + list_str2[i]
            
            return output_str
            
            i += 1

In [None]:
def annotate_snv_lib():
    seq = 'CCATGGA'

    coding_snvs = mutate_snvs(seq)
    rev_coding_snvs = {}

    for snv in coding_snvs:
        rev_coding_snvs[snv] = reverse_complement_string(snv)

    rev_seq = 'TCCATGG'

    mapped_snvs = {}
    for snv in coding_snvs:
        pos_id = compare_strings(rev_seq, rev_coding_snvs[snv], 214809490)
        mapped_snvs[snv] = pos_id



    return mapped_snvs
    

In [None]:
def annotate_vars(dict,keys, snv_map, annotations):

    annotation_df = pd.read_excel(annotations)
    annotated_dfs = {}
    for key in keys:
        df = dict[key]
        df['canonical_start'] = df['Sequence'].transform(lambda x: x[15:18])
        df['second_start'] = df['Sequence'].transform(lambda x: x[90:93])
        df['snv_lib'] = df['Sequence'].transform(lambda x: x[88:95])

        df['pos_id'] = df['snv_lib'].transform(lambda x: snv_map[x])
        df['start_pos_id'] = df['canonical_start'] + ':' + df['pos_id']
        df = pd.merge(df, annotation_df, how = 'left', on = 'pos_id')

        df = df[['Count', 'freq', 'canonical_start', 'second_start', 'pos_id', 'pos', 'allele', 'AAsub', 'Consequence','start_pos_id']]
        df = df.loc[:,['pos', 'allele', 'pos_id', 'Consequence', 'AAsub', 'canonical_start', 'second_start', 'Count', 'freq', 'start_pos_id']]

        new_count_name = key + '_count'
        new_freq_name = key + '_freq'

        df = df.rename(columns = {'Count': new_count_name, 'freq': new_freq_name})
        annotated_dfs[key] = df

    return annotated_dfs

        
        

In [None]:
def combine_dfs(dict, keys):
    base_df = dict['lib']
    base_df = base_df[['pos', 'allele', 'pos_id', 'Consequence', 'AAsub', 'canonical_start', 'second_start', 'start_pos_id']]

    for key in keys:
        to_merge = dict[key]

        count_name = key + '_count'
        freq_name = key + '_freq'
        to_merge = to_merge[['start_pos_id', count_name, freq_name]]

        base_df = pd.merge(base_df, to_merge, how = 'left', on = 'start_pos_id')
        

    return base_df

In [None]:
def qc_plots(df):
    d5_pairs = [('D05_rep1_freq', 'D05_rep2_freq'), ('D05_rep1_freq', 'D05_rep3_freq'), ('D05_rep2_freq', 'D05_rep3_freq')]
    d13_pairs = [('D13_rep1_freq', 'D13_rep2_freq'), ('D13_rep1_freq', 'D13_rep3_freq'), ('D13_rep2_freq', 'D13_rep3_freq')]

    d5_plots = []
    for pair in d5_pairs:
        rep1, rep2 = pair
        r, p_value = pearsonr(df[rep1], df[rep2])
        scatter = alt.Chart(df).mark_point().encode(
            x = rep1,
            y = rep2,
            tooltip = [alt.Tooltip('canonical_start', title = 'Canonical Start: '),
                       alt.Tooltip('second_start', title = 'Second Start: '),
                       alt.Tooltip('AAsub', title = 'AA Substitution')
                      ]
        ).properties(
            title = rep1 + ' vs. ' + rep2 + ' (r = ' + str(round(r,3)) + ')'
        ).interactive()

        trendline = scatter.transform_regression(rep1, rep2).mark_line(color = 'red')
        scatter = trendline + scatter
        d5_plots.append(scatter)

    d5_all = alt.hconcat(d5_plots[0], d5_plots[1], d5_plots[2])
    
    d13_plots = []
    for pair in d13_pairs:
        rep1, rep2 = pair
        r, p_value = pearsonr(df[rep1], df[rep2])
        
        scatter = alt.Chart(df).mark_point().encode(
            x = rep1,
            y = rep2,
            tooltip = [alt.Tooltip('canonical_start', title = 'Canonical Start: '),
                       alt.Tooltip('second_start', title = 'Second Start: '),
                       alt.Tooltip('AAsub', title = 'AA Substitution')
                      ]
        ).properties(
            title = rep1 + ' vs. ' + rep2 + ' (r = ' + str(round(r,3)) + ')'
        ).interactive()
        trendline = scatter.transform_regression(rep1, rep2).mark_line(color = 'red')

        scatter = trendline + scatter
        d13_plots.append(scatter)
        
    d13_all = alt.hconcat(d13_plots[0], d13_plots[1], d13_plots[2])

    full_corr_plot = d5_all & d13_all
    full_corr_plot.display()


    lib_plot = alt.Chart(df).mark_bar().encode(
        x = 'pos:O',
        y = 'lib_freq',
        color = 'canonical_start'
    ).properties(
        width = 600,
        height = 400
    )

    lib_plot.display()

In [None]:
def score(df):

    df = df.rename(columns = {'D05_rep1_count': 'D05_R1',
                              'D05_rep2_count': 'D05_R2',
                              'D05_rep3_count': 'D05_R3',
                              'D13_rep1_count': 'D13_R1',
                              'D13_rep2_count': 'D13_R2',
                              'D13_rep3_count': 'D13_R3'
                             }
                  )


    df = df.drop(columns = ['D05_rep1_freq', 'D05_rep2_freq', 'D05_rep3_freq',
                            'D13_rep1_freq', 'D13_rep2_freq', 'D13_rep3_freq'
                           ]
                )


    value_cols = [col for col in df.columns if "D13"  in col or "D05"  in col or "lib_count" in col]
    dfmelt = pd.melt(df, id_vars=['start_pos_id'], value_vars=value_cols)
    dfpivot = pd.pivot(dfmelt, index='variable', columns=['start_pos_id'], values='value')
    dfpivot = dfpivot.rename_axis(None, axis=1)
    dfpivot = dfpivot.rename_axis(None, axis=0)
    
    metanames = ['D05_R1', 'D05_R2', 'D05_R3', 'D13_R1', 'D13_R2', 'D13_R3', 'lib_count']
    metadays = [5, 5, 5, 13, 13, 13, 0]

    metadf = pd.DataFrame(
        {'sample_name': metanames,
        'time': metadays,
        }
    )
    metadf = metadf.set_index('sample_name').rename_axis(None, axis=0)

    inference = DefaultInference(n_cpus=1)
    dds = DeseqDataSet(
        counts=dfpivot,
        metadata=metadf,
        design="~time",
        refit_cooks=True
    )
    
    dds.deseq2()
    contrast = ["time", 1, 0]
    stat_res = DeseqStats(dds, contrast=contrast, inference=inference)
    stat_res.summary()
    resdf = stat_res.results_df

    resdf = resdf.reset_index(names='start_pos_id')#.drop(columns=['index'])
    resdf = resdf.merge(df[[
        "pos", "allele", "start_pos_id", 'AAsub', 'Consequence', 'canonical_start', 'second_start'
    ]])
    resdf["target"] = 'BARD1_X1A_ATG'

    resdf = resdf[['start_pos_id', 'log2FoldChange', 'lfcSE', 'pos', 'allele', 'AAsub', 'Consequence', 'canonical_start', 'second_start', 'target']]
    resdf = resdf.rename(columns = {'log2FoldChange': 'score', 'lfcSE': 'standard_error'})

    resdf = resdf.loc[:, ['target', 'pos', 'allele', 'start_pos_id', 'Consequence', 'AAsub', 'canonical_start', 'second_start', 'score', 'standard_error']]

    def starts(df):
        df['comb_starts'] = df['canonical_start'] + df['second_start']

        df['start_consequence'] = None

        df.loc[df['comb_starts'].str.match(r'^.{3}ATG$'), 'start_consequence'] = 'canonical_lost'
        df.loc[df['comb_starts'].str.match(r'^ATG.{3}$'), 'start_consequence'] = 'second_lost'
        df.loc[df['comb_starts'] == 'ATGATG', 'start_consequence'] = 'WT'
        df.loc[df['start_consequence'].isna(), 'start_consequence'] = 'both_lost'

        #df = df.drop(columns = ['comb_starts'])
        
        return df

    resdf = starts(resdf)

    resdf.to_excel('/Users/ivan/Desktop/ATG_QC.xlsx')
    return resdf

In [None]:
def visualize_scores(df):

    ordered_var_list = ['synonymous_variant', 'missense_variant', 'stop_gained']
    facet_sort = ['WT', 'canonical_lost', 'second_lost', 'both_lost']
    histogram = alt.Chart(df).mark_bar().encode(
        x = alt.X('score', title = 'SGE Score', bin = alt.Bin(maxbins = 40)),
        y = alt.Y('count()', title = 'Number of Variants'),
        color = alt.Color('Consequence', sort = ordered_var_list)
    ).properties(
        width = 600, 
        height = 400
    ).interactive()

    facetgram = histogram.facet(columns = 2, facet = alt.Facet('start_consequence', title = 'Start Codon Impact', sort = facet_sort))

    facetgram.display()

    scattergram = alt.Chart(df).mark_point().encode(
        x = alt.X('pos:O', title = 'Genomic Coordinate', axis = alt.Axis(labels = False)),
        y = alt.Y('score', title = 'SGE Score'),
        color = alt.Color('Consequence',sort = ordered_var_list),
        tooltip = [alt.Tooltip('AAsub', title = 'AA Substitution: ')]
    ).properties(
        width = 600,
        height = 400
    ).facet(
        columns = 2,
        facet = alt.Facet('start_consequence',
                          title = 'Start Codon Impact',
                          sort = facet_sort
                         )
    )

    scattergram.display()

    std_error = alt.Chart(df).mark_point().encode(
        x = 'pos:O',
        y = 'standard_error',
        color = 'start_consequence'
    )

    std_error.display()

In [None]:
def main():
    counts_dict = read_counts(counts_path)
    freq_dicts, keys = add_freq(counts_dict)
    mapped_snvs = annotate_snv_lib()
    annotated = annotate_vars(freq_dicts, keys, mapped_snvs, x1a_annotations)
    combined_df = combine_dfs(annotated, keys)
    qc_plots(combined_df)
    scored_df = score(combined_df)
    visualize_scores(scored_df)

In [None]:
main()