# Transcription-Progeny Correlation
This notebook plots the correlation between transcription and progeny production for each valid viral barcode in each infected cell 

Import Python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import pandas as pd

import plotnine as p9

Hardcode variables for now:

In [None]:
cell_annotations_csv = snakemake.input.cell_annotations
viral_barcodes_valid_csv = snakemake.input.viral_barcodes_valid_csv
filtered_progeny_viral_bc_csv = snakemake.input.filtered_progeny_viral_bc_csv
barcoded_viral_genes = snakemake.params.barcoded_viral_genes
expt = expt = snakemake.wildcards.expt
plot = snakemake.output.plot
transcription_progeny_csv = snakemake.output.transcription_progeny_csv

Style parameters. *N.b.* `CBPALETTE` is defined in imports above.

In [None]:
p9.theme_set(p9.theme_classic())

### Load data
The following data will be loaded:
* list of cell barcodes and viral tags
* viral barcode transcription measurements
* viral progeny measurements

Load cell barcodes and viral tags:

In [None]:
cell_barcodes = pd.read_csv(cell_annotations_csv)
cell_barcodes = (
    cell_barcodes
    [['cell_barcode',
      'infected',
      'infecting_viral_tag',
      'total_UMIs',
      'viral_UMIs',
      'frac_viral_UMIs']]
)
display(cell_barcodes)

Load viral barcode transcription measurements:

In [None]:
transcriptome_viral_barcodes = pd.read_csv(viral_barcodes_valid_csv)
transcriptome_viral_barcodes = transcriptome_viral_barcodes.drop(columns=['valid_viral_bc'])
assert set(transcriptome_viral_barcodes['gene']) == set(barcoded_viral_genes), \
       "Barcoded genes in barcode counts do not match expectation."
display(transcriptome_viral_barcodes)

Load viral progeny measurements:

In [None]:
progeny_viral_barcodes = pd.read_csv(filtered_progeny_viral_bc_csv)
progeny_viral_barcodes = progeny_viral_barcodes.drop(columns=['Unnamed: 0'])
progeny_viral_barcodes = (progeny_viral_barcodes
                          .rename(columns={'tag': 'infecting_viral_tag',
                                           'average_freq': 'progeny_freq'}))
assert set(progeny_viral_barcodes['gene']) == set(barcoded_viral_genes), \
       "Barcoded genes in barcode counts do not match expectation."
progeny_sources = list(progeny_viral_barcodes['source'].unique())
display(progeny_viral_barcodes)

## Integrate data sources.

First, make a copy of each cell barcode for each barcoded viral gene and each progeny source:

In [None]:
viral_barcode_freqs = cell_barcodes.copy()
viral_barcode_freqs = (
    pd.concat([viral_barcode_freqs
               .assign(gene=gene)
               for gene in barcoded_viral_genes]))
viral_barcode_freqs = (
    pd.concat([viral_barcode_freqs
               .assign(source=source)
               for source in progeny_sources]))
assert len(viral_barcode_freqs) == len(cell_barcodes) * \
    len(barcoded_viral_genes) * len(progeny_sources), \
    "Need one copy of each cell barcode for each data source"
display(viral_barcode_freqs)

Merge viral barcode frequencies from the transcriptome:

In [None]:
viral_barcode_freqs = pd.merge(
    left=viral_barcode_freqs,
    right=transcriptome_viral_barcodes,
    on=['cell_barcode', 'gene'],
    how='left'
    )
display(viral_barcode_freqs)

Merge viral barcode frequencies from the progeny:

In [None]:
viral_barcode_freqs = pd.merge(
    left=viral_barcode_freqs,
    right=progeny_viral_barcodes,
    on=['source', 'infecting_viral_tag', 'gene', 'viral_barcode'],
    how='left'
    )
display(viral_barcode_freqs)

## Plot correlation between viral transcription and progeny production:
Only look at infected cells that are not doublets:

In [None]:
infected_cells = (viral_barcode_freqs
                  .query('(infected == "infected") and '
                         '(infecting_viral_tag != "both")'))
display(infected_cells)

Plot the correlation between individual viral barocdes in the transcriptome and the progeny:

In [None]:
viral_barcode_correlation = (
    p9.ggplot((infected_cells),
              p9.aes(x='frac_viral_bc_UMIs',
                     y='progeny_freq')) +
       p9.geom_point(alpha=0.3) +
       p9.geom_hline(yintercept=1e-5, linetype='dashed', color=CBPALETTE[2]) +
       p9.facet_grid('source~gene') +
       p9.ggtitle('viral barcode transcription and progeny production\n'
                  f'{expt}') +
       p9.scale_x_log10() +
       p9.scale_y_log10() +
       p9.labs(x='viral barcode fraction of total UMIs in cell',
               y='fraction of total reads in progeny') +
       p9.theme(figure_size=(6, 6),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(viral_barcode_correlation)

Plot the correlation between a cell's total viral burden and progeny production of each infecting viral barcode:

In [None]:
viral_burden_correlation = (
    p9.ggplot((infected_cells),
              p9.aes(x='frac_viral_UMIs',
                     y='progeny_freq')) +
       p9.geom_point(alpha=0.3) +
       p9.geom_hline(yintercept=1e-5, linetype='dashed', color=CBPALETTE[2]) +
       p9.facet_grid('source~gene') +
       p9.ggtitle('viral burden and progeny production\n'
                  f'{expt}') +
       p9.scale_x_log10() +
       p9.scale_y_log10() +
       p9.labs(x='fraction of viral UMIs in cell',
               y='fraction of viral barcode reads in progeny') +
       p9.theme(figure_size=(6, 6),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(viral_burden_correlation)

## Output
Save the correlation plot and the CSV containing info about viral progeny from infected cells:

In [None]:
# save plot
print(f"Saving plot to {plot}")
p9.ggsave(plot=viral_barcode_correlation, filename=plot, verbose=False)

In [None]:
# save CSV
infected_cells.to_csv(transcription_progeny_csv, index=False)