# Notebook to parse fasta files and merge sequences with layout

## Import libraries

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['svg.fonttype'] = 'none'
import numpy as np

## Specify paths

In [None]:
### Input
fasta_files = snakemake.input.fasta_files
layout_path = snakemake.input.expected_mutants
stats_path = snakemake.input.read_stats

### Output
# Dataframes
rcdf_path = snakemake.output.read_counts
# Plots
unexp_rc_plot = snakemake.output.unexp_rc_plot
rc_filter_plot = snakemake.output.rc_filter_plot

## Specify parameters

In [None]:
# Note, I used to import config entries as "params" (i.e. listed in the params statement of the rule), but...
# a bug led to params not being imported.. (snakemake object is imported but not the params attribute)
# Therefore, now I will import all config entries directly from the snakemake object
exp_rc_per_sample = float(snakemake.config['rc_aims']['exp_rc_per_sample'])
plot_formats = [x for x in snakemake.config['plots']['format'] if x!= 'svg']

## Parse fasta files (with read count)

In [None]:
seq_l = []

for f in fasta_files:
    
    # Note: In the next line, the double split should catch file (base) names regardless of the platform
    # Should still work when run with snakemake (no paths)
    sample_name = f.split('/')[-1].split('\\')[-1].split('_aggregated.fasta')[0]
    
    with open(f, 'r') as file:
        entries = file.read().split('>')[1:]

    readcount = [int(x.split('size=')[1].split('\n')[0]) for x in entries]
    seqs = [x.split('size=')[1].split('\n', 1)[1].replace('\n', '') for x in entries]
    fasta_df = pd.DataFrame(list(zip(seqs, readcount)), columns=['nt_seq','readcount'])
    fasta_df['Sample_name'] = sample_name
    seq_l.append(fasta_df)
    
master_seq = pd.concat(seq_l, ignore_index=True)
master_seq

## Compare with expected variants and annotate

In [None]:
master_layout = pd.read_csv(layout_path, index_col=0, dtype={'WT': 'boolean', # Boolean type supports missing data
                                                             'pos':str, 'aa_pos':str}) # pos and aa_pos contain mixed types because of the nucleotide WT - very important when pivoting later on
master_layout

In [None]:
comparedf = pd.merge(left=master_layout, right=master_seq, how='outer', on=['Sample_name','nt_seq'], indicator='Location')
covered_df = comparedf[comparedf.Location == 'both']
covered_df.to_csv(rcdf_path)

## Coverage of expected variants (ratio of unique expected sequences) at T0

In [None]:
expected_df = covered_df[covered_df.Timepoint == 'T0'].groupby('Sample_name').size().reset_index(name='unique_seq_variants')
expected_df['unique_expected_variants'] = master_layout[master_layout.Sample_name.isin(expected_df.Sample_name.unique())].groupby('Sample_name').size().values
expected_df['unique_variants_%'] = expected_df['unique_seq_variants'] / expected_df['unique_expected_variants']
expected_df

## Read count of unexpected variants

In [None]:
sns.kdeplot(data=comparedf[comparedf.Location == 'right_only'], x='readcount',
            hue='Sample_name', common_norm=False, log_scale=True,
            legend=False
           )
plt.xlabel('Read count of unexpected variants')
plt.savefig(unexp_rc_plot, format='svg', dpi=300)
[plt.savefig(f"{unexp_rc_plot.split('.svg')[0]}.{x}", format=x, dpi=300) for x in plot_formats]

In [None]:
unexpected_df = comparedf[comparedf.Location == 'right_only'].groupby('Sample_name')[['readcount']].sum()
unexpected_df['total_rc'] = master_seq.groupby('Sample_name')[['readcount']].sum().readcount
unexpected_df['%rc_unexp'] = unexpected_df['readcount'] / unexpected_df['total_rc']
unexpected_df.sort_values(by=['%rc_unexp','Sample_name'], ascending=[False,True]).head(3)

The following command can be used to explore why the sequence is unexpected:
```
comparedf[(comparedf.Location == 'right_only') & (comparedf.Sample_name == 'CN_a_r2_F2_T0')].iloc[0].nt_seq
```

## Show overall read filtering steps

In [None]:
stats = pd.read_csv(stats_path, index_col=0)[['Total_raw_reads','Total_trimmed_reads','Total_merged_reads','Nb_singletons']]
# In the following steps, the column names refer to the total number of reads lost at the specified step
stats['Trimming'] = stats['Total_raw_reads'] - stats['Total_trimmed_reads']
stats['Merging'] =  stats['Total_trimmed_reads'] - stats['Total_merged_reads']
stats['Aggregating'] = stats['Nb_singletons']
stacked_data = pd.concat([stats,
                          unexpected_df[['readcount']].rename(columns={'readcount':'Unexpected'}),
                         ], axis=1)
stacked_data['OK'] = stacked_data['Total_raw_reads'] - stacked_data[['Trimming','Merging','Aggregating','Unexpected']].sum(axis=1)
stacked_data.drop(['Total_raw_reads','Total_trimmed_reads','Total_merged_reads','Nb_singletons'], axis=1, inplace=True)
stacked_data.sort_index(inplace=True)
stacked_data

In [None]:
samples = stacked_data.index.to_list()
width = .5
color_dict = {'OK':'green',
              'Trimming':'gold',
              'Merging':'orange',
              'Aggregating':'red',
              'Unexpected':'grey'
             }

f, ax = plt.subplots(figsize=(20,5))
bottom = np.zeros(len(stacked_data))

for l in color_dict.keys():
    p = ax.bar(samples, stacked_data[l].values, width, label=l, bottom=bottom, color=color_dict[l])
    bottom += stacked_data[l].values

ax.set_yscale('log', base=10)
ax.set(ylim=(1e4,1e7), ylabel='Read count')

ax.axhline(y=exp_rc_per_sample, linestyle='--', color='.8')
ax.annotate('Aim', (-4, 1.1*exp_rc_per_sample), color='.5')

ax.xaxis.set_ticks(samples)
ax.set_xticklabels(samples, rotation=45, ha='right')
ax.legend(framealpha=.9)

plt.tight_layout()
plt.savefig(rc_filter_plot, format='svg', dpi=300)
[plt.savefig(f"{rc_filter_plot.split('.svg')[0]}.{x}", format=x, dpi=300) for x in plot_formats]
plt.show()