# Try to implement bayes based variant calling

- For Parent first
- Check for more mutatios later

In [1]:
# Import
import sys
sys.path.append("/home/emre/github_repo/MinION")
from minION.util import IO_processor
from minION import analyser
from minION import consensus

import importlib
importlib.reload(analyser)
importlib.reload(consensus)
importlib.reload(IO_processor)
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt
import gzip
import math
import re
import pickle
import itertools
import pysam
import subprocess

In [13]:
def get_bases_from_pileup(bam_file, chrom, positions):
    bases_dict = {position: {} for position in positions}
    qualities_dict = {position: {} for position in positions}
    
    with pysam.AlignmentFile(bam_file, 'rb') as bam:
        for pileup_column in bam.pileup(chrom, min(positions) - 1, max(positions) + 1,
                                        min_base_quality=0, 
                                        min_mapping_quality=0, 
                                        truncate=True):
            pos = pileup_column.pos + 1
            if pos in positions:
                for pileup_read in pileup_column.pileups:
                    read_name = pileup_read.alignment.query_name

                    # Handle deletions
                    if pileup_read.is_del:
                        base = '-'  # or any symbol you prefer to represent a deletion
                        quality = 0  # Assign a default quality for deletions
                    elif not pileup_read.is_refskip:
                        base = pileup_read.alignment.query_sequence[pileup_read.query_position]
                        quality = pileup_read.alignment.query_qualities[pileup_read.query_position]
                    else:
                        continue

                    # Add base and quality to the dictionaries
                    if read_name not in bases_dict[pos]:
                        bases_dict[pos][read_name] = base
                        qualities_dict[pos][read_name] = quality

    # Get unique read names and sort them
    read_names = sorted(set().union(*[bases_dict[pos].keys() for pos in bases_dict]))

    # Create DataFrames
    df_bases = pd.DataFrame(index=read_names, columns=positions)
    df_qualities = pd.DataFrame(index=read_names, columns=positions)
    
    # Populate DataFrames
    for pos in positions:
        for read_name in bases_dict[pos]:
            df_bases.at[read_name, pos] = bases_dict[pos][read_name]
            df_qualities.at[read_name, pos] = qualities_dict[pos][read_name]
    
    # Fill NaN with "-" for bases and 0 for qualities
    df_bases = df_bases.fillna("-")
    df_qualities = df_qualities.fillna(10) # 10 is the lowest quality filter we used for filtering

    return df_bases, df_qualities


def get_soft_pop_frequency(bam_file, template, reference, nb_positions, min_depth = 5):

    # Min depth based on the alphabet size
    
    # Check also for variant by sampling random positions

    bases_df = get_bases_from_pileup(bam_file, reference, nb_positions)
    
    frequency_df = bases_df.apply(get_variant_name, axis=1, args=(template, nb_positions)).value_counts().reset_index()

    frequency_df.columns = ['Population', 'N_reads']
    
    frequency_df["Frequency"] = frequency_df["N_reads"] / frequency_df["N_reads"].sum()


    # Filter for frequency > 0.4 and depth > 15
    frequency_df = frequency_df[(frequency_df["Frequency"] > min_freq) & (frequency_df["N_reads"] > min_depth)]

    return frequency_df


def add_neighbouring_positions(positions, nb_neighbours, max_index):
    if positions is None:
        return None
    
    elif isinstance(positions, int):
        positions = [positions]

    new_positions = []
    for position in positions:
        for new_pos in range(position - nb_neighbours, position + nb_neighbours + 1):
            if 1 <= new_pos <= max_index:  # Check if the new position is within valid index range
                new_positions.append(new_pos)
    return sorted(set(new_positions))

def calculate_mean_quality_for_reads(bases_df, qual_df, nb_positions, nb_neighbours):
    if isinstance(nb_positions, int):
        nb_positions = [nb_positions]  # Convert single integer to a list

    read_mean_qualities = {}

    for nb_position in nb_positions:
        if nb_position not in bases_df.columns or nb_position not in qual_df.columns:
            continue  # Skip positions that are not present in either DataFrame

        neighbor_positions = range(nb_position - nb_neighbours, nb_position + nb_neighbours + 1)

        for read_name in qual_df.index:
            total_qual = 0
            valid_count = 0

            for position in neighbor_positions:
                if position not in bases_df.columns:
                    continue  # Skip positions that are outside the DataFrame's columns

                base = bases_df.at[read_name, position]
                quality = qual_df.at[read_name, position]

                if base != "-" and not pd.isna(quality):
                    total_qual += quality
                    valid_count += 1

            if valid_count == 0:
                continue  # Skip if no valid qualities were found
            else:
                if read_name not in read_mean_qualities:
                    read_mean_qualities[read_name] = {}
                read_mean_qualities[read_name][nb_position] = total_qual / valid_count

    # Convert the dictionary into a DataFrame
    mean_quality_df = pd.DataFrame.from_dict(read_mean_qualities, orient='index')
    updated_base_df = bases_df[nb_positions] 

    return updated_base_df, mean_quality_df

def add_neighbouring_positions(positions, nb_neighbours, max_index):
    if positions is None:
        return None
    
    elif isinstance(positions, int):
        positions = [positions]

    new_positions = []
    for position in positions:
        for new_pos in range(position - nb_neighbours, position + nb_neighbours + 1):
            if 1 <= new_pos <= max_index:  # Check if the new position is within valid index range
                new_positions.append(new_pos)
    return sorted(set(new_positions))

def calculate_mean_quality_for_reads(bases_df, qual_df, nb_positions, nb_neighbours):
    if isinstance(nb_positions, int):
        nb_positions = [nb_positions]  # Convert single integer to a list

    read_mean_qualities = {}

    for nb_position in nb_positions:
        if nb_position not in bases_df.columns or nb_position not in qual_df.columns:
            continue  # Skip positions that are not present in either DataFrame

        neighbor_positions = range(nb_position - nb_neighbours, nb_position + nb_neighbours + 1)

        for read_name in qual_df.index:
            total_qual = 0
            valid_count = 0

            for position in neighbor_positions:
                if position not in bases_df.columns:
                    continue  # Skip positions that are outside the DataFrame's columns

                base = bases_df.at[read_name, position]
                quality = qual_df.at[read_name, position]

                if base != "-" and not pd.isna(quality):
                    total_qual += quality
                    valid_count += 1

            if valid_count == 0:
                continue  # Skip if no valid qualities were found
            else:
                if read_name not in read_mean_qualities:
                    read_mean_qualities[read_name] = {}
                read_mean_qualities[read_name][nb_position] = total_qual / valid_count

    # Convert the dictionary into a DataFrame
    mean_quality_df = pd.DataFrame.from_dict(read_mean_qualities, orient='index')
    updated_base_df = bases_df[nb_positions] 

    return updated_base_df, mean_quality_df

def get_non_error_prop(quality_score):
    """Convert quality score to non-error probability."""
    return 1 - 10 ** (-quality_score / 10)

def get_softmax_count_df(bases_df, qual_df, nb_positions):

    alphabet = "ACTG-"
    softmax_counts = {position: [] for position in nb_positions}
    
    for position in nb_positions:
        for base in alphabet:
            base_mask = bases_df[position] == base
            base_counts = base_mask.sum()
            # Calculate the non-error probability for each base and sum them up
            soft_count = sum(base_mask * qual_df[position].apply(get_non_error_prop))
            softmax_counts[position].append(soft_count)

    softmax_count_df = pd.DataFrame(softmax_counts, columns=nb_positions, index=list(alphabet))

    # Apply softmax to each column (position)
    softmax_count_df = softmax_count_df.apply(lambda x: x / x.sum(), axis=0)

    return softmax_count_df

def get_softmax(soft_count):
    """Calculate the softmax of a dictionary of soft counts."""
    # Calculate the sum of the non-error probabilities
    total = sum(soft_count.values())
    # Calculate the softmax for each base
    return {base: count / total for base, count in soft_count.items()}

def call_potential_populations(softmax_df, ref_seq):
    positions = softmax_df.columns
    top_combinations = []
    
    # Get the top 2 variants for each position
    for position in positions:
        top_variants = softmax_df[position].nlargest(2)

        if top_variants.iloc[1] < 0.1:
            top_combinations.append([top_variants.index[0]])
        
        else:
            top_combinations.append(top_variants.index.tolist())

        potential_combinations = list(itertools.product(*top_combinations))

    
    variants = {"Variant" : [], "Probability" : []}
    
    for combination in potential_combinations:
        final_variant = []
        for i, pos in enumerate(positions):

            if combination[i] == ref_seq[pos - 1]:
                continue

            elif combination[i] == "-":
                var = f"{ref_seq[pos - 1]}{pos}DEL"
                final_variant.append(var)
            else:
                var = f"{ref_seq[pos - 1]}{pos}{combination[i]}"
                final_variant.append(var)

        final_variant = '_'.join(final_variant)
        if final_variant == "":
            final_variant = "#PARENT#"

        joint_prob = np.prod([softmax_df.at[combination[i], positions[i]] for i in range(len(positions))])
    
        variants["Variant"].append(final_variant)
        variants["Probability"].append(joint_prob)

    return variants

def get_variant_soft(bam_file, template_seq, ref_name, padding = 50):

    variants = {"Variant" : [], "Position" : [], "Alignment Probability" : [], "Alignment Count" : []}

    alignment_count = int(subprocess.run(f"samtools view -c {bam_file}", shell=True, capture_output=True).stdout.decode("utf-8").strip())

    if alignment_count < 5:
        print("Not enough alignments")
        return None


    template = analyser.get_template_sequence(template_seq)

    padding_start, padding_end = padding, padding
    range_positions = range(padding_start + 1, len(template) - padding_end + 1) 

    freq_dist = pd.DataFrame(analyser.get_highest_non_ref_base_freq_2(bam_file, ref_name, range_positions, template, qualities=False)[0]).T.rename(columns={0:"Base", 1:"Frequency"})

    nb_positions = analyser.get_nb_positions(freq_dist, 0.3)

    available_positions = [pos for pos in range_positions if pos not in nb_positions]

    if len(nb_positions) == 0:
        # Select random 3 positions
        nb_positions = np.random.choice(available_positions, 3, replace=False)

    elif len(nb_positions) == 1:
        add_pos  = np.random.choice(available_positions, 2, replace=False)
        nb_positions = np.append(nb_positions, add_pos)

    elif len(nb_positions) > 15:
        print("Too many positions, either contaminated or sequencing error")
        #nb_positions = np.random.choice(range_positions, 3, replace=False)

    bases_df, qual_df = get_bases_from_pileup(bam_file, ref_name, add_neighbouring_positions(nb_positions, 2, len(template)))
    bases_df, qual_df = calculate_mean_quality_for_reads(bases_df, qual_df, nb_positions, 2)
    softmax_df = get_softmax_count_df(bases_df, qual_df, nb_positions)
    print(softmax_df)
    variant_df = pd.DataFrame(call_potential_populations(softmax_df, template)).sort_values(by="Probability", ascending=False)

    # Take top variant
    variants["Variant"] = variant_df["Variant"].iloc[0]
    variants["Position"] = nb_positions
    variants["Alignment Probability"] = variant_df["Probability"].iloc[0]
    variants["Alignment Count"] = alignment_count

    return variants


def get_variant_df_soft(demultiplex_folder: Path, ref_seq : Path, ref_name : str, barcode_dicts : dict = None, merge = True, min_depth= 5, padding=50):


    if barcode_dicts is None:
        barcode_dicts = get_barcode_dict(demultiplex_folder)
    
    variant_template_df = analyser.template_df(barcode_dicts, rowwise=False)

    variants = {"RBC": [], "FBC": [], "Position": [], "Variant": [], "Alignment Probability": [], "Alignment Count": []}

    template = analyser.get_template_sequence(ref_seq) # Reference sequence

    summary = analyser.read_summary_file(demultiplex_folder)
    n_counts = summary.groupby(["RBC","FBC"])["FBC"].value_counts().reset_index() 



    for barcode_id, barcode_dict in barcode_dicts.items():

        rbc = os.path.basename(barcode_id)

        for front_barcode in barcode_dict:

            fbc = os.path.basename(front_barcode)
            print("Processing", rbc, fbc)
            count = n_counts[(n_counts["RBC"] == rbc) & (n_counts["FBC"] == fbc)]["count"].values[0]

            # # If alignment file exist continue
            if not os.path.exists(os.path.join(front_barcode, "alignment_minimap.bam")):
                print(f"Alignment file in {front_barcode} does not exist, running alignment and indexing")
                analyser.run_alignment_and_indexing(ref_seq, front_barcode)
            
            else: 
                print("Alignment file already exists, skipping alignment and indexing")


            bam_file = front_barcode / "alignment_minimap.bam"


            if not bam_file.exists() or count < min_depth:
                print(f"{bam_file} does not exist.")
                variants["RBC"].append(rbc)
                variants["FBC"].append(fbc)
                variants["Position"].append(["NA"])
                variants["Variant"].append(["NA"])
                variants["Alignment Count"].append(["NA"])
                variants["Alignment Probability"].append(["NA"])
                print(f"Skipping Variant: {fbc}/{rbc}")
                continue

            # try:
            if padding == 0:
                print("Padding is 0. Implementing soft alignment")
            
            else: 
                nn_variants = get_variant_soft(bam_file, ref_seq, ref_name, padding = padding)
                print(nn_variants)

            if nn_variants["Variant"] is None:
                print("Empty variant list")
                variants["RBC"].append(rbc)
                variants["FBC"].append(fbc)
                variants["Position"].append(["NA"])
                variants["Variant"].append(["NA"])
                variants["Alignment Count"].append(["NA"])
                variants["Alignment Probability"].append(["NA"])
                print(f"Skipping Variant: {fbc}/{rbc}")
                continue

            
            variants["RBC"].append(rbc)
            variants["FBC"].append(fbc)
            variants["Position"].append(nn_variants["Position"])
            variants["Variant"].append(nn_variants["Variant"])
            #variants["Reads"].append(count)
            # Check if Alignment count is a number
            if isinstance(nn_variants["Alignment Count"], int) & (isinstance(nn_variants["Alignment Probability"], float) or nn_variants["Alignment Probability"] == "-"):
                variants["Alignment Count"].append(nn_variants["Alignment Count"])
                variants["Alignment Probability"].append(nn_variants["Alignment Probability"])


            else:
                print(f"Skipping {rbc}/{fbc} due incomplete data")
                variants["Alignment Count"].append("NA")
                variants["Alignment Probability"].append("NA")

            print(f"Variant: {fbc}/{rbc} {nn_variants['Alignment Count']} {nn_variants['Alignment Probability']}")
        
        # except Exception as e:
        #     # Append 'NA' in case of an exception
        #     print(f"Error processing {rbc}/{fbc}: {e}")
        #     variants["RBC"].append(rbc)
        #     variants["FBC"].append(fbc)
        #     variants["Position"].append("NA")
        #     variants["Variant"].append("NA")
        #     variants["Alignment Count"].append("NA")
        #     variants["Alignment Frequency"].append("NA")



    if merge:
        variant_df = analyser.rename_barcode(pd.DataFrame(variants).merge(n_counts, on=["RBC","FBC"] , how="left"))
        variant_df["Variant"] = variant_df["Variant"].apply(analyser.format_variant_list)
        variant_df["Variant"] = variant_df["Variant"].apply(lambda x: analyser.adjust_variant(x, padding))

        return variant_df.merge(variant_template_df, on=["Plate", "Well"], how="right")
    else:
        return variants

In [5]:
# Get alignment which were difficult to align 
variant_df = pd.read_pickle('/home/emre/github_repo/MinION/results/2_hetcpiii_minion_errorprone/local/variants_SW_40k.pkl')
variant_df_guppy = pd.read_pickle('/home/emre/github_repo/MinION/results/2_hetcpiii_minion_errorprone/local/variants_SW_BF_40k.pkl')
variant_df[variant_df["Variant"] == "NA"]

Unnamed: 0,Plate,Well,Position,Variant,Alignment Count,Alignment Frequency,count
4,1,A5,[NA],,[NA],[NA],135.0
7,1,A8,-,,2,-,2.0
8,1,A9,-,,2,-,2.0
20,1,B9,-,,3,-,3.0
23,1,B12,-,,2,-,2.0
33,1,C10,-,,4,-,4.0
35,1,C12,-,,7,-,7.0
58,1,E11,-,,1,-,1.0
70,1,F11,-,,2,-,2.0
71,1,F12,-,,8,-,8.0


In [10]:
bam_file = Path("/home/emre/minION_results/MinION_RBC_0902723_sup/Demultiplex_cpp_70_40k_reads/RB02/NB04/alignment_minimap.bam")
template_seq = Path("/home/emre/github_repo/MinION/minION/refseq/hetcpiii_padded.fasta")
ref_name = "HetCPIII"



get_variant_soft(bam_file, template_seq, ref_name, padding = 50)




        80   146  647
A  0.914990  0.0  0.0
C  0.000000  0.0  1.0
T  0.000000  0.0  0.0
G  0.029121  1.0  0.0
-  0.055889  0.0  0.0


{'Variant': 'T80A',
 'Position': array([ 80, 146, 647]),
 'Alignment Probability': 0.9149902666562933,
 'Alignment Count': 34}

In [14]:
demultiplex_folder = Path("/home/emre/minION_results/MinION_RBC_0902723_sup/Demultiplex_cpp_70_40k_reads")
template_seq = Path("/home/emre/github_repo/MinION/minION/refseq/hetcpiii_padded.fasta")
ref_name = "HetCPIII"
barcode_dicts = analyser.get_barcode_dict(demultiplex_folder, "NB", "RB")


variant_df_soft = get_variant_df_soft(demultiplex_folder, template_seq, ref_name, barcode_dicts, merge = True, min_depth= 5, padding=50)

Processing RB03 NB87
Alignment file already exists, skipping alignment and indexing
       84        164       178       211
A  0.00000  0.412497  0.627552  0.413806
C  0.00000  0.000000  0.000000  0.017949
T  0.00000  0.000000  0.000000  0.000000
G  0.58565  0.550829  0.372448  0.546233
-  0.41435  0.036674  0.000000  0.022012
{'Variant': 'A164G_A211G', 'Position': [84, 164, 178, 211], 'Alignment Probability': 0.11058156794895055, 'Alignment Count': 46}
Variant: NB87/RB03 46 0.11058156794895055
Processing RB03 NB03
Alignment file already exists, skipping alignment and indexing
        87        116
A  0.000000  0.010879
C  0.016721  0.937923
T  0.916892  0.040256
G  0.040408  0.000000
-  0.025979  0.010943
{'Variant': 'C87T_T116C', 'Position': [87, 116], 'Alignment Probability': 0.8599735671700774, 'Alignment Count': 94}
Variant: NB03/RB03 94 0.8599735671700774
Processing RB03 NB20
Alignment file already exists, skipping alignment and indexing
        191  520       630
A  0.000000  N

In [17]:
variant_df_soft.to_pickle("/home/emre/github_repo/MinION/results/2_hetcpiii_minion_errorprone/local/variants_SW_soft_40k.pkl")

In [16]:
variant_df_soft

Unnamed: 0,Plate,Well,Position,Variant,Alignment Probability,Alignment Count,count
0,1,A1,,,,,
1,1,A2,"[193, 293, 271]",#PARENT#,0.916895,115,115.0
2,1,A3,"[174, 521, 197]",#PARENT#,0.509708,172,172.0
3,1,A4,"[130, 566, 401]",#PARENT#,0.591522,87,87.0
4,1,A5,"[157, 176, 206, 642, 643, 644]",A107G_C592A_C594A,,135,135.0
...,...,...,...,...,...,...,...
283,3,H8,"[225, 56, 126]",T175C,0.920401,33,33.0
284,3,H9,"[633, 296, 268]",#PARENT#,0.935064,86,86.0
285,3,H10,"[77, 185, 200]",T27A_T135A_T150C,0.800266,102,102.0
286,3,H11,"[157, 233, 259]",A107G_T183C_T209C,0.882437,69,69.0
