In [None]:
# RUN
import sys
sys.path.append("/opt/src")
import mip_functions as mip
import calculate_prevalences as cap
import probe_summary_generator
import pickle
import json
import copy
import math
import os
import numpy as np
import subprocess
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib.lines import Line2D
plt.rcParams['svg.fonttype'] = 'none'
import pandas as pd
import seaborn as sns
import plotly.express as px
import warnings
warnings.filterwarnings('ignore')
import allel
wdir = "/opt/user/stats_and_variant_calling/"
import plotly.express as px
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from urllib.request import urlopen

# Calling genotypes, prevalences and filtering data
The original vcf file created by freebayes contain the genotypes determined by the program itself. In addition, genotype values for aggregated and non-aggregated nucleotides and aminoacids are also available as *_genotypes_table.csv files as described above.  

However, the default parameters generating the vcf file are not very strict. In this part of the analysis we will apply various filters to the count tables and generate genotype calls based on those filters.

### Chose which tables to analyse
Select the type of data to analyse. Make sure the count file is matching the coverage file. e.g. alternate_XX_table and coverage_XX_table, XX must be the same value (AA, AN or nothing).

#### Example cell
```python
mutation_count_file = "alternate_AA_table.csv"
mutation_coverage_file = "coverage_AA_table.csv"
```

In [None]:
# USER INPUT

mutation_count_file = "alternate_AA_table.csv"
mutation_coverage_file = "coverage_AA_table.csv"

In [None]:
# RUN
import pandas as pd
mutation_counts = pd.read_csv(mutation_count_file,
                              header=list(range(6)),
                              index_col=0)
mutation_counts.head()

mutation_coverage = pd.read_csv(mutation_coverage_file,
                                index_col=0,
                                header=list(range(6)))
mutation_coverage.head()

### Set your filters   
1.  **min_coverage**: how many UMIs are needed to for a genomic position for a sample to reliably call genotypes. If we set min_coverage = 10, any locus within a sample that is covered below this threshold will have an NA genotype.
2.  **min_count**: if a genomic position have enough coverage, how many UMIs supporting an ALT allele call is needed for a reliable call. If we set min_count = 2, any mutation with an  call that has less than 2 UMIs supporting the ALT call will revert to REF.
3.  **min_freq**: a minimum within sample allele frequency threshold to consider a variant valid. If set to 0.01, for example, a variant locus in a sample that is at 0.005 frequency for the ALT allele within the sample, the locus would be called REF, if the within sample AF is between 0.01 and 0.99, it would be considered HET, and if > 0.99, it would be homozygous ALT.

#### Example cell
```python
# filter mutation counts for minimum count parameter
# by setting counts to zero if it is below threshold
min_count = 2
# filter loci without enough coverage by setting
# coverage to zero if it is below threshold
min_coverage = 10
# call genotypes using the minimum within sample
# allele frequency
min_freq = 0
```

In [None]:
# USER INPUT 

# filter mutation counts for minimum count parameter
# by setting counts to zero if it is below threshold
min_count = 2
# filter loci without enough coverage by setting
# coverage to zero if it is below threshold
min_coverage = 10
# call genotypes using the minimum within sample
# allele frequency
min_freq = 0

In [None]:
# RUN

# import the PCA module which has genotype calling and
# filtering functions 
import PCA

gt_calls = PCA.call_genotypes(mutation_counts, mutation_coverage,
                              min_count, min_coverage, min_freq)
gt_calls.keys()

### What are the dataframes generated by call_genotypes function and how  are they generated?

**filtered_mutation_coverage**: take the mutation_coverage table, if a cell's value is below *min_coverage*, reset that cell's value to zero, otherwise leave as is.

In [None]:
# RUN
filtered_mutation_counts = gt_calls["filtered_mutation_counts"]
filtered_mutation_counts.to_csv(os.path.join(
        wdir, "filtered_alternate_AA_table.csv"))
filtered_mutation_counts.head()

**wsaf**: divide *filtered_mutation_counts* table by *filtered_mutation_coverage* table, yielding within sample allele frequencies.  

In [None]:
# RUN
freq = gt_calls["wsaf"]
freq.to_csv(os.path.join(
        wdir, "within_sample_allele_frequencies.csv"))
freq.head()

**genotypes**: take the *wsaf* table, if a cell's value is less than *min_freq* set the genotype value to 0 (homozygous wild type); if the cell's value is more than (*1 - min_freq*) set the genotype value to 2 (homozygous mutant), if the cell's value is between *min_freq* and (*1 - min_freq*) set the genotype value to 1 (heterozygous/mixed).  

In [None]:
# RUN
genotypes = gt_calls["genotypes"]
genotypes.to_csv(os.path.join(
        wdir, "filtered_genotypes_table.csv"))
genotypes.head()

**prevalences**: take the *genotypes* table, if a cell's value is 2, reset its value to 1; otherwise leave as is.

In [None]:
# RUN
prevalences = gt_calls["prevalences"]
prevalences.to_csv(os.path.join(wdir, "prevalences_input_table.csv"))
prevalences.head()

# Calculate Prevalence in Each Region

In [None]:
#### User Input #######

mutations_of_interest = [
"crt-Cys72Ser",
"crt-Val73Leu",
"crt-Met74Ile",
"crt-Asn75Glu",
"crt-Lys76Thr",
"crt-Asn326Ser",
"crt-Ile356Thr",
"dhfr-ts-Asn51Ile",
"dhfr-ts-Cys59Arg",
"dhfr-ts-Ser108Asn",
"dhfr-ts-Ser108Thr",
"dhfr-ts-Ile164Leu",
"dhps-Ala437Gly",
"dhps-Ala581Gly",
"dhps-Ala613Ser",
"dhps-Ala613Thr",
"dhps-Ile431Val",
"dhps-Lys540Glu",
"dhps-Ser436Ala",
"dhps-Ser436Phe",
"k13-Ala675Val",
"k13-Arg539Thr",
"k13-Arg561His",
"k13-Arg622Ile",
"k13-Cys469Phe",
"k13-Cys469Tyr",
"k13-Cys580Tyr",
"k13-Pro441Leu",
"k13-Tyr493His",
"k13-Val568Gly",
"mdr1-Asn1042Asp",
"mdr1-Asn86Phe",
"mdr1-Asn86Tyr",
"mdr1-Asp1246Tyr",
"mdr1-Ser1034Cys",
"mdr1-Tyr184Phe"]


In [None]:
#### RUN ####

prevalences_input_table = 'prevalences_input_table.csv'
metadata_files = os.listdir('/opt/user/prevalence_metadata')
if '.ipynb_checkpoints' in metadata_files:
    metadata_files.remove('.ipynb_checkpoints')
metadata_files.sort()

for metadata_file in metadata_files:
    output_summary_table = metadata_file.split('_metadata')[0]+'_prevalence_summary.tsv'
    cap.calculate_prevalences('/opt/user/prevalence_metadata/'+metadata_file,
                              prevalences_input_table,
                              mutations_of_interest,
                              output_summary_table)

variant_set = set(open(prevalences_input_table).readlines()[2].strip().split(',')[1:])
mutations_of_interest_set = set(mutations_of_interest)
variants_of_interest = list(variant_set.intersection(mutations_of_interest_set))
variants_of_interest.sort()

# Plot Choropleth

In [None]:
### RUN ###
def load_json(region_of_interest):
    json_file = open('/opt/resources/geojson_files/'+region_of_interest+'.geojson','r')
    region_df = json.load(json_file)
    return region_df

def read_prevalence_table(sample_set, variants_of_interest):
    summary_table = sample_set+'_prevalence_summary.tsv'
    prevalence_summary_df = pd.read_csv(summary_table, sep='\t')
    for variant in variants_of_interest:
        prevalence_summary_df[variant]=[float(x.split()[0]) for x in prevalence_summary_df[variant]]
    return prevalence_summary_df

def display_choropleth(region_of_interest, variant, sample_set, variants_of_interest):
    region_dict = {'uganda':{'zoom':5.8, 'coordinates':{"lat": 1.3733, "lon": 32.2903}},
                   'tanzania':{'zoom':4, 'coordinates':{"lat": -6.3690, "lon": 34.8888}}
                  }
    json_file = load_json(region_of_interest)
    prevalence_summary_df = read_prevalence_table(sample_set, variants_of_interest)
    fig = px.choropleth_mapbox(prevalence_summary_df, 
                                geojson=json_file, 
                                locations='Sites', 
                                color=variant,
                                color_continuous_scale="reds",
                                mapbox_style="open-street-map",
                                featureidkey="properties.name",
                                zoom=region_dict[region_of_interest]['zoom'], 
                                center = region_dict[region_of_interest]['coordinates'],
                                labels={'Sites':'Site'},
                                range_color=(0,1),
    )
    fig.update_layout(margin={"r":0,"t":40,"l":0,"b":0})
    fig.update_layout(height=500, title=sample_set)
    return fig

In [None]:
########### RUN #################
region = widgets.Dropdown(
    options=['uganda', 'tanzania'],
    value='uganda',
    description='region:',
    disabled=False,
)
display(region)

variant = widgets.Dropdown(
    options=variants_of_interest,
    description='variant:',
    disabled=False,
)
display(variant)

In [None]:
##### re-run after making a new selection with the dropdown menus above ####
sample_sets = [x.split('_metadata')[0] for x in metadata_files]
for sample_set in sample_sets:
    fig = display_choropleth(region.value, variant.value, sample_set, variants_of_interest)
    fig.show()

## Filter genotypes / prevalences
It is generally a good idea to do some basic noise removal once the genotypes are created. Some suggestions are provided here.

### Filter variants that are always at low WSAF
If a variant is only seen at a low frequency within samples, it is a good indication that it could be just noise. Here we will set a number of samples and minimum WSAF threshold to remove such noise.

```python
num_samples_wsaf = 2
min_wsaf = 0.5
wsaf_filter = ((freq > min_wsaf).sum()) >= num_samples_wsaf
```

The above options will keep the variants that are in at > 0.5 WSAF in at least 2 samples.

In [None]:
# USER INPUT
num_samples_wsaf = 2
min_wsaf = 0.5

In [None]:
wsaf_filter = ((freq > min_wsaf).sum()) >= num_samples_wsaf
print(("{} of {} variants will remain after the wsaf filter").format(
    wsaf_filter.sum(), freq.shape[1]))

### Filter variants that are observed with low UMI counts
If a variant is only supported by a low number of UMIs across the entire sample set, it is another indication of noise.

```python
num_samples_umi = 2
min_umi = 3
umi_filter = ((filtered_mutation_counts >= min_umi).sum()) > num_samples_umi
```

The above options will keep the variants that are supported by at least 3 UMIs in at least 2 samples.

In [None]:
# USER INPUT
num_samples_umi = 2
min_umi = 3

In [None]:
# RUN
umi_filter = ((filtered_mutation_counts >= min_umi).sum()) > num_samples_umi
print(("{} of {} variants will remain after the UMI filter").format(
    umi_filter.sum(), freq.shape[1]))

### Keep variants that were targeted
In most projects there are a number of variants that we would like to report, even if they are not seen in the sample set. We would like to stop those variants from being removed by the above filters.

In [None]:
# RUN
targ = freq.columns.get_level_values("Targeted") == "Yes"

### Combine filters
Keep the variants that are either targeted or passing filters

In [None]:
variant_mask = targ | (wsaf_filter & umi_filter)
print(("{} variants will remain in the final call set.\n"
       "{} variants were targeted and will be kept; and {} will be removed by "
       "the combined UMI and WSAF filters.").format(
    variant_mask.sum(), targ.sum(), (wsaf_filter & umi_filter).sum()))

## Filter data tables with the combined filters

In [None]:
filtered_genotypes = genotypes.loc[:, variant_mask]
filtered_genotypes.to_csv(os.path.join(wdir, "final_filtered_genotypes.csv"))
filtered_genotypes.head()

In [None]:
filtered_prevalences = prevalences.loc[:, variant_mask]
filtered_prevalences.to_csv(os.path.join(wdir, "final_filtered_prevalences_input_table.csv"))
filtered_prevalences.head()