In [39]:
from pathlib import Path
import pandas as pd
import numpy as np
from msasim import sailfish as sf
import msastats

from spartaabc.abc_inference import load_data, load_correction_regressors, load_correction_regressor_scores, bias_correction
from spartaabc.utility import get_msa_path, PARAMS_LIST, SUMSTATS_LIST


In [6]:
MAIN_PATH = Path("sanity_check/")
MAIN_PATH.resolve()

PosixPath('/home/elyalab/Dev/failed_syncs/SpartaV2/benchmark/sanity_check')

In [17]:
#setting up tree and simulated msa + model file
with open(MAIN_PATH / "msa.tree",'w') as f:
    f.write("(1159908.H9BR28_9NIDO:1.12098,(572290.B6VDY1_THCOV:2.1089e-06,(572287.B6VDX2_9NIDO:0.328448,(1159906.H9BR03_9NIDO:0.107594,1586324.A0A0E3Y5V9_9NIDO:0.137885):0.333936):0.12119):1.12098);")

sim_protocol = sf.SimProtocol(str(MAIN_PATH / "msa.tree"), root_seq_size=1000,
                                deletion_rate=0.001, insertion_rate=0.001,
                                insertion_dist=sf.ZipfDistribution(2.0, 150),
                                deletion_dist=sf.ZipfDistribution(2.0, 150), seed=42)

sim = sf.Simulator(simProtocol=sim_protocol, simulation_type=sf.SIMULATION_TYPE.PROTEIN)
sim.set_replacement_model(sf.MODEL_CODES.WAG, gamma_parameters_alpha=3.0, gamma_parameters_categories=4)
msa = sim()
msa.write_msa(str(MAIN_PATH / "msa.fasta"))

with open(MAIN_PATH / "msa.bestModel", 'w') as f:
    f.write("WAG+G4m{3.0},")

In [21]:
# check sampled data

rim_df = pd.read_parquet(MAIN_PATH / "full_data_zipf_rim.parquet.gzip")
sim_df = pd.read_parquet(MAIN_PATH / "full_data_zipf_sim.parquet.gzip")

In [74]:
sim_df[abs(sim_df["insertion_rate"] -0.001) < 0.0001]

Unnamed: 0,root_length,insertion_rate,deletion_rate,length_param_insertion,length_param_deletion,SS_0,SS_1,SS_2,SS_3,SS_4,...,SS_17,SS_18,SS_19,SS_20,SS_21,SS_22,SS_23,SS_24,SS_25,SS_26
27,962.0,0.001076,0.001076,1.460422,1.460422,12.142857,965.0,962.0,922.0,14.0,...,0.0,0.0,0.0,1.0,0.0,1.0,918.0,6.0,0.0,41.0
100,1044.0,0.001023,0.001023,1.856858,1.856858,2.000000,1046.0,1039.0,1034.0,20.0,...,0.0,0.0,1.0,2.0,0.0,0.0,1027.0,12.0,0.0,7.0
123,873.0,0.001028,0.001028,1.241992,1.241992,2.600000,892.0,890.0,874.0,10.0,...,1.0,0.0,0.0,1.0,0.0,0.0,872.0,18.0,0.0,2.0
196,1031.0,0.001098,0.001098,2.776364,2.776364,1.000000,1036.0,1032.0,1030.0,26.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1026.0,4.0,1.0,5.0
219,792.0,0.001040,0.001040,1.720163,1.720163,1.333333,793.0,793.0,789.0,6.0,...,1.0,0.0,0.0,0.0,0.0,0.0,788.0,4.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
499894,1075.0,0.001050,0.001050,2.008938,2.008938,7.166667,1225.0,1221.0,1076.0,24.0,...,0.0,0.0,0.0,1.0,0.0,0.0,1070.0,148.0,2.0,5.0
499918,882.0,0.001002,0.001002,1.746916,1.746916,2.550000,896.0,888.0,883.0,20.0,...,0.0,0.0,1.0,0.0,0.0,1.0,879.0,5.0,0.0,10.0
499927,895.0,0.001032,0.001032,2.778770,2.778770,1.000000,896.0,895.0,891.0,16.0,...,0.0,0.0,0.0,0.0,0.0,0.0,890.0,2.0,1.0,3.0
499992,1013.0,0.001083,0.001083,2.074266,2.074266,1.285714,1019.0,1018.0,1012.0,14.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1010.0,6.0,0.0,3.0


In [78]:
distance_metric = "mahal"
top_cutoff = 10000
aligner = "mafft"


In [79]:
def get_top_params(main_path: Path):
    MSA_PATH = get_msa_path(main_path)

    empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)

    stats_data = load_data(main_path)
    regressors = load_correction_regressors(main_path, aligner)
    regressor_scores = load_correction_regressor_scores(main_path, aligner)

    params_data = []
    full_stats_data = []
    for model in  stats_data.keys():
        current_regressors = regressors.get(model, None)
        params_data.append(stats_data[model][PARAMS_LIST])
        # full_stats_data.append(stats_data[model][SUMSTATS_LIST])
        if current_regressors is not None:
            temp_df, kept_statistics = bias_correction(current_regressors, stats_data[model], regressor_scores)
            full_stats_data.append(temp_df)

    # kept_statistics = range(len(SUMSTATS_LIST))
    empirical_stats = [empirical_stats[i] for i in kept_statistics]

    params_data = pd.concat(params_data)
    full_stats_data = pd.concat(full_stats_data)

    calculated_distances = None

    if distance_metric == "mahal":
        cov = np.cov(full_stats_data.T)
        cov = cov + np.eye(len(cov))*1e-4
        inv_covmat = np.linalg.inv(cov)
        u_minus_v = empirical_stats-full_stats_data
        left = np.dot(u_minus_v, inv_covmat)
        calculated_distances = np.sqrt(np.sum(u_minus_v*left, axis=1))
    if distance_metric == "euclid":
        weights = 1/(full_stats_data.std(axis=0) + 0.001)
        calculated_distances = np.sum(weights*(full_stats_data - empirical_stats)**2, axis=1)

    full_stats_data["distances"] = calculated_distances
    full_stats_data[PARAMS_LIST] = params_data

    top_stats = full_stats_data.nsmallest(top_cutoff, "distances")
    return top_stats

In [80]:
top_params = get_top_params(MAIN_PATH)

In [None]:
MSA_PATH = get_msa_path(MAIN_PATH)

empirical_stats = msastats.calculate_fasta_stats(MSA_PATH)


In [83]:
top_params

Unnamed: 0,SS_0,SS_1,SS_2,SS_3,SS_4,SS_5,SS_6,SS_7,SS_8,SS_9,...,SS_23,SS_24,SS_25,SS_26,distances,root_length,insertion_rate,deletion_rate,length_param_insertion,length_param_deletion
279232,2.551372,1122.747546,1075.728428,1047.886814,91.139998,37.894898,20.356068,8.423993,25.185252,2.659302,...,1003.981634,49.488045,21.219972,41.119968,4.275463,1063.0,0.012249,0.011381,2.529276,1.889769
456798,2.383437,1141.240555,1075.740754,1035.883192,144.997696,68.852202,30.251046,12.958945,31.034503,2.655107,...,963.205689,71.512877,38.925208,62.318564,4.519022,1105.0,0.006024,0.026029,2.532696,2.566448
334180,5.177046,1006.391617,964.923303,896.977847,81.307786,28.562266,16.850400,5.790757,28.979728,5.159926,...,842.276800,61.884845,11.789323,77.632385,4.685884,1005.0,0.003471,0.021462,2.426363,1.983706
149169,2.991218,1118.071992,1084.582857,1013.916042,65.160554,29.485451,15.488803,5.189814,15.446109,3.688437,...,983.458851,77.973933,18.005624,33.455427,4.721594,1091.0,0.003089,0.012319,2.325416,2.260934
58784,3.084539,1103.516356,1053.757843,1005.913861,104.662994,49.198413,19.319990,7.336054,29.743451,3.282394,...,949.492151,69.064559,19.115433,43.513386,4.727738,1002.0,0.015046,0.008834,2.472696,2.036762
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
443789,6.952413,1078.010599,1041.565042,802.118372,82.941602,27.872983,15.564346,5.784109,29.724324,8.835317,...,755.482579,186.981676,64.908293,51.588106,7.181706,844.0,0.006949,0.003986,1.744114,1.663407
480744,7.087038,1109.949980,1088.591007,1021.932830,39.034549,12.573888,7.471378,2.755955,16.565812,6.381774,...,1000.996656,34.249810,1.139499,60.517949,7.181721,1052.0,0.000529,0.004342,1.190192,1.803435
341945,12.828216,1117.612848,1020.175078,825.116273,93.926410,16.812822,15.175226,7.866227,52.091122,11.108244,...,773.760076,57.086261,9.938027,240.073170,7.181725,923.0,0.016381,0.003666,1.658960,1.078781
212184,3.706679,935.009668,910.069955,895.975971,47.582147,18.992108,6.474819,3.364690,18.277703,3.710754,...,864.710881,31.676529,3.628262,27.465262,7.181748,932.0,0.000785,0.004870,2.603404,1.429255


In [81]:
def get_all_model_params(models, abc_params: pd.DataFrame):
    top_cutoff = 100

    filtered_top_params = abc_params[abc_params["insertion_rate"] == abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")

    root_length = int(filtered_top_params["root_length"].mean())
    R_ID = float(filtered_top_params["insertion_rate"].mean())
    A_ID = float(filtered_top_params["length_param_insertion"].mean())
    models["sim"]["root_lengths"].append(root_length)
    models["sim"]["indel_rates"].append( R_ID)
    models["sim"]["indel_length_params"].append(  A_ID )
    
    filtered_top_params = abc_params[abc_params["insertion_rate"] != abc_params["deletion_rate"]]
    filtered_top_params = filtered_top_params.nsmallest(top_cutoff, "distances")
    root_length = int(filtered_top_params["root_length"].mean())
    R_I = float(filtered_top_params["insertion_rate"].mean())
    R_D = float(filtered_top_params["deletion_rate"].mean())
    A_I = float(filtered_top_params["length_param_insertion"].mean())
    A_D = float(filtered_top_params["length_param_deletion"].mean())
    models["rim"]["root_lengths"].append(( root_length))
    models["rim"]["insertion_rates"].append(R_I)
    models["rim"]["deletion_rates"].append(R_D)
    models["rim"]["insertion_length_params"].append( A_I)
    models["rim"]["deletion_length_params"].append(A_D)

    return models


In [82]:
models = {"rim":
        {
            "root_lengths": [],
            "insertion_rates": [],
            "deletion_rates": [],
            "insertion_length_params": [],
            "deletion_length_params": []
        },
        "sim":
        {
            "root_lengths": [],
            "indel_rates": [],
            "indel_length_params": [],
        }
    }
get_all_model_params(models, top_params)

{'rim': {'root_lengths': [1004],
  'insertion_rates': [0.012456799045237488],
  'deletion_rates': [0.021385237865473593],
  'insertion_length_params': [2.044829117701512],
  'deletion_length_params': [2.0784146162350736]},
 'sim': {'root_lengths': [1037],
  'indel_rates': [0.01656914863072278],
  'indel_length_params': [2.061963052405511]}}