In [40]:
from pathlib import Path
import pandas as pd
import numpy as np
from msasim import sailfish as sf
import msastats

from spartaabc.abc_inference import load_data, load_correction_regressors, load_correction_regressor_scores, bias_correction
from spartaabc.utility import get_msa_path, PARAMS_LIST, SUMSTATS_LIST
from spartaabc.aligner_interface import Aligner
from spartaabc.getting_priors.zipf import calc_zip_mom


In [91]:
MAIN_PATH = Path("sanity_check/")
MAIN_PATH.resolve()

distance_metric = "mahal"
top_cutoff = 10000
aligner = "mafft"
correction=False

In [92]:
def get_all_simulated_stats(main_path: Path, correct_alignment_bias=True):

    stats_data = load_data(main_path)
    regressors = load_correction_regressors(main_path, aligner)
    regressor_scores = load_correction_regressor_scores(main_path, aligner)

    params_data = []
    full_stats_data = []
    kept_statistics = []

    for model in  stats_data.keys():
        current_regressors = regressors.get(model, None)
        params_data.append(stats_data[model][PARAMS_LIST])
        if not correct_alignment_bias:
            full_stats_data.append(stats_data[model][SUMSTATS_LIST])
            kept_statistics = range(len(SUMSTATS_LIST))
            continue

        if current_regressors is not None:
            temp_df, kept_statistics = bias_correction(current_regressors, stats_data[model], regressor_scores, 0.85)
            full_stats_data.append(temp_df)

    params_data = pd.concat(params_data)
    full_stats_data = pd.concat(full_stats_data)

    return params_data, full_stats_data, kept_statistics


In [93]:
params_df, stats_df, kept_statistics = get_all_simulated_stats(MAIN_PATH, correct_alignment_bias=correction)

In [95]:
#setting up tree and simulated msa + model file
with open(MAIN_PATH / "msa.tree",'w') as f:
    f.write("(1159908.H9BR28_9NIDO:1.12098,(572290.B6VDY1_THCOV:2.1089e-06,(572287.B6VDX2_9NIDO:0.328448,(1159906.H9BR03_9NIDO:0.107594,1586324.A0A0E3Y5V9_9NIDO:0.137885):0.333936):0.12119):1.12098);")

indel_rate = 0.001
length_parama = 1.5

sim_protocol = sf.SimProtocol(str(MAIN_PATH / "msa.tree"), root_seq_size=1000,
                                deletion_rate=indel_rate, insertion_rate=indel_rate,
                                insertion_dist=sf.ZipfDistribution(length_parama, 150),
                                deletion_dist=sf.ZipfDistribution(length_parama, 150), seed=42)

print(calc_zip_mom(a=length_parama, truncation=150))

sim = sf.Simulator(simProtocol=sim_protocol, simulation_type=sf.SIMULATION_TYPE.PROTEIN)
sim.set_replacement_model(sf.MODEL_CODES.WAG, gamma_parameters_alpha=3.0, gamma_parameters_categories=4)
msa = sim()
if correction:
    msa_str = msa.get_msa().replace("-","")
    # msa.write_msa(str(MAIN_PATH / "msa.fasta"))
    (MAIN_PATH / "msa.fasta").write_text(msa_str)
    aligner = Aligner(aligner="MAFFT")
    aligner.set_input_file(str(MAIN_PATH / "msa.fasta"))
    msa_str_realigned = aligner.get_realigned_msa()
    (MAIN_PATH / "msa.fasta").write_text(msa_str_realigned)
else:
    msa.write_msa(str((MAIN_PATH / "msa.fasta")))

with open(MAIN_PATH / "msa.bestModel", 'w') as f:
    f.write("WAG+G4m{3.0},")

9.421016317602982


In [96]:

def calc_distances(params_data, full_stats_data, kept_statistics, main_path):
    MSA_PATH = get_msa_path(main_path)
    empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)
    empirical_stats = [empirical_stats[i] for i in kept_statistics]
    print(len(empirical_stats))
    calculated_distances = None

    if distance_metric == "mahal":
        cov = np.cov(full_stats_data.T)
        cov = cov + np.eye(len(cov))*1e-4
        inv_covmat = np.linalg.inv(cov)
        u_minus_v = empirical_stats-full_stats_data
        left = np.dot(u_minus_v, inv_covmat)
        calculated_distances = np.sqrt(np.sum(u_minus_v*left, axis=1))
    if distance_metric == "euclid":
        weights = 1/(full_stats_data.std(axis=0) + 0.001)
        calculated_distances = np.sum(weights*(full_stats_data - empirical_stats)**2, axis=1)

    full_stats_data["distances"] = calculated_distances
    full_stats_data[PARAMS_LIST] = params_data

    top_stats = full_stats_data.nsmallest(top_cutoff, "distances")
    return top_stats

In [97]:
top_params = calc_distances(params_df, stats_df, kept_statistics, MAIN_PATH)

27


In [98]:
MSA_PATH = get_msa_path(MAIN_PATH)

empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)
empirical_stats

[6.157894736842105,
 1025.0,
 1020.0,
 975.0,
 19.0,
 5.0,
 10.0,
 1.0,
 3.0,
 8.2,
 10.0,
 1.0,
 0.0,
 1.0,
 2.0,
 2.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 947.0,
 47.0,
 26.0,
 3.0]

In [99]:
top_params

Unnamed: 0,SS_0,SS_1,SS_2,SS_3,SS_4,SS_5,SS_6,SS_7,SS_8,SS_9,...,SS_23,SS_24,SS_25,SS_26,distances,root_length,insertion_rate,deletion_rate,length_param_insertion,length_param_deletion
317644,5.880000,1067.0,1052.0,997.0,25.0,8.0,11.0,1.0,5.0,7.461538,...,970.0,77.0,5.0,15.0,0.986407,1068.0,0.000416,0.003198,1.731750,1.398057
431004,2.650000,1002.0,998.0,974.0,20.0,10.0,9.0,0.0,1.0,4.111111,...,965.0,29.0,4.0,4.0,1.110994,1000.0,0.000208,0.001729,1.184355,2.322123
270072,6.115385,1062.0,1038.0,1014.0,26.0,13.0,7.0,1.0,5.0,7.600000,...,986.0,47.0,2.0,27.0,1.150933,1033.0,0.002569,0.001025,2.162232,1.191474
494386,5.866667,1145.0,1144.0,1097.0,15.0,5.0,5.0,1.0,4.0,7.444444,...,1078.0,48.0,18.0,1.0,1.152396,1105.0,0.000617,0.001906,1.060157,1.862203
66396,3.428571,954.0,949.0,933.0,21.0,4.0,9.0,1.0,7.0,3.777778,...,920.0,16.0,8.0,10.0,1.153018,945.0,0.001117,0.001117,1.230053,1.230053
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
227725,4.615385,1032.0,1015.0,964.0,39.0,15.0,11.0,0.0,13.0,4.687500,...,957.0,38.0,3.0,34.0,1.978495,1014.0,0.001983,0.001983,1.872772,1.872772
140283,6.052632,1011.0,983.0,930.0,57.0,22.0,16.0,1.0,18.0,5.190476,...,902.0,16.0,13.0,63.0,1.978495,934.0,0.003480,0.003480,1.859856,1.859856
340241,1.962963,933.0,927.0,918.0,27.0,17.0,3.0,1.0,6.0,2.142857,...,903.0,20.0,3.0,6.0,1.978512,922.0,0.000879,0.001991,1.559805,2.171172
343401,1.555556,1123.0,1121.0,1120.0,9.0,4.0,5.0,0.0,0.0,1.666667,...,1118.0,2.0,0.0,3.0,1.978521,1122.0,0.000442,0.001331,2.728296,2.510324


In [106]:
def get_all_model_params(models, abc_params: pd.DataFrame):
    top_cutoff = 2

    filtered_top_params = abc_params[abc_params["insertion_rate"] == abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")

    root_length = int(filtered_top_params["root_length"].mean())
    R_ID = float(filtered_top_params["insertion_rate"].mean())
    A_ID = float(filtered_top_params["length_param_insertion"].mean())
    models["sim"]["root_lengths"].append(root_length)
    models["sim"]["indel_rates"].append( R_ID)
    models["sim"]["indel_length_params"].append(calc_zip_mom(A_ID, truncation=150))
    
    filtered_top_params = abc_params[abc_params["insertion_rate"] != abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")
    root_length = int(filtered_top_params["root_length"].mean())
    R_I = float(filtered_top_params["insertion_rate"].mean())
    R_D = float(filtered_top_params["deletion_rate"].mean())
    A_I = float(filtered_top_params["length_param_insertion"].mean())
    A_D = float(filtered_top_params["length_param_deletion"].mean())
    models["rim"]["root_lengths"].append(( root_length))
    models["rim"]["insertion_rates"].append(R_I)
    models["rim"]["deletion_rates"].append(R_D)
    models["rim"]["insertion_length_params"].append(calc_zip_mom(a=A_I, truncation=150))
    models["rim"]["deletion_length_params"].append(calc_zip_mom(a=A_D, truncation=150))

    return models


In [107]:
models = {"rim":
        {
            "root_lengths": [],
            "insertion_rates": [],
            "deletion_rates": [],
            "insertion_length_params": [],
            "deletion_length_params": []
        },
        "sim":
        {
            "root_lengths": [],
            "indel_rates": [],
            "indel_length_params": [],
        }
    }
get_all_model_params(models, top_params)

{'rim': {'root_lengths': [1034],
  'insertion_rates': [0.0003123228186426856],
  'deletion_rates': [0.002463769653182689],
  'insertion_length_params': [10.348165530081221],
  'deletion_length_params': [4.393464454253741]},
 'sim': {'root_lengths': [961],
  'indel_rates': [0.0017406547941462337],
  'indel_length_params': [8.261644188752326]}}