In [41]:
from pathlib import Path
import pandas as pd
import numpy as np
from msasim import sailfish as sf
import msastats

from spartaabc.abc_inference import load_data, load_correction_regressors, load_correction_regressor_scores, bias_correction
from spartaabc.utility import get_msa_path, PARAMS_LIST, SUMSTATS_LIST
from spartaabc.aligner_interface import Aligner
from spartaabc.getting_priors.zipf import calc_zip_mom


In [52]:
MAIN_PATH = Path("sanity_check/")
MAIN_PATH.resolve()

distance_metric = "mahal"
top_cutoff = 10000
aligner = "mafft"
correction=True

In [53]:
def get_all_simulated_stats(main_path: Path, correct_alignment_bias=True):

    stats_data = load_data(main_path)
    regressors = load_correction_regressors(main_path, aligner)
    regressor_scores = load_correction_regressor_scores(main_path, aligner)

    params_data = []
    full_stats_data = []
    kept_statistics = []

    for model in  stats_data.keys():
        current_regressors = regressors.get(model, None)
        params_data.append(stats_data[model][PARAMS_LIST])
        if not correct_alignment_bias:
            full_stats_data.append(stats_data[model][SUMSTATS_LIST])
            kept_statistics = range(len(SUMSTATS_LIST))
            continue

        if current_regressors is not None:
            temp_df, kept_statistics = bias_correction(current_regressors, stats_data[model], regressor_scores, 0.85)
            full_stats_data.append(temp_df)

    params_data = pd.concat(params_data)
    full_stats_data = pd.concat(full_stats_data)

    return params_data, full_stats_data, kept_statistics


In [54]:
params_df, stats_df, kept_statistics = get_all_simulated_stats(MAIN_PATH, correct_alignment_bias=correction)

In [55]:
#setting up tree and simulated msa + model file
with open(MAIN_PATH / "msa.tree",'w') as f:
    f.write("(1159908.H9BR28_9NIDO:1.12098,(572290.B6VDY1_THCOV:2.1089e-06,(572287.B6VDX2_9NIDO:0.328448,(1159906.H9BR03_9NIDO:0.107594,1586324.A0A0E3Y5V9_9NIDO:0.137885):0.333936):0.12119):1.12098);")

indel_rate = 0.01
length_parama = 1.5

sim_protocol = sf.SimProtocol(str(MAIN_PATH / "msa.tree"), root_seq_size=1000,
                                deletion_rate=indel_rate, insertion_rate=indel_rate,
                                insertion_dist=sf.ZipfDistribution(length_parama, 150),
                                deletion_dist=sf.ZipfDistribution(length_parama, 150), seed=42)

print(calc_zip_mom(a=length_parama, truncation=150))

sim = sf.Simulator(simProtocol=sim_protocol, simulation_type=sf.SIMULATION_TYPE.PROTEIN)
sim.set_replacement_model(sf.MODEL_CODES.WAG, gamma_parameters_alpha=3.0, gamma_parameters_categories=4)
msa = sim()
if correction:
    print("hey")
    msa_str = msa.get_msa().replace("-","")
    (MAIN_PATH / "msa.fasta").write_text(msa_str)
    aligner = Aligner(aligner="MAFFT")
    aligner.set_input_file(str(MAIN_PATH / "msa.fasta"))
    msa_str_realigned = aligner.get_realigned_msa()
    (MAIN_PATH / "msa.fasta").write_text(msa_str_realigned)
else:
    msa.write_msa(str((MAIN_PATH / "msa.fasta")))

with open(MAIN_PATH / "msa.bestModel", 'w') as f:
    f.write("WAG+G4m{3.0},")

9.421016317602982
hey


In [56]:

def calc_distances(params_data, full_stats_data, kept_statistics, main_path):
    MSA_PATH = get_msa_path(main_path)
    empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)
    empirical_stats = [empirical_stats[i] for i in kept_statistics]

    calculated_distances = None

    if distance_metric == "mahal":
        cov = np.cov(full_stats_data.T)
        cov = cov + np.eye(len(cov))*1e-4
        inv_covmat = np.linalg.inv(cov)
        u_minus_v = empirical_stats-full_stats_data
        left = np.dot(u_minus_v, inv_covmat)
        calculated_distances = np.sqrt(np.sum(u_minus_v*left, axis=1))
    if distance_metric == "euclid":
        weights = 1/(full_stats_data.std(axis=0) + 0.001)
        calculated_distances = np.sum(weights*(full_stats_data - empirical_stats)**2, axis=1)

    top_stats = pd.DataFrame()
    top_stats["distances"] = calculated_distances
    top_stats[PARAMS_LIST] = params_data

    # full_stats_data["distances"] = calculated_distances
    # full_stats_data[PARAMS_LIST] = params_data

    top_stats = top_stats.nsmallest(top_cutoff, "distances")
    return top_stats

In [57]:
top_params = calc_distances(params_df, stats_df, kept_statistics, MAIN_PATH)

In [58]:
vectorized = np.vectorize(calc_zip_mom)
top_params["mean_insertion_length"] = vectorized(top_params["length_param_insertion"], truncation=150)
top_params["mean_deletion_length"] = vectorized(top_params["length_param_deletion"], truncation=150)
top_params

Unnamed: 0,distances,root_length,insertion_rate,deletion_rate,length_param_insertion,length_param_deletion,mean_insertion_length,mean_deletion_length
39348,10.454603,799.0,0.056940,0.056940,1.550405,1.550405,8.419105,8.419105
17038,10.817194,1162.0,0.055491,0.055491,1.604294,1.604294,7.473293,7.473293
210387,10.906463,909.0,0.043639,0.043639,1.705613,1.705613,6.004703,6.004703
254009,11.024341,893.0,0.051285,0.051285,1.487963,1.487963,9.678124,9.678124
488464,11.035365,864.0,0.027683,0.027683,1.543027,1.543027,8.558441,8.558441
...,...,...,...,...,...,...,...,...
168063,13.761338,819.0,0.042426,0.042426,1.424696,1.424696,11.149574,11.149574
38506,13.761423,1031.0,0.037311,0.037311,1.810350,1.810350,4.841885,4.841885
310104,13.761434,1129.0,0.007108,0.007108,1.446616,1.446616,10.616384,10.616384
37121,13.761523,831.0,0.021288,0.021288,1.065208,1.065208,23.778785,23.778785


In [59]:
MSA_PATH = get_msa_path(MAIN_PATH)

empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)
empirical_stats

[9.731707317073171,
 1329.0,
 1137.0,
 939.0,
 164.0,
 48.0,
 31.0,
 21.0,
 64.0,
 10.016129032258064,
 62.0,
 7.0,
 3.0,
 5.0,
 3.0,
 0.0,
 7.0,
 1.0,
 2.0,
 4.0,
 9.0,
 2.0,
 9.0,
 857.0,
 55.0,
 41.0,
 331.0]

In [60]:
def get_all_model_params(models, abc_params: pd.DataFrame):
    top_cutoff = 100

    filtered_top_params = abc_params[abc_params["insertion_rate"] == abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")

    root_length = int(filtered_top_params["root_length"].mean())
    R_ID = float(filtered_top_params["insertion_rate"].mean())
    A_ID = float(filtered_top_params["mean_insertion_length"].mean())
    models["sim"]["root_lengths"].append(root_length)
    models["sim"]["indel_rates"].append( R_ID)
    models["sim"]["mean_indel_length"].append(A_ID)
    
    filtered_top_params = abc_params[abc_params["insertion_rate"] != abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")
    root_length = int(filtered_top_params["root_length"].mean())
    R_I = float(filtered_top_params["insertion_rate"].mean())
    R_D = float(filtered_top_params["deletion_rate"].mean())
    A_I = float(filtered_top_params["mean_insertion_length"].mean())
    A_D = float(filtered_top_params["mean_deletion_length"].mean())
    models["rim"]["root_lengths"].append(( root_length))
    models["rim"]["insertion_rates"].append(R_I)
    models["rim"]["deletion_rates"].append(R_D)
    models["rim"]["mean_insertion_length"].append(A_I)
    models["rim"]["mean_deletion_length"].append(A_D)

    return models


In [61]:
models = {"rim":
        {
            "root_lengths": [],
            "insertion_rates": [],
            "deletion_rates": [],
            "mean_insertion_length": [],
            "mean_deletion_length": []
        },
        "sim":
        {
            "root_lengths": [],
            "indel_rates": [],
            "mean_indel_length": [],
        }
    }

get_all_model_params(models, top_params)

{'rim': {'root_lengths': [1117],
  'insertion_rates': [0.027474091404982377],
  'deletion_rates': [0.009028413651280433],
  'mean_insertion_length': [7.893971814542328],
  'mean_deletion_length': [4.7808333538271786]},
 'sim': {'root_lengths': [918],
  'indel_rates': [0.029769739000078482],
  'mean_indel_length': [11.821691903353464]}}