# Notebook to process variants and their read count across conditions

## Import libraries

In [None]:
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['svg.fonttype'] = 'none'
from upsetplot import from_indicators
from upsetplot import UpSet

## Specify paths

In [None]:
### Input
rcdf_path = snakemake.input[0]

### Output
# Dataframes
selcoeffs_df_path = snakemake.output.selcoeffs
# Plots
hist_plot = snakemake.output.hist_plot
upset_plot = snakemake.output.upset_plot
rc_var_plot = snakemake.output.rc_var_plot
timepoints_plot = snakemake.output.timepoints_plot
scoeff_violin_plot = snakemake.output.scoeff_violin_plot
s_through_time_plot = snakemake.output.s_through_time_plot
replicates_plot = snakemake.output.replicates_plot

## Specify parameters

In [None]:
# Columns in the sequence layout (dataframes of expected variants) related to mutations (should be invariable from one project to another)
mutation_attributes = ['pos','alt_codons','alt_aa','aa_pos','nt_seq','aa_seq','Nham_nt','Nham_aa','Nham_codons']

# Project-specific columns in the sample layout
sample_attributes = snakemake.params.sample_attributes

# Project-specific target for the read count per variant
exp_rc_per_var = snakemake.params.exp_rc_per_var

# Project-specific threshold to label variants with a confidence score
rc_threshold = snakemake.params.rc_threshold

# Projet-specific file containing the number of mitotic generations for each condition
nbgen_path = snakemake.params.nb_gen

## Import data

In [None]:
covered_df = pd.read_csv(rcdf_path, index_col=0, dtype={'WT': 'boolean', # Boolean type supports missing data
                                                             'pos':str, 'aa_pos':str}) # pos and aa_pos contain mixed types because of the nucleotide WT - very important when pivoting later on
covered_df

## Add rows corresponding to variants not present in all replicates/timepoints

In [None]:
covered_df['TR'] = covered_df['Timepoint'] + '_' + covered_df['Replicate']
conditions = covered_df.TR.unique()
T0_conditions = [x for x in conditions if 'T0' in x]

In [None]:
upset = covered_df.pivot_table(index = sample_attributes + mutation_attributes,
                               columns='TR', values='readcount',
                               fill_value=0
                              ).reset_index(level=mutation_attributes)
upset

## Label variants based on read count in input replicates

In [None]:
def get_confidence_score(g, threshold):
    if (g >= threshold).all(): # Above threshold in all replicates
        return 1 # best confidence score
    elif (g >= threshold).any(): # Above threshold in at least 1 replicate
        return 2 # medium confidence score
    else:
        return 3 # low confidence score

In [None]:
upset['confidence_score'] = upset[T0_conditions].apply(lambda row: get_confidence_score(row, rc_threshold), axis=1)
mutation_attributes += ['confidence_score']
upset

In [None]:
gby_score = upset.reset_index().groupby(sample_attributes+['confidence_score'])[['nt_seq']].nunique().reset_index('confidence_score')
gby_score

In [None]:
ntseq_tot = upset.reset_index().groupby(sample_attributes)[['nt_seq']].nunique()
ntseq_tot

In [None]:
cscore_statement = ""

if ((gby_score[gby_score.confidence_score == 3] / ntseq_tot).nt_seq > .25 ).any():
    cscore_statement = f"Warning.. at least one of your combination of sample attributes shows more than 25% variants labeled with low confidence. This means a lot of your variants were sequenced less than {rc_threshold} times in all replicates. Make sure your review the config file and adjust the rc_threshold parameter if necessary."
    warnings.warn(cscore_statement)

## Calculate frequencies

In [None]:
freq = upset.copy()
freq_conditions = [f'{x}_freq' for x in conditions]
T0_freq = [x for x in freq_conditions if 'T0' in x]

In [None]:
freq[freq_conditions] = freq[conditions].add(1) / freq.groupby(sample_attributes)[conditions].sum()
freq

In [None]:
# Retrieve overall mean frequency corresponding to the expected read count per variant
mean_exp_freq = (np.log10((exp_rc_per_var+1) / freq.groupby(sample_attributes)[conditions].sum())).mean(axis=None)

## Plot example of distribution of raw read count (per variant)

In [None]:
dataset1 = freq.index[0] # Just plotting for the first combination of sample attributes
graph1df = freq[conditions].loc[dataset1]
graph2df = freq[freq_conditions].loc[dataset1]

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

sns.histplot(graph1df, element='step', bins=50, common_norm=False, log_scale=10, ax=ax1)
ax1.axvline(x=exp_rc_per_var, linestyle='--', color='.8')
ax1.set(xlabel='Raw read count')

sns.histplot(graph2df, element='step', bins=50, log_scale=10, common_norm=False, ax=ax2)
ax2.axvline(x=10**mean_exp_freq, linestyle='--', color='.8')
ax2.set(xlabel='Frequency')

plt.subplots_adjust(top=.9)
plt.suptitle(f'Samples attributes: {" | ".join(dataset1)}')
plt.savefig(hist_plot, format='svg', dpi=300)

## Plot overlap across timepoints and replicates

In [None]:
freq['mean_input'] = freq[T0_freq].mean(axis=1)
bool_conditions = [f'{x}_indicator' for x in conditions]
freq[bool_conditions] = freq[conditions].astype(bool)
freq

In [None]:
upset_sub = freq.loc[dataset1].drop(conditions, axis=1).rename(columns=dict(zip(bool_conditions, conditions))) # Can be replaced with any specific combination of sample attributes
upset_sub

In [None]:
cscores = [1,2,3]
cscore_colors = ['green','orange','red']

fig = plt.figure(figsize=(6, 6))
upset_obj = UpSet(from_indicators(conditions, data=upset_sub),
                   #show_percentages=True,
                   show_counts=True,
                   min_subset_size="1%",
                   sort_by='cardinality',
                   element_size=None,
                   intersection_plot_elements=0, # height of intersection barplot in matrix elements
                   totals_plot_elements = 2 # width of totals barplot in matrix elements
                  )

upset_obj.add_stacked_bars(by='confidence_score',
                           colors=dict(zip(cscores, cscore_colors)),
                           elements=3
                          )

upset_obj.add_catplot(value='mean_input', kind='violin', cut=0, density_norm='count',
                      log_scale=10, linewidth=0.5,
                      elements=3 # height in number of matrix elements
                      )

d = upset_obj.plot(fig=fig) # Assigns all plots to a dictionary containing axes subplots - same keys as gridspec returned by upset_obj.make_grid()
ax0 = d['extra0'] # Key corresponding to 1st stacked barplot - confidence score ('intersections' = intersection barplot)
ax1 = d['extra1'] # Key corresponding to 1st catplot - read count for input samples

ax0.set_ylabel('# Variants')
ax0.legend(title='Confidence score')

ax1.set_ylabel('Mean\nT0 freq.')

plt.subplots_adjust(top=.95)
plt.suptitle(f'Samples attributes: {" | ".join(dataset1)}')

plt.savefig(upset_plot, format='svg', dpi=300)

## Distribution of allele frequencies

In [None]:
longfreq = freq.melt(id_vars=mutation_attributes,
                     value_vars=freq_conditions, var_name='TR_freq',
                     value_name='frequency', ignore_index=False
                    ).reset_index()
longfreq

In [None]:
longfreq['Timepoint'] = longfreq.TR_freq.apply(lambda x: x.split('_')[0])
longfreq['Replicate'] = longfreq.TR_freq.apply(lambda x: x.split('_')[1])
longfreq

In [None]:
graphdf = longfreq.copy()
graphdf['Sample attributes'] = graphdf[sample_attributes].agg(' | '.join, axis=1)
labels = graphdf['Sample attributes'].unique()
g = sns.catplot(graphdf, x='Sample attributes', y='frequency', row='Timepoint',
            hue='Replicate', palette='hls', split=True,
            log_scale=10,
            kind='violin', cut=0, linewidth=1, inner='quart',
            height=2, aspect=.8*len(labels)
           )
g.map(plt.axhline, y=10**mean_exp_freq, linestyle='--', color='.8')

g.set_axis_labels('','Frequency')
g.set_titles(row_template='{row_name}')
g.set_xticklabels(labels, rotation=45, ha='right')
g.tight_layout()
plt.savefig(rc_var_plot, format='svg', dpi=300)

## Get mutation type

In [None]:
def get_mutation_type(Nham_aa, alt_aa):
    # Quick function to determine if the mutation is synonymous or non-synonymous<
    # and if it's missense or nonsense
    if Nham_aa == 0:
        return 'synonymous'
    elif alt_aa == '*':
        return 'nonsense'
    else:
        return 'missense'

In [None]:
longfreq['mutation_type'] = longfreq.apply(lambda row: get_mutation_type(row.Nham_aa, row.alt_aa), axis=1)
mutation_attributes += ['mutation_type']
longfreq.head(4)

## Calculate Log2(fold-change) for every timepoint relative to T0

In [None]:
freq_wide = longfreq.pivot(index=sample_attributes+mutation_attributes+['Replicate'],
                              columns='Timepoint',
                              values='frequency')
freq_wide

In [None]:
timepoints = [x for x in freq_wide.columns]
lfc_combinations = [(x,'T0') for x in timepoints[1:]]
lfc_combinations
lfc_cols = [f'Lfc_{"_".join(x)}' for x in lfc_combinations]
lfc_cols

In [None]:
for i,v in enumerate(lfc_cols):
    freq_wide[v] = freq_wide.apply(lambda row: np.log2(row[lfc_combinations[i][0]] / row[lfc_combinations[i][1]]), axis=1)

freq_wide

## Normalize with number of mitotic generations

In [None]:
nbgen_df = pd.read_excel(nbgen_path)
nbgen_wide = nbgen_df.pivot(index=sample_attributes+['Replicate'],
                            columns='Timepoint',
                            values='Nb_gen'
                           )
nbgen_wide.columns = [f'{x}_gen' for x in nbgen_wide.columns]
for i,x in enumerate(timepoints):
    if i in [0,1]:
        pass
    else:
        nbgen_wide[f'cumul_{x}_gen'] = nbgen_wide[[f'{t}_gen' for t in timepoints[1:i]+[x]]].sum(axis=1)
for x in nbgen_wide.columns:
    if 'cumul_' in x:
        nbgen_wide[x.split('cumul_')[1]] = nbgen_wide[x]
        nbgen_wide.drop(x, axis=1, inplace=True)
nbgen_wide

In [None]:
lfc_wide = freq_wide.reset_index().merge(right=nbgen_wide.reset_index(), on=sample_attributes+['Replicate'])
gen_cols = nbgen_wide.columns
lfc_wide[gen_cols]

In [None]:
for x in list(zip(lfc_cols, gen_cols)):
    lfc_wide[x[0]] /= lfc_wide[x[1]]
lfc_wide

## Normalize with median of synonymous codons

In [None]:
syn = lfc_wide[(lfc_wide.Nham_nt >0) & (lfc_wide.Nham_aa == 0)][sample_attributes+['Replicate']+lfc_cols]
syn

In [None]:
syn.describe()

In [None]:
mediansyn = syn.groupby(sample_attributes+['Replicate'])[lfc_cols].median()
mediansyn.columns = [x.replace('Lfc','med') for x in mediansyn.columns]
med_cols = mediansyn.columns
mediansyn

## Calculate selection coefficients

In [None]:
selcoeff_cols = [x.replace('Lfc','s') for x in lfc_cols]
s_wide = lfc_wide.merge(right=mediansyn.reset_index(), on=sample_attributes+['Replicate'])

In [None]:
for i,s in enumerate(selcoeff_cols):
    s_wide[s] = s_wide[lfc_cols[i]] - s_wide[med_cols[i]]
s_wide

## Repeat WT at every position

### Repeat row for every position in the protein sequence

In [None]:
# Select WT nucleotide sequence(s)
WTdf = s_wide[s_wide.Nham_nt == 0]

# Get length of protein sequence
WTdf['len_aa'] = WTdf.aa_seq.apply(lambda x: len(x))

# Create list of positions for every sequence
WTdf['pos'] = WTdf.len_aa.apply(lambda x: np.arange(x))

# Same with WT codons (list of codons at every matching position)
WTdf['alt_codons'] = WTdf.nt_seq.apply(lambda x: [x[i:i+3] for i in range(0, len(x), 3)])

# Same with WT amino acid
WTdf['alt_aa'] = WTdf.aa_seq.apply(lambda x: [y for y in x])

# Then we use explode to turn horizontal lists into rows with matching values for all 3 columns
WTdf = WTdf.explode(['pos','alt_codons','alt_aa'])
WTdf

### Get non-WT

In [None]:
# In this step we need to cast the dtype of pos and aa_pos
# which we could not do before because the WT rows feature string values ("non-applicable")
nonWT = s_wide[s_wide.Nham_nt > 0]
nonWT[['pos','aa_pos']] = nonWT[['pos','aa_pos']].astype(int)

### Retrieve position offset (position in the full protein sequence)

In [None]:
offpos = nonWT.groupby(sample_attributes)[['pos','aa_pos']].min()

if offpos.pos.sum() == 0:
    WTdf = pd.merge(left=WTdf.drop(columns='aa_pos'), right=offpos[['aa_pos']].reset_index(), on=sample_attributes)
else:
    print('There is at least one case where there is no sequenced mutant at position 0, which prevents from retrieving the start position in the full protein sequence (min aa_pos)')

WTdf['aa_pos'] += WTdf['pos']
WTdf

In [None]:
allpos_df = pd.concat([WTdf, nonWT], ignore_index=True)
allpos_df

## Calculate median selection coefficient (over synonymous mutants)

In [None]:
median_df = allpos_df.groupby(sample_attributes+['Replicate','aa_pos','alt_aa']
                          )[selcoeff_cols + ['confidence_score', 'Nham_aa', 'mutation_type']
                           ].agg(dict(zip(selcoeff_cols + ['confidence_score', 'Nham_aa', 'mutation_type'],
                                          ['median']*len(selcoeff_cols) + ['min', 'first', 'first']))
                                ).reset_index(level=['aa_pos','alt_aa'])
median_df

In [None]:
median_df.to_csv(selcoeffs_df_path)

In [None]:
median_df.index[0]

## Plot example of correlation between compared timepoints

In [None]:
dataset1_r1 = median_df.index[0]
graphdf = median_df.loc[dataset1_r1].reset_index()
g = sns.pairplot(graphdf, vars=selcoeff_cols,
             hue='confidence_score', hue_order=cscores, palette=dict(zip(cscores, cscore_colors)),
             plot_kws={'s':8, 'alpha':.2},
             height=1.5, corner=True)
g.tight_layout()
plt.savefig(timepoints_plot, format='svg', dpi=300)

## Plot overall distribution of selection coefficients

In [None]:
mutation_attributes_aa = ['aa_pos','alt_aa','Nham_aa','mutation_type']
median_long = median_df.melt(id_vars=mutation_attributes_aa,
                        value_vars=selcoeff_cols, var_name='Compared timepoints',
                        value_name='s', ignore_index=False).reset_index()
median_long

In [None]:
median_long['Sample attributes'] = median_long[sample_attributes].agg(' | '.join, axis=1)
labels = median_long['Sample attributes'].unique()
median_long['Compared timepoints'] = median_long['Compared timepoints'].apply(lambda x: x.split('_')[1])

In [None]:
g = sns.catplot(median_long, x='Sample attributes', y='s',
                row='Compared timepoints', row_order=timepoints[1:],
                hue='Replicate', palette='hls', split=True,
                kind='violin', cut=0, linewidth=1, inner='quart',
                height=2, aspect=.8*len(labels)
               )

g.set_axis_labels('','s')
g.set_titles(row_template='{row_name}')
g.set_xticklabels(labels, rotation=45, ha='right')
g.tight_layout()
plt.savefig(scoeff_violin_plot, format='svg', dpi=300)

## Plot selection through time

In [None]:
median_long.groupby(['Sample attributes','Replicate','mutation_type','Compared timepoints'])[['s']].describe()

In [None]:
g = sns.relplot(data=median_long, x='Compared timepoints', y='s',
                col='Sample attributes', col_wrap=3,
                hue='mutation_type', palette='hls',
                style='Replicate',
                kind='line', markers=True, errorbar='sd',
                height=1.5)
g.set(xlabel='')
g.set_titles(col_template='{col_name}')
plt.savefig(s_through_time_plot, format='svg', dpi=300)

## Show correlation between first two replicates

In [None]:
graphdf = median_long.pivot(index=mutation_attributes_aa+['Sample attributes','Compared timepoints'],
                            columns='Replicate',
                            values='s').reset_index()
firstTwoReplicates = median_long.Replicate.unique()[:2]

g = sns.lmplot(graphdf, x=firstTwoReplicates[0], y=firstTwoReplicates[1],
               col='Sample attributes', col_wrap=3,
               hue='Compared timepoints', palette='mako',
               height=1.5, scatter_kws={'s':8,'alpha':.2})
g.set_titles(col_template='{col_name}')
plt.savefig(replicates_plot, format='svg', dpi=300)