In [21]:
# run this with conda env: conda/env-plot
import os
import pandas as pd
from pathlib import Path
from freyja_plot import FreyjaPlotter
import subprocess
import vcf
import plotly.express as px

heatmap_dir = Path(os.curdir).resolve()
benchmark_dir = heatmap_dir.parents[1]
ont_dir = benchmark_dir / "ont"
outdir = heatmap_dir / "decon_heatmap_out"
outdir.mkdir(exist_ok=True)


In [22]:
mixtures = [f"Mixture{x:02}" for x in range(1,43)]
# batches = ["Control","Neg Spike-in","Pos Spike-in"]
batches = {"Control": "05-05-23-A41", "Neg Spike-in": "05-16-23-A41", "Pos Spike-in": "06-26-23-A41"}

## Compile info about total read counts, percent coverage, and variant call quality

In [23]:
def count_reads(bam):
    # stats = subprocess.check_output(f"module load samtools", shell=True)
    stats = subprocess.check_output(f"module load samtools; samtools idxstats {bam}; exit 0", shell=True).decode()
    # print(stats)
    return stats.split("\n")[0].split("\t")[2]

In [24]:
def variant_call_quality(vcf_file):
    vcf_reader = vcf.Reader(filename=vcf_file)
    # for record in vcf_reader:
    #     print(record.QUAL)
    #     break
    qual_list = [record.QUAL for record in vcf_reader]
    return sum(qual_list)/len(qual_list)

In [25]:
def coverage_details(bam):
    output = subprocess.check_output(f"module load samtools; samtools coverage {bam}; exit 0", shell=True).decode()
    # example output:
    # #rname	startpos	endpos	numreads	covbases	coverage	meandepth	meanbaseq	meanmapq
    # MN908947.3	1	29903	75218	29241	97.7862	874.309	24.2	60
    results = output.split("\n")[1].split("\t")
    percent_coverage = results[5]
    counts = results[3]
    mean_depth = results[6]
    return percent_coverage,counts,mean_depth

In [26]:
def extract_data():
    batch_col = []
    mix_col = []
    read_counts = []
    depths = []
    quality = []
    coverage_col = []

    for batch,plate in batches.items():
        # batch_col.extend([batch]*len(mixtures))
        # mix_col.extend(mixtures)
        for mix in mixtures:
            batch_col.append(batch)
            mix_col.append(mix)

            bam = ont_dir.joinpath(f"MixedControl-{plate}-fastqs/output/alignments/").glob(f"{mix}*.bam").__next__()
            # vcf_dir = ont_dir.joinpath(f"MixedControl-{plate}-fastqs/output/vcfs")
            # vcf_file = vcf_dir / f"{mix}.merged.vcf"

            # # count reads
            # read_counts.append(count_reads(bam))

            # get coverage and counts
            coverage, counts, mean_depth = coverage_details(bam)
            coverage_col.append(coverage)
            read_counts.append(counts)
            depths.append(mean_depth)

            # get average variant call quality
            # quality.append(variant_call_quality(vcf_file))
            quality.append("")
            # break
        # break

    df = pd.DataFrame({
        "batch":batch_col,
        "mixture":mix_col,
        "read_counts":read_counts,
        "coverage":coverage_col,
        "mean_depth":depths,
        "quality":quality,
        })
    df.to_csv(outdir/"quality_stats.csv", index=False)
    return df

In [29]:
### Decide how to acquire data (only need to extract from bams the first time)
# # create df anew
# df = extract_data()
# read in existing data
df = pd.read_csv(outdir/"quality_stats.csv")

In [44]:
# finalize df
mixture_renames = {'Mixture01': '0ADGIO1O2O3O4O5', 'Mixture02': '0ADGIO1', 'Mixture03': 'O2O3O4O5','Mixture04': '0AGIO1O2', 'Mixture05': '0O5O3O4', 'Mixture06': 'ADGIO1O2O3','Mixture07': 'AGIO3O4O5', 'Mixture08': 'O1O2O3O4O5', 'Mixture09': '0','Mixture10': 'O1O2', 'Mixture11': 'O3', 'Mixture12': 'O5','Mixture13': 'O4', 'Mixture14': '0-2', 'Mixture15': 'A', 'Mixture16': 'G', 'Mixture17': 'I', 'Mixture18': 'D', 'Mixture19': 'O1', 'Mixture20': 'O2', 'Mixture21': '0-3','Mixture22': 'O3-2', 'Mixture23': 'O3-3', 'Mixture24': 'O5-2', 'Mixture25': 'O5-3', 'Mixture26': 'O4-2', 'Mixture27': 'O4-3', 'Mixture28': 'O2-2', 'Mixture29': 'O2O3O4O5-2', 'Mixture30': 'O2O3O4O5-3', 'Mixture31': '0ADGIO1-2', 'Mixture32': '0AIO1O2O3O4O5', 'Mixture33': '0-4', 'Mixture34': 'A-2','Mixture35': 'G-2', 'Mixture36': 'I-2', 'Mixture37': 'D-2', 'Mixture38': 'O1-2', 'Mixture39': 'O2-3', 'Mixture40': 'O3-4','Mixture41': 'O5-4', 'Mixture42': 'O4-4'}
mixture_renames = {m:n.lower() for m,n in mixture_renames.items()}
df["mixture"] = df["mixture"].apply(lambda x: mixture_renames.get(x))
df["mean_depth"] = df["mean_depth"].astype(float)

In [31]:
def getHeatmap(df,field,title=None,labels=None,title_y=0.7):
    fig_df = df[["batch","mixture",field]].pivot(index="batch",columns="mixture",values=field)
    fig = px.imshow(fig_df, title=title, labels=labels)
    fig.update_layout(title_y=0.7)
    return fig

In [37]:
# fig = px.imshow(getHeatmap(df,"coverage"), title="Percent coverage heatmap", labels={"y":"Batch","x":"Mixture"})
fig = getHeatmap(df,"coverage", title="Percent coverage heatmap", labels={"y":"Batch","x":"Mixture"})
fig

In [38]:
fig2 = px.imshow(getHeatmap(df,"read_counts"), title="Read counts heatmap", labels={"y":"Batch","x":"Mixture"})
fig2.update_layout(title_y=0.7)

In [39]:
# fig3 = px.imshow(getHeatmap(df,"quality"), title="Average variant call quality heatmap", labels={"y":"Batch","x":"Mixture"})
# fig3.update_layout(title_y=0.7)

In [45]:
df

mean_mean_depth = {batch:df[df["batch"]==batch]["mean_depth"].mean() for batch in df["batch"].unique()}
mean_mean_depth

{'Control': 10845.81880952381,
 'Neg Spike-in': 5878.616666666667,
 'Pos Spike-in': 11502.592857142856}

In [46]:
df["normalized_mean_depth"] = df.apply(lambda row: row["mean_depth"] / mean_mean_depth[row["batch"]], axis=1)

In [40]:
fig4 = px.imshow(getHeatmap(df,"mean_depth"), title="Mean depth heatmap", labels={"y":"Batch","x":"Mixture"})
fig4.update_layout(title_y=0.7)

In [47]:
fig5 = px.imshow(getHeatmap(df,"normalized_mean_depth"), title="Normalized mean depth heatmap", labels={"y":"Batch","x":"Mixture"})
fig5.update_layout(title_y=0.7)

In [41]:
# data_for_plots = {
#     "coverage":fig, "read_counts":fig2, 
#     # "vc_quality":fig3
#     }
# for t,f in data_for_plots.items():
#     f.write_html(outdir/f"{t}.html")
#     f.write_image(outdir/f"{t}.png")