In [None]:
# RUN
import sys
sys.path.append("/opt/src")
import mip_functions as mip
import probe_summary_generator
import pickle
import json
import copy
import math
import os
import numpy as np
import subprocess
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib.lines import Line2D
plt.rcParams['svg.fonttype'] = 'none'
import pandas as pd
import seaborn as sns
import plotly.express as px
import warnings
warnings.filterwarnings('ignore')
import allel
wdir = "/opt/analysis/"
data_dir = "/opt/data/"

# Calling genotypes, prevalences and filtering data
The original vcf file created by freebayes contain the genotypes determined by the program itself. In addition, genotype values for aggregated and non-aggregated nucleotides and aminoacids are also available as *_genotypes_table.csv files as described above.  

However, the default parameters generating the vcf file are not very strict. In this part of the analysis we will apply various filters to the count tables and generate genotype calls based on those filters.

### Chose which tables to analyse
Select the type of data to analyse. Make sure the count file is matching the coverage file. e.g. alternate_XX_table and coverage_XX_table, XX must be the same value (AA, AN or nothing).

#### Example cell
```python
mutation_count_file = "/opt/analysis/alternate_AA_table.csv"
mutation_coverage_file = "/opt/analysis/coverage_AA_table.csv"
```

In [None]:
# USER INPUT

mutation_count_file = "/opt/analysis/alternate_AA_table.csv"
mutation_coverage_file = "/opt/analysis/coverage_AA_table.csv"

In [None]:
# RUN
import pandas as pd
mutation_counts = pd.read_csv(mutation_count_file,
                              header=list(range(6)),
                              index_col=0)
mutation_counts.head()

In [None]:
# RUN
mutation_coverage = pd.read_csv(mutation_coverage_file,
                                index_col=0,
                                header=list(range(6)))
mutation_coverage.head()

### Set your filters   
1.  **min_coverage**: how many UMIs are needed to for a genomic position for a sample to reliably call genotypes. If we set min_coverage = 10, any locus within a sample that is covered below this threshold will have an NA genotype.
2.  **min_count**: if a genomic position have enough coverage, how many UMIs supporting an ALT allele call is needed for a reliable call. If we set min_count = 2, any mutation with an  call that has less than 2 UMIs supporting the ALT call will revert to REF.
3.  **min_freq**: a minimum within sample allele frequency threshold to consider a variant valid. If set to 0.01, for example, a variant locus in a sample that is at 0.005 frequency for the ALT allele within the sample, the locus would be called REF, if the within sample AF is between 0.01 and 0.99, it would be considered HET, and if > 0.99, it would be homozygous ALT.

#### Example cell
```python
# filter mutation counts for minimum count parameter
# by setting counts to zero if it is below threshold
min_count = 2
# filter loci without enough coverage by setting
# coverage to zero if it is below threshold
min_coverage = 10
# call genotypes using the minimum within sample
# allele frequency
min_freq = 0
```

In [None]:
# USER INPUT 

# filter mutation counts for minimum count parameter
# by setting counts to zero if it is below threshold
min_count = 2
# filter loci without enough coverage by setting
# coverage to zero if it is below threshold
min_coverage = 10
# call genotypes using the minimum within sample
# allele frequency
min_freq = 0

In [None]:
# RUN

# import the PCA module which has genotype calling and
# filtering functions 
import PCA

gt_calls = PCA.call_genotypes(mutation_counts, mutation_coverage,
                              min_count, min_coverage, min_freq)
gt_calls.keys()

### What are the dataframes generated by call_genotypes function and how  are they generated?

**filtered_mutation_counts**: take the mutation_counts table, if a cell's value is below *min_count*, reset that cell's value to zero, otherwise leave as is.  

In [None]:
# RUN
filtered_mutation_counts = gt_calls["filtered_mutation_counts"]
filtered_mutation_counts.to_csv(os.path.join(
        wdir, "filtered_alternate_AA_table.csv"))
filtered_mutation_counts.head()

**filtered_mutation_coverage**: take the mutation_coverage table, if a cell's value is below *min_coverage*, reset that cell's value to zero, otherwise leave as is.

In [None]:
# RUN
filtered_mutation_coverage = gt_calls["filtered_mutation_coverage"]
filtered_mutation_coverage.to_csv(os.path.join(
        wdir, "filtered_coverage_AA_table.csv"))
filtered_mutation_coverage.head()

**wsaf**: divide *filtered_mutation_counts* table by *filtered_mutation_coverage* table, yielding within sample allele frequencies.  

In [None]:
# RUN
freq = gt_calls["wsaf"]
freq.to_csv(os.path.join(
        wdir, "within_sample_allele_frequencies.csv"))
freq.head()

**genotypes**: take the *wsaf* table, if a cell's value is less than *min_freq* set the genotype value to 0 (homozygous wild type); if the cell's value is more than (*1 - min_freq*) set the genotype value to 2 (homozygous mutant), if the cell's value is between *min_freq* and (*1 - min_freq*) set the genotype value to 1 (heterozygous/mixed).  

In [None]:
# RUN
genotypes = gt_calls["genotypes"]
genotypes.to_csv(os.path.join(
        wdir, "filtered_genotypes_table.csv"))
genotypes.head()

**prevalences**: take the *genotypes* table, if a cell's value is 2, reset its value to 1; otherwise leave as is.

In [None]:
# RUN
prevalences = gt_calls["prevalences"]
prevalences.to_csv(os.path.join(wdir, "prevalences_input_table.csv"))
prevalences.head()

## Filter genotypes / prevalences
It is generally a good idea to do some basic noise removal once the genotypes are created. Some suggestions are provided here.

### Filter variants that are always at low WSAF
If a variant is only seen at a low frequency within samples, it is a good indication that it could be just noise. Here we will set a number of samples and minimum WSAF threshold to remove such noise.

```python
num_samples_wsaf = 2
min_wsaf = 0.5
wsaf_filter = ((freq > min_wsaf).sum()) >= num_samples_wsaf
```

The above options will keep the variants that are in at > 0.5 WSAF in at least 2 samples.

In [None]:
# USER INPUT
num_samples_wsaf = 2
min_wsaf = 0.5

In [None]:
wsaf_filter = ((freq > min_wsaf).sum()) >= num_samples_wsaf
print(("{} of {} variants will remain after the wsaf filter").format(
    wsaf_filter.sum(), freq.shape[1]))

### Filter variants that are observed with low UMI counts
If a variant is only supported by a low number of UMIs across the entire sample set, it is another indication of noise.

```python
num_samples_umi = 2
min_umi = 3
umi_filter = ((filtered_mutation_counts >= min_umi).sum()) > num_samples_umi
```

The above options will keep the variants that are supported by at least 3 UMIs in at least 2 samples.

In [None]:
# USER INPUT
num_samples_umi = 2
min_umi = 3

In [None]:
# RUN
umi_filter = ((filtered_mutation_counts >= min_umi).sum()) > num_samples_umi
print(("{} of {} variants will remain after the UMI filter").format(
    umi_filter.sum(), freq.shape[1]))

### Keep variants that were targeted
In most projects there are a number of variants that we would like to report, even if they are not seen in the sample set. We would like to stop those variants from being removed by the above filters.

In [None]:
# RUN
targ = freq.columns.get_level_values("Targeted") == "Yes"

### Combine filters
Keep the variants that are either targeted or passing filters

In [None]:
variant_mask = targ | (wsaf_filter & umi_filter)
print(("{} variants will remain in the final call set.\n"
       "{} variants were targeted and will be kept; and {} will be removed by "
       "the combined UMI and WSAF filters.").format(
    variant_mask.sum(), targ.sum(), (wsaf_filter & umi_filter).sum()))

## Filter data tables with the combined filters

In [None]:
filtered_genotypes = genotypes.loc[:, variant_mask]
filtered_genotypes.to_csv(os.path.join(wdir, "final_filtered_genotypes.csv"))
filtered_genotypes.head()

In [None]:
filtered_prevalences = prevalences.loc[:, variant_mask]
filtered_prevalences.to_csv(os.path.join(wdir, "final_filtered_prevalences_input_table.csv"))
filtered_prevalences.head()

In [None]:
prevalences_input_table = '/opt/analysis/prevalences_input_table.csv'
metadata_file = '/opt/prevalence_metadata/PRX-00_metadata.csv'
UMI_suffix = '-PRX-00-1'
output_file = open('/opt/analysis/test_new_prevalence.csv','w')

mutations = [
"crt-Cys72Ser",
"crt-Val73Leu",
"crt-Met74Ile",
"crt-Asn75Glu",
"crt-Lys76Thr",
"crt-Asn326Ser",
"crt-Ile356Thr",
"dhfr-ts-Asn51Ile",
"dhfr-ts-Cys59Arg",
"dhfr-ts-Ser108Asn",
"dhfr-ts-Ser108Thr",
"dhfr-ts-Ile164Leu",
"dhps-Ala437Gly",
"dhps-Ala581Gly",
"dhps-Ala613Ser",
"dhps-Ala613Thr",
"dhps-Ile431Val",
"dhps-Lys540Glu",
"dhps-Ser436Ala",
"dhps-Ser436Phe",
"k13-Ala675Val",
"k13-Arg539Thr",
"k13-Arg561His",
"k13-Arg622Ile",
"k13-Cys469Phe",
"k13-Cys469Tyr",
"k13-Cys580Tyr",
"k13-Pro441Leu",
"k13-Tyr493His",
"k13-Val568Gly",
"mdr1-Asn1042Asp",
"mdr1-Asn86Phe",
"mdr1-Asn86Tyr",
"mdr1-Asp1246Tyr",
"mdr1-Ser1034Cys",
"mdr1-Tyr184Phe"]
def create_site_dict(metadata_file):
	site_dict = {}
	for line in open (metadata_file):
		if 'Sites' not in line:
			line = line.strip().split(',')
			sample = line[1]+UMI_suffix
			site = line[0]
			site_dict[sample] = site
	return site_dict

# {site:mutation_name}

def get_counts(prevalences_input_table, site_dict):
	count_dict = {}
	count_dict['overall'] = {}
	base, top = 0, 0
	for line_number, line in enumerate(open(prevalences_input_table, 'r')):
		if "Mutation Name" in line:
			mutation_dict = dict(enumerate(line.strip().split(',')[1:]))
			# need to print an error message if two mutations have the same name
		if line_number >= 6:
			line = line.strip().split(',')
			sample = line[0]
			if sample in site_dict:
				site = site_dict[sample]
				if site not in count_dict:
					count_dict[site] = {}
				for tally_number, tally in enumerate(line[1:]):
					# mutation_name = tally_number
					if tally == '':
						if tally_number not in count_dict[site]:
							count_dict[site][tally_number] = [0, 0]
						if tally_number not in count_dict['overall']:
							count_dict['overall'][tally_number] = [0, 0]
					if tally == '0.0':
						if tally_number not in count_dict[site]:
							count_dict[site][tally_number] = [0, 1]
						count_dict[site][tally_number][1] += 1
						if tally_number not in count_dict['overall']:
							count_dict['overall'][tally_number] = [0, 1]
						count_dict['overall'][tally_number][1] += 1
					if tally == '1.0':
						if tally_number not in count_dict[site]:
							count_dict[site][tally_number] = [1, 1]
						count_dict[site][tally_number][1] += 1
						count_dict[site][tally_number][0] += 1
						if tally_number not in count_dict['overall']:
							count_dict['overall'][tally_number] = [1, 1]
						count_dict['overall'][tally_number][1] += 1
						count_dict['overall'][tally_number][0] += 1
	return count_dict, mutation_dict

site_dict = create_site_dict(metadata_file)
count_dict, mutation_dict = get_counts(prevalences_input_table, site_dict)
# print(count_dict)

output_file.write('Sites')
sites = list(count_dict.keys())
sites.remove('overall')
sites.sort()
for column_number in mutation_dict:
	if mutation_dict[column_number] in mutations:
		output_file.write('\t'+mutation_dict[column_number])
for site in sites:
	output_file.write('\n'+site)
	for column_number in count_dict[site]:
		if mutation_dict[column_number] in mutations:
			alt = count_dict[site][column_number][0]
			cov = count_dict[site][column_number][1]
			if alt == 0:
				prevalence = 0
			else:
				prevalence = alt/cov
			output_file.write(f"\t{prevalence} ({alt}/{cov})")
output_file.write('\n'+'overall')
for column_number in count_dict['overall']:
	if mutation_dict[column_number] in mutations:
		alt = count_dict['overall'][column_number][0]
		cov = count_dict['overall'][column_number][1]
		if alt == 0:
			prevalence = 0
		else:
			prevalence = alt/cov
		output_file.write(f"\t{prevalence} ({alt}/{cov})")
prevalences_output = pd.read_csv('/opt/analysis/test_new_prevalence.csv',sep='\t', index_col=0)
prevalences_output.head()

In [None]:
import json
import pandas as pd
import plotly.express as px
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

########### USER INPUT #################
prevalence_summary_file_name='/opt/analysis/test_new_prevalence.csv'
region_of_interest='uganda'

# load the inputs
region_dict = {'uganda':['uganda-with-regions.geojson',5, {"lat": 1.3733, "lon": 32.2903}],
               'tanzania':[['tanzania-with-regions.geojson',4, {"lat": -6.3690, "lon": 34.8888}]]}
region_list=[
    ('uganda',),
    ('tanzania',)
]
prevalence_summary_df = pd.read_csv(prevalence_summary_file_name, sep='\t')

# identify columns that are variants not headers
variant_columns=prevalence_summary_df.columns.difference(['Sites'])

# get a list of every variant from the dataframe and convert prevalence values to floats
variant_list = []
for column in variant_columns:
    prevalence_summary_df[column]=[float(x.split()[0]) for x in prevalence_summary_df[column]]
    variant_list.append(column)

def display_choropleth(variant):
    json_file = json.load(open('/opt/prevalence_metadata/'+region_dict[region_of_interest][0]))
    fig = px.choropleth_mapbox(prevalence_summary_df, 
                                geojson=json_file, 
                                locations='Sites', 
                                color=variant,
                                color_continuous_scale="reds",
                                mapbox_style="carto-positron",
                                featureidkey="properties.name",
                                zoom=region_dict[region_of_interest][1], 
                                center = region_dict[region_of_interest][2],
                                labels={'Sites':'Site'},
                                range_color=(0,1),
    )
    fig.update_layout(margin={"r":0,"t":40,"l":0,"b":0})
    fig.update_layout(height=600)
    return fig

interact(display_choropleth, variant=variant_list);