In [39]:
from pathlib import Path
import pandas as pd
import numpy as np
from msasim import sailfish as sf
import msastats

from spartaabc.abc_inference import load_data, load_correction_regressors, load_correction_regressor_scores, bias_correction
from spartaabc.utility import get_msa_path, PARAMS_LIST, SUMSTATS_LIST


In [6]:
MAIN_PATH = Path("sanity_check/")
MAIN_PATH.resolve()

PosixPath('/home/elyalab/Dev/failed_syncs/SpartaV2/benchmark/sanity_check')

In [17]:
#setting up tree and simulated msa + model file
with open(MAIN_PATH / "msa.tree",'w') as f:
    f.write("(1159908.H9BR28_9NIDO:1.12098,(572290.B6VDY1_THCOV:2.1089e-06,(572287.B6VDX2_9NIDO:0.328448,(1159906.H9BR03_9NIDO:0.107594,1586324.A0A0E3Y5V9_9NIDO:0.137885):0.333936):0.12119):1.12098);")

sim_protocol = sf.SimProtocol(str(MAIN_PATH / "msa.tree"), root_seq_size=1000,
                                deletion_rate=0.001, insertion_rate=0.001,
                                insertion_dist=sf.ZipfDistribution(2.0, 150),
                                deletion_dist=sf.ZipfDistribution(2.0, 150), seed=42)

sim = sf.Simulator(simProtocol=sim_protocol, simulation_type=sf.SIMULATION_TYPE.PROTEIN)
sim.set_replacement_model(sf.MODEL_CODES.WAG, gamma_parameters_alpha=3.0, gamma_parameters_categories=4)
msa = sim()
msa.write_msa(str(MAIN_PATH / "msa.fasta"))

with open(MAIN_PATH / "msa.bestModel", 'w') as f:
    f.write("WAG+G4m{3.0},")

In [21]:
# check sampled data

rim_df = pd.read_parquet(MAIN_PATH / "full_data_zipf_rim.parquet.gzip")
sim_df = pd.read_parquet(MAIN_PATH / "full_data_zipf_sim.parquet.gzip")

In [29]:
sim_df[abs(sim_df["insertion_rate"] -0.001) < 0.0001]

Unnamed: 0,root_length,insertion_rate,deletion_rate,length_param_insertion,length_param_deletion,SS_0,SS_1,SS_2,SS_3,SS_4,...,SS_17,SS_18,SS_19,SS_20,SS_21,SS_22,SS_23,SS_24,SS_25,SS_26
27,962.0,0.001076,0.001076,1.460422,1.460422,12.142857,965.0,962.0,922.0,14.0,...,0.0,0.0,0.0,1.0,0.0,1.0,918.0,6.0,0.0,41.0
100,1044.0,0.001023,0.001023,1.856858,1.856858,2.000000,1046.0,1039.0,1034.0,20.0,...,0.0,0.0,1.0,2.0,0.0,0.0,1027.0,12.0,0.0,7.0
123,873.0,0.001028,0.001028,1.241992,1.241992,2.600000,892.0,890.0,874.0,10.0,...,1.0,0.0,0.0,1.0,0.0,0.0,872.0,18.0,0.0,2.0
196,1031.0,0.001098,0.001098,2.776364,2.776364,1.000000,1036.0,1032.0,1030.0,26.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1026.0,4.0,1.0,5.0
219,792.0,0.001040,0.001040,1.720163,1.720163,1.333333,793.0,793.0,789.0,6.0,...,1.0,0.0,0.0,0.0,0.0,0.0,788.0,4.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
499894,1075.0,0.001050,0.001050,2.008938,2.008938,7.166667,1225.0,1221.0,1076.0,24.0,...,0.0,0.0,0.0,1.0,0.0,0.0,1070.0,148.0,2.0,5.0
499918,882.0,0.001002,0.001002,1.746916,1.746916,2.550000,896.0,888.0,883.0,20.0,...,0.0,0.0,1.0,0.0,0.0,1.0,879.0,5.0,0.0,10.0
499927,895.0,0.001032,0.001032,2.778770,2.778770,1.000000,896.0,895.0,891.0,16.0,...,0.0,0.0,0.0,0.0,0.0,0.0,890.0,2.0,1.0,3.0
499992,1013.0,0.001083,0.001083,2.074266,2.074266,1.285714,1019.0,1018.0,1012.0,14.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1010.0,6.0,0.0,3.0


In [69]:
distance_metric = "euclid"
top_cutoff = 10000
aligner = "mafft"


In [70]:
def get_top_params(main_path: Path):
    MSA_PATH = get_msa_path(main_path)

    empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)

    stats_data = load_data(main_path)
    regressors = load_correction_regressors(main_path, aligner)
    regressor_scores = load_correction_regressor_scores(main_path, aligner)

    params_data = []
    full_stats_data = []
    for model in  stats_data.keys():
        current_regressors = regressors.get(model, None)
        params_data.append(stats_data[model][PARAMS_LIST])
        # full_stats_data.append(stats_data[model][SUMSTATS_LIST])
        if current_regressors is not None:
            temp_df, kept_statistics = bias_correction(current_regressors, stats_data[model], regressor_scores)
            full_stats_data.append(temp_df)

    # kept_statistics = range(len(SUMSTATS_LIST))
    empirical_stats = [empirical_stats[i] for i in kept_statistics]

    params_data = pd.concat(params_data)
    full_stats_data = pd.concat(full_stats_data)

    calculated_distances = None

    if distance_metric == "mahal":
        cov = np.cov(full_stats_data.T)
        cov = cov + np.eye(len(cov))*1e-4
        inv_covmat = np.linalg.inv(cov)
        u_minus_v = empirical_stats-full_stats_data
        left = np.dot(u_minus_v, inv_covmat)
        calculated_distances = np.sqrt(np.sum(u_minus_v*left, axis=1))
    if distance_metric == "euclid":
        weights = 1/(full_stats_data.std(axis=0) + 0.001)
        calculated_distances = np.sum(weights*(full_stats_data - empirical_stats)**2, axis=1)

    full_stats_data["distances"] = calculated_distances
    full_stats_data[PARAMS_LIST] = params_data

    top_stats = full_stats_data.nsmallest(top_cutoff, "distances")
    return top_stats

In [71]:
top_params = get_top_params(MAIN_PATH)

In [72]:
def get_all_model_params(models, abc_params: pd.DataFrame):
    top_cutoff = 100

    filtered_top_params = abc_params[abc_params["insertion_rate"] == abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")

    root_length = int(filtered_top_params["root_length"].mean())
    R_ID = float(filtered_top_params["insertion_rate"].mean())
    A_ID = float(filtered_top_params["length_param_insertion"].mean())
    models["sim"]["root_lengths"].append(root_length)
    models["sim"]["indel_rates"].append( R_ID)
    models["sim"]["indel_length_params"].append(  A_ID )
    
    filtered_top_params = abc_params[abc_params["insertion_rate"] != abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")
    root_length = int(filtered_top_params["root_length"].mean())
    R_I = float(filtered_top_params["insertion_rate"].mean())
    R_D = float(filtered_top_params["deletion_rate"].mean())
    A_I = float(filtered_top_params["length_param_insertion"].mean())
    A_D = float(filtered_top_params["length_param_deletion"].mean())
    models["rim"]["root_lengths"].append(( root_length))
    models["rim"]["insertion_rates"].append(R_I)
    models["rim"]["deletion_rates"].append(R_D)
    models["rim"]["insertion_length_params"].append( A_I)
    models["rim"]["deletion_length_params"].append(A_D)

    return models


In [73]:
models = {"rim":
        {
            "root_lengths": [],
            "insertion_rates": [],
            "deletion_rates": [],
            "insertion_length_params": [],
            "deletion_length_params": []
        },
        "sim":
        {
            "root_lengths": [],
            "indel_rates": [],
            "indel_length_params": [],
        }
    }
get_all_model_params(models, top_params)

{'rim': {'root_lengths': [1000],
  'insertion_rates': [0.0008963493047142717],
  'deletion_rates': [0.001102148885355063],
  'insertion_length_params': [1.8863688542841845],
  'deletion_length_params': [1.625021389814165]},
 'sim': {'root_lengths': [992],
  'indel_rates': [0.001348259291674701],
  'indel_length_params': [1.425386731894891]}}