# Process plate counts to get ratios of variants and determine pooling and MOI
This notebook is designed to be run interactively, to determine the relative concentration of strains in equal volume and repooled samples

## Setup
Import Python modules:

In [None]:
import pickle
import sys

import altair as alt

import matplotlib.pyplot as plt

import numpy

import pandas as pd
from os.path import join
import os
import ruamel.yaml as yaml

_ = alt.data_transformers.disable_max_rows()

Get the variables passed by `snakemake`:

In [None]:
viral_library_csv = snakemake.input.viral_library
neut_standard_set_csv = snakemake.input.neut_standard_set
initialpool_metadata = snakemake.input.initialpool_metadata
initial_pool_counts = snakemake.input.initial_pool_counts
initial_pool_fates = snakemake.input.initial_pool_fates
repooledlibraryfile = snakemake.input.repooledlibraryfile
strainrepooling_volumes = snakemake.output.strainpooling
equalvolume_strainbalance_plot = snakemake.output.equalvolume_plot
repool_strainbalance_plot = snakemake.output.repool_plot

In [None]:
barcode_runs_df = pd.read_csv(initialpool_metadata)
barcode_runs_df.drop(columns=["fastq"], inplace=True)
barcode_runs_df["sample"] = barcode_runs_df.apply(
    lambda x: "-".join(x.astype(str)), axis=1
)

samples = barcode_runs_df["sample"].unique().tolist()
print(f"There are {len(samples)} barcode runs.")

count_csvs = initial_pool_counts

fate_csvs = initial_pool_fates

samples_df = barcode_runs_df

## Statistics on barcode-parsing for each sample
Make interactive chart of the "fates" of the sequencing reads parsed for each sample on the plate.

If most sequencing reads are not "valid barcodes", this could potentially indicate some problem in the sequencing or barcode set you are parsing.

Potential fates are:
 - *valid barcode*: barcode that matches a known virus or neutralization standard, we hope most reads are this.
 - *invalid barcode*: a barcode with proper flanking sequences, but does not match a known virus or neutralization standard. If you  have a lot of reads of this type, it is probably a good idea to look at the invalid barcode CSVs (in the `./results/barcode_invalid/` subdirectory created by the pipeline) to see what these invalid barcodes are.
 - *unparseable barcode*: could not parse a barcode from this read as there was not a sequence of the correct length with the appropriate flanking sequence.
 - *low quality barcode*: low-quality or `N` nucleotides in barcode, could indicate problem with sequencing.
 - *failed chastity filter*: reads that failed the Illumina chastity filter, if these are reported in the FASTQ (they may not be).

Also, if the number of reads per sample is very uneven, that could indicate that you did not do a good job of balancing the different samples in the Illumina sequencing.

In [None]:
fates = (
    pd.concat([pd.read_csv(f).assign(well=os.path.basename(f).strip('_fates.csv')) for f, s in zip(fate_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="well")
    .assign(
        fate_counts=lambda x: x.groupby("fate")["count"].transform("sum"),
        sample_well=lambda x: x["sample"] + " (" + x["well"] + ")",
    )
    .query("fate_counts > 0")[  # only keep fates with at least one count
        ["fate", "count", "well", "sample_well", "dilution_factor"]
    ]
)

assert len(fates) == len(fates.drop_duplicates())


sample_wells = list(
    fates.sort_values(["dilution_factor"])["sample_well"]
)



fates_chart = (
    alt.Chart(fates)
    .encode(
        alt.X("count", scale=alt.Scale(nice=False, padding=3)),
        alt.Y(
            "sample_well",
            title=None,
            sort=sample_wells,
        ),
        alt.Color("fate", sort=sorted(fates["fate"].unique(), reverse=True)),
        alt.Order("fate", sort="descending"),
        tooltip=fates.columns.tolist(),
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=200,
        title=f"Barcode parsing for initial titering plate",
    )
    .configure_axis(grid=False)
)

fates_chart

## Read barcode counts
Read the counts per barcode:

In [None]:
# get barcode counts
counts = (
    pd.concat([pd.read_csv(c).assign(well=os.path.basename(c).strip('_counts.csv')) for c, s in zip(count_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="well")
    .drop(columns=["replicate"])
    .assign(sample_well=lambda x: x["sample"] + " (" + x["well"] + ")")
)

# classify barcodes as viral or neut standard
barcode_class = pd.concat(
    [
        pd.read_csv(viral_library_csv)[["barcode", "strain"]].assign(
            neut_standard=False,
        ),
        pd.read_csv(neut_standard_set_csv)[["barcode"]].assign(
            neut_standard=True,
            strain=pd.NA,
        ),
    ],
    ignore_index=True,
)

# merge counts and classification of barcodes
assert set(counts["barcode"]) == set(barcode_class["barcode"])
counts = counts.merge(barcode_class, on="barcode", validate="many_to_one")
assert set(sample_wells) == set(counts["sample_well"])

Apply any manually specified data drops:

## Average counts per barcode in each well

Plot average counts per barcode.
If a sample has inadequate barcode counts, it may not have good enough statistics for accurate analysis, and a QC-threshold is applied:

In [None]:
avg_barcode_counts = (
    counts.groupby(
        ["well", "sample_well"],
        dropna=False,
        as_index=False,
    )
    .aggregate(avg_count=pd.NamedAgg("count", "mean"))
    .assign(
        fails_qc=lambda x: (
            x["avg_count"] < 500
        ),
    )
)

avg_barcode_counts_chart = (
    alt.Chart(avg_barcode_counts)
    .encode(
        alt.X(
            "avg_count",
            title="average barcode counts per well",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Color(
            "fails_qc",
            title=f"fails {'min barcode count threshold'=}",
            legend=alt.Legend(titleLimit=500),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if avg_barcode_counts[c].dtype == float else c
            for c in avg_barcode_counts.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Average barcode counts per well for titering plate",
    )
    .configure_axis(grid=False)
)

display(avg_barcode_counts_chart)

# drop wells failing QC
avg_barcode_counts_per_well_drops = list(avg_barcode_counts.query("fails_qc")["well"])

## Fraction of counts from neutralization standard
Determine the fraction of counts from the neutralization standard in each sample, and make sure this fraction passess the QC threshold.

In [None]:
neut_standard_fracs = (
    counts.assign(
        neut_standard_count=lambda x: x["count"] * x["neut_standard"].astype(int)
    )
    .groupby(
        ["well", "sample_well"],
        dropna=False,
        as_index=False,
    )
    .aggregate(
        total_count=pd.NamedAgg("count", "sum"),
        neut_standard_count=pd.NamedAgg("neut_standard_count", "sum"),
    )
    .assign(
        neut_standard_frac=lambda x: x["neut_standard_count"] / x["total_count"],
        fails_qc=lambda x: (
            x["neut_standard_frac"] < 0.001
        ),
    )
)

neut_standard_fracs_chart = (
    alt.Chart(neut_standard_fracs)
    .encode(
        alt.X(
            "neut_standard_frac",
            title="frac counts from neutralization standard per well",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Color(
            "fails_qc",
            title=f"fails {'min_neut_standard_frac_per_well'=}",
            legend=alt.Legend(titleLimit=500),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if neut_standard_fracs[c].dtype == float else c
            for c in neut_standard_fracs.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Neutralization-standard fracs per well for titering plate",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(neut_standard_fracs_chart)


In [None]:
counts_balancedbarcode = counts.groupby(['sample','strain','dilution_factor','serum','well','barcode']).sum()
counts_balancedbarcode = counts_balancedbarcode.reset_index().drop(columns = ['neut_standard'])
counts_balancedbarcode

## Fraction of counts from each barcode
Determine the fraction of counts from the each barcode for in each well of the plate

In [None]:
#First calculate the total counts per barcode in each well
sumperwell = counts_balancedbarcode.groupby(['sample','sample_well','dilution_factor','serum','well']).sum().drop(columns=['strain','barcode','date'])
sumperwell = sumperwell.reset_index()
sumperwell

In [None]:
#Then calculate the sum of barcodes corresponding to barcoded variants per well and calculate the fraction of counts that each barcode represents of barcode counts per well
sumperwell = sumperwell.rename(columns={'count':'counts_perwell'})
counts_balancedbarcode = counts_balancedbarcode.merge(sumperwell, on=['sample','sample_well','dilution_factor','serum','well'])
counts_balancedbarcode['fraction_barcode'] = counts_balancedbarcode['count'] /counts_balancedbarcode['counts_perwell'] 
counts_balancedbarcode

In [None]:
#Now we select a couple wells where we think the vRNA counts should correspond to linearly to titers based on prior experiments amd calculate the sum of all barcodes for each strain and determine the fraction of counts that is each strain
selected_well = counts_balancedbarcode.loc[counts_balancedbarcode['sample'].str.contains('-A6-')] #Take wells where fraction neut-standard is reasonable
sum_barcodes_bystrain = selected_well.groupby(['strain','well'])['fraction_barcode'].sum().to_frame().rename(columns = {'fraction_barcode': 'fraction_strain_perwell'}).reset_index()

In [None]:
#Merge calculation back with other data on barcoded strains
mean_single_well = selected_well.merge(sum_barcodes_bystrain, on = ['strain','well'], how = 'left')

# calcualte ratio to add
num_strains = 36
mean_single_well['ratio_to_add'] = (1/num_strains)/mean_single_well['fraction_strain_perwell']

mean_single_well

In [None]:
#Plot the fraction that each strain is represented
(mean_single_well
                 [['strain', 'fraction_strain_perwell']]
                 .drop_duplicates()
                 .plot.barh(x="strain", y="fraction_strain_perwell", figsize = (6,10), log=False, xlim =(0,0.1))
)

In [None]:
#Make a dataframe that just has the fraction each strain represented and the ratio is should be added to the pool
initial_pool_ratios = mean_single_well[['strain', 'fraction_strain_perwell', 'ratio_to_add']].drop_duplicates()
initial_pool_ratios.plot.barh(x="strain", y="ratio_to_add", figsize = (6,14))

In [None]:
#Calculate volume to add to each well
initial_pool_ratios['vol_to_add'] = initial_pool_ratios['ratio_to_add'].apply(lambda x: x * 200).to_list()

In [None]:
#Send volumes to a printable csv file
initial_pool_ratios[['strain','vol_to_add']].to_csv(strainrepooling_volumes)

In [None]:
#Look at the representation of each barcode in this well
mean_single_well['strain_barcode'] = mean_single_well['strain'] + "_" + mean_single_well['barcode'].str[0:16]
mean_single_well.plot.barh(x="strain_barcode", y="fraction_barcode", figsize = (6,40), log=False)

In [None]:
#Now we want to generate a plot that shows barcodes by strain
assess_barcodebalancing = mean_single_well[['strain','barcode','strain_barcode','fraction_barcode','count','counts_perwell']]
assess_barcodebalancing_bystrain = assess_barcodebalancing.groupby(['strain'])['count'].sum().to_frame().rename(columns = {'count': 'count_perstrain_perwell'}).reset_index()
assess_barcodebalancing_bystrain
assess_barcodebalancing = assess_barcodebalancing.merge(assess_barcodebalancing_bystrain, on="strain")
assess_barcodebalancing['fraction_of_strain_barcode'] = assess_barcodebalancing['count']/assess_barcodebalancing['count_perstrain_perwell']
assess_barcodebalancing

In [None]:
barcode_balance_chart = (
    alt.Chart(assess_barcodebalancing)
    .encode(
        alt.X("fraction_of_strain_barcode", scale=alt.Scale(nice=False, padding=3)),
        alt.Y(
            "strain",
            title=None,
            sort=sample_wells,
        ),
        alt.Color("barcode", sort=sorted(assess_barcodebalancing["strain"].unique())).scale(range = ['steelblue', 'goldenrod','firebrick','rebeccapurple']),
    )
    .mark_bar(height={"band": 0.75})
    .properties(
        height=alt.Step(20),
        width=250,
    )
    .configure_axis(grid=False,labelFontSize=15,titleFontSize=18,labelLimit=300)
)

barcode_balance_chart


# Confirm that the ratio of the strains represented is similar after repooling

In [None]:
#Checking how well repooling worked based on balancing in plate 1
counts_repooled = repooledlibraryfile
counts_repooled = pd.read_csv(counts_repooled)

In [None]:
#Merge with strain names
counts_repooled_df = counts_repooled.merge(barcode_class, on="barcode", validate="many_to_one")
counts_repooled_df = counts_repooled_df.loc[counts_repooled_df['neut_standard'] == False].drop(columns=['neut_standard'])
counts_repooled_df

In [None]:
counts_repooled_df_bystrain = counts_repooled_df.groupby(['strain']).sum(numeric_only=True).reset_index()

In [None]:
counts_repooled_df_bystrain['counts_well'] = counts_repooled_df_bystrain['count'].sum()

counts_repooled_df_bystrain['fraction'] = counts_repooled_df_bystrain['count']/counts_repooled_df_bystrain['counts_well']

In [None]:
counts_repooled_df_bystrain.plot.barh(x="strain", y="fraction", figsize = (6,10), xlim=(0,0.1))

In [None]:
#We can look at this also as difference from the desired fraction, rather than raw fraction
counts_repooled_df_bystrain['ratio_to_desired_fraction'] = counts_repooled_df_bystrain['fraction']/(1/36)

In [None]:
#Also calculate ratio to desired fraction in original pool
initial_pool_ratios['ratio_to_desired_fraction'] = initial_pool_ratios['fraction_strain_perwell']/(1/36)

In [None]:
#Plot difference in representation in original pool and in repooled sample
counts_repooled_df_bystrain.plot.barh(x="strain", y="ratio_to_desired_fraction", xlim=(0,3), figsize = (4,10))
initial_pool_ratios.plot.barh(x="strain", y="ratio_to_desired_fraction", xlim=(0,3),figsize = (4,10))

In [None]:
#Plot difference in representation in original pool and in repooled sample
counts_repooled_df_bystrain.plot.barh(x="strain", y="fraction", xlim=(0,0.1), figsize = (4,10))
plt.savefig(repool_strainbalance_plot,
            dpi = 'figure',
            bbox_inches = 'tight')

initial_pool_ratios.plot.barh(x="strain", y="fraction_strain_perwell", xlim=(0,0.1),figsize = (4,10))
plt.savefig(equalvolume_strainbalance_plot,
            dpi = 'figure',
            bbox_inches = 'tight')

# Determine concentration of library that should be used for neutralization assay

In [None]:
#Determine when we are in the linear range for vRNA relationship to sample dilutions
MOItestsamples = neut_standard_fracs.loc[neut_standard_fracs['sample_well'].str.contains('230417')]
MOItestsamples['TCID50'] = MOItestsamples['sample_well'].str.split('-').str[3].str.split("_").str[1].str[:-1]
MOItestsamples.plot.scatter(x='TCID50', y='neut_standard_frac')