In [2]:
using DelimitedFiles
using StatsBase 
using Random
using LinearAlgebra
using Plots
using Statistics
using LaTeXStrings
using Printf
using Measures
using Distributions: Uniform
rng = Random.MersenneTwister(1234);

In [3]:
include("../src/common/common_analysis.jl");
include("../src/HIV/MPL_Bezier_categorical.jl");
include("../src/HIV/MPL_categorical.jl");
include("../src/HIV/correction_for_Bezier_interpolation.jl");

In [4]:
# Set the keys for the input data 
name_keys = ["700010470", "700010077", "700010058", "700010040", "700010607", 
          "706010164", "705010198", "705010185", "705010162", "704010042", 
          "703010256", "703010159", "703010131"];
n_polys = ["3", "5"];

In [6]:
#---------------------------#
q = 5 # Number of maximum states 
gamma, mu = 10,1 # Regularization and mutation parameter.
n_poly_max = 2
pseudo_count = 0 # This is better to use zero to infer selection coefficient accurately.
#---------------------------#

#---------------------------#
diff_freq_thresh = 0.7
time_interval_threshold = 50
#---------------------------#

#------------------------------------------------------#
fname_mutation_file = "../data/HIV/Zanini-extended.dat" 
fname_dir = "../data/HIV/"
muMat = readdlm(fname_mutation_file);
output_Cov_dirft = false
#------------------------------------------------------#

C_diag_with_mid_point = []
C_Offdiag_with_mid_point = []
C_diag_original = []
C_tot_B = []
C_tot_B_original = []
C_tot_positive, C_tot_negative = [], []
C_Offdiag_with_mid_point_positive, C_Offdiag_with_mid_point_negative = [], []
for ID_seq in 1:size(name_keys,1)
    for ID_poly in 1:n_poly_max
        fname_in = fname_dir * name_keys[ID_seq] * "-" * string(n_polys[ID_poly]) * "-poly-seq2state.dat"
        @printf("\nfname = %s\n", name_keys[ID_seq] * "-" * string(n_polys[ID_poly]))
        (q,L,N,MPLseq_raw,specific_times) = get_MPLseq(fname_in)
        unique_time = unique(specific_times)
        #@printf("pseudo count = %.3f\n", pseudo_count) # Better to use zero for accurate inference.
        (N, qL) = size(MPLseq_raw[:, 4:end]);
        #----------------------------- Set polymorophic sites ---------------------------------------#
        number_of_states = [length(unique(MPLseq_raw[:, n])) for n in 4:(q*L+3)]
        polymorphic_sites = number_of_states .> 1;
        pushfirst!(polymorphic_sites, true)
        pushfirst!(polymorphic_sites, true)
        pushfirst!(polymorphic_sites, true);
        MPLseq_raw_poly = MPLseq_raw[:, polymorphic_sites];
        poly_seq_len = count(polymorphic_sites[4:end])
        (polymorphic_positions, polymorphic_states_set, original_index_to_polymorophic_index) = 
            get_polymorophic_index_set(q, L, polymorphic_sites);
                    scaling_psc = get_polymorphic_scaling_for_pseudo_count(polymorphic_states_set);
        #----------------------------- Computation of Bezier MPL -------------------------------------------#
        @show unique_time
        #--------------------------- Get integrated covaraince matrix and drift vector ---------------------#
        if(length(unique_time)<=3)
            (C_tot_B, drift_B, x1_mean_B, x1_init_B, x1_fini_B, point_set_x1_a_B, point_set_x1_b_B, 
                point_set_x2_a_B, point_set_x2_b_B) = 
                get_Bezier_Ctot_drift_x1mean_categorical_simple_poly(
                    poly_seq_len, q, L, MPLseq_raw_poly, specific_times, muMat, pseudo_count, 
                    polymorphic_positions, 
                    polymorphic_states_set, 
                    original_index_to_polymorophic_index, 
                    scaling_psc);
        end
        if(length(unique_time)>3)
            
            (C_tot_B, drift_B, x1_mean_B, x1_init_B, x1_fini_B, point_set_x1_a_B, 
                point_set_x1_b_B, point_set_x2_a_B, point_set_x2_b_B,
                x1_traject, x2_traject, time_list_with_insertion, 
                C_tot_positive, C_tot_negative) = 
                    get_Bezier_Ctot_drift_x1mean_categorical_simple_poly_virtual_insertion(
                        poly_seq_len, q, L, MPLseq_raw_poly, specific_times, muMat, pseudo_count, 
                        polymorphic_positions, 
                        polymorphic_states_set, 
                        original_index_to_polymorophic_index,  
                        scaling_psc, 
                        time_interval_threshold);

            # midpoint correlation for on-diagonal covariance
            C_diag_with_mid_point = correction_of_covariance_diagonal_with_midpoint(
                C_tot_B,
                poly_seq_len, 
                x1_traject, 
                x2_traject,
                time_list_with_insertion,
                diff_freq_thresh);

            # midpoint correction for off-diagonal covarinace
            (C_Offdiag_with_mid_point, C_Offdiag_with_mid_point_positive, 
                C_Offdiag_with_mid_point_negative, Matrix_modified_element) = 
                correction_of_covariance_diagonal_with_midpoint_on_off_diagonal(
                    C_tot_B,
                    C_tot_positive, 
                    C_tot_negative,
                    poly_seq_len, 
                    x1_traject, 
                    x2_traject,
                    time_list_with_insertion,
                    diff_freq_thresh);

            C_diag_original = [C_tot_B[i,i] for i in 1:poly_seq_len]
            C_tot_B_original = copy(C_tot_B)
            C_tot_B = copy(C_Offdiag_with_mid_point)
            C_tot_B[diagind(C_tot_B)] = copy(C_diag_with_mid_point)
        end;
        
        (C_tot_L, drift_L, x1_mean_L, x1_init_L, x1_fini_L) = get_integrated_Ctot_drift_x1mean_simple_poly(
            poly_seq_len, q, L, MPLseq_raw_poly, specific_times, muMat, pseudo_count, 
            polymorphic_positions, 
            polymorphic_states_set, 
            original_index_to_polymorophic_index, 
            scaling_psc);
        #-------------------------------------------------------------------------------------------------------#

        num_L = x1_fini_L-x1_init_L - drift_L
        num_B = x1_fini_B-x1_init_B - drift_B
        selec_coeff_L_poly = get_selection(gamma, num_L, C_tot_L);
        selec_coeff_B_poly = get_selection(gamma, num_B, C_tot_B);
        selec_coeff_L_poly_SL = get_selection_SL(gamma, num_L, C_tot_L);
        selec_coeff_B_poly_SL = get_selection_SL(gamma, num_B, C_tot_B);
        #-------------------------- Inflate selection so that it contains monoorophic sites ---------------------#
        selec_coeff_L = zeros(q*L)
        selec_coeff_B = zeros(q*L)
        selec_coeff_L_SL = zeros(q*L)
        selec_coeff_B_SL = zeros(q*L)
        
        selec_coeff_L[polymorphic_sites[4:end]] = selec_coeff_L_poly
        selec_coeff_B[polymorphic_sites[4:end]] = selec_coeff_B_poly
        selec_coeff_L_SL[polymorphic_sites[4:end]] = selec_coeff_L_poly_SL
        selec_coeff_B_SL[polymorphic_sites[4:end]] = selec_coeff_B_poly_SL
        #-------------------------------------------------------------------------------------------#
        
        fout_selections = open("../out/HIV/"*name_keys[ID_seq]*"-"*string(n_polys[ID_poly])*"_selections.dat", "w")
        for i in 1:size(selec_coeff_L,1)
            println(fout_selections, selec_coeff_L[i], " ", selec_coeff_B[i], " ", selec_coeff_L_SL[i], " ", selec_coeff_B_SL[i])
        end
        close(fout_selections)
        
        #------------------------------- Output for the covariance ---------------------------------#
        if(false)            
            fname_covnum_Bez = "../out/HIV/covariance-numerator"*name_keys[ID_seq]*"-"*string(n_polys[ID_poly])*"-poly-Bezier.txt"
            fname_covnum_Lin = "../out/HIV/covariance-numerator"*name_keys[ID_seq]*"-"*string(n_polys[ID_poly])*"-poly-Linear.txt"
            fout_covnum_Bez = open(fname_covnum_Bez, "w")
            fout_covnum_Lin = open(fname_covnum_Lin, "w")
            n_raw_max = size(C_tot_L,1)
            len_poly_positions = size(polymorphic_positions,1)
            for n in 1:len_poly_positions
                k = polymorphic_positions[n]
                for a in polymorphic_states_set[n]
                    for m in n:len_poly_positions
                        j = polymorphic_positions[m]
                        for b in polymorphic_states_set[m]
                            u = original_index_to_polymorophic_index[km(k,a,q)]
                            v = original_index_to_polymorophic_index[km(j,b,q)]
                            if(abs(C_tot_B[u,v])>1e-4)
                                println(fout_covnum_Bez, "C $k $j $a $b ", C_tot_B[u,v])
                            end
                            if(abs(C_tot_L[u,v])>1e-4)
                                println(fout_covnum_Bez, "C $k $j $a $b ", C_tot_L[u,v])
                            end
                        end
                    end
                end
            end
            
            for n in 1:len_poly_positions
                k = polymorphic_positions[n]
                for a in polymorphic_states_set[n]
                    u = original_index_to_polymorophic_index[km(k,a,q)]
                    println(fout_covnum_Bez, "dx $k $a ", x1_fini_B[n]-x1_init_B[n])
                    println(fout_covnum_Lin, "dx $k $a ", x1_fini_L[n]-x1_init_L[n])
                end
            end
            for n in 1:len_poly_positions
                k = polymorphic_positions[n]
                for a in polymorphic_states_set[n]
                    u = original_index_to_polymorophic_index[km(k,a,q)]
                    println(fout_covnum_Bez, "drift $k $a ",  drift_B[n])
                    println(fout_covnum_Lin, "drift $k $a ",  drift_L[n])
                end
            end
            close(fout_covnum_Bez); close(fout_covnum_Lin) 
        end
        #-------------------------------------------------------------------------------------------#
    end
end




fname = 700010470-3
L=367, q=5, N=113
unique_time = [0, 13, 41, 69, 174, 420]

fname = 700010470-5
L=193, q=5, N=104
unique_time = [0, 13, 41, 69, 174, 420, 454]

fname = 700010077-3
L=203, q=5, N=44
unique_time = [0, 14, 32, 102, 159]

fname = 700010077-5
L=48, q=5, N=32
unique_time = [0, 14, 32, 159]

fname = 700010058-3
L=90, q=5, N=25
unique_time = [0, 8, 45, 85]

fname = 700010058-5
L=96, q=5, N=52
unique_time = [0, 8, 45, 85, 154, 239, 252, 350]

fname = 700010040-3
L=303, q=5, N=82
unique_time = [0, 16, 45, 111, 181, 283, 412, 552]

fname = 700010040-5
L=146, q=5, N=74
unique_time = [0, 16, 45, 111, 181, 283, 412, 552]

fname = 700010607-3
L=239, q=5, N=73
unique_time = [0, 9, 14, 21]

fname = 700010607-5
L=78, q=5, N=76
unique_time = [0, 9, 14, 21]

fname = 706010164-3
L=485, q=5, N=102
unique_time = [0, 14, 28, 70, 183, 434]

fname = 706010164-5
L=204, q=5, N=98
unique_time = [0, 14, 28, 70, 183, 434]

fname = 705010198-3
L=204, q=5, N=48
unique_time = [0, 11, 60]

fname = 70