# Import libraries

In [1]:
import os
import logging

# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 0 = all messages are logged (default), 1 = INFO, 2 = WARNING, 3 = ERROR
logging.getLogger('tensorflow').setLevel(logging.ERROR)

import time  # Import the time module
import warnings
import importlib.util

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from scipy.stats import t, entropy, stats

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import train_test_split

import statsmodels.api as sm

import tensorflow as tf
from tensorflow.keras import regularizers, Input, Model, layers
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping

from k_means_constrained import KMeansConstrained

from helpers import (
    parse_variables, get_risk_level, hi_gauss_blob_risk_fun, blob_risk_fun, 
    NW_risk_fun, square_risk_fun, map_to_color, simulate_quant_trait
)

from models import ols_regression, manhattan_linear, gc
from deep_learning_models import abyss, deep_abyss


In [2]:
dict = parse_variables('geno_simulation.txt')
if 'G' not in globals():
    G = int(dict['G'])
if 'L' not in globals():
    L = int(dict['L'])
if 'c' not in globals():
    c = int(dict['c'])
if 'k' not in globals():
    k = int(dict['k'])
if 'M' not in globals():
    M = float(dict['M'])
if 'HWE' not in globals():
    HWE = int(dict['HWE'])

if 'bottleneck_nr' not in globals():
    bottleneck_nr = int(dict['bottleneck_nr'])

if 'nr_humans' not in globals():
    nr_humans = int(dict['nr_humans'])

if 'nr_snps' not in globals():
    nr_snps = int(dict['nr_snps'])

if 'epoch' not in globals():
    epoch = 500
if 'patience' not in globals():
    patience = 100

if 'tools' not in globals():
    tools = ['PCA', 'abyss_counted', 'abyss', 'no_corr']


if 'scenarios' not in globals():
    scenarios = ['snp_effect',
                 'linear_continuous',
                 'non_linear_continuous',
                 'discrete_global',
                 'discrete_localized',
                 'mix_linear_continuous',
                 'mix_non_linear_continuous',
                 'mix_discrete_global',
                 'mix_discrete_localized']

if 'very_rare_threshold_L' not in globals():
    very_rare_threshold_L = float(dict['very_rare_threshold_L'])
if 'very_rare_threshold_H' not in globals():
    very_rare_threshold_H = float(dict['very_rare_threshold_H'])
if 'rare_threshold_L' not in globals():
    rare_threshold_L = float(dict['rare_threshold_L'])
if 'rare_threshold_H' not in globals():
    rare_threshold_H = float(dict['rare_threshold_H'])
if 'common_threshold_L' not in globals():
    common_threshold_L = float(dict['common_threshold_L'])
if 'common_threshold_H' not in globals():
    common_threshold_H = float(dict['common_threshold_H'])

number_of_snps = int((G*L)/2) # one loci per chromosome
number_of_individuals = c*k*k

In [3]:
very_rare = pd.read_pickle(f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/01_veryrare_genotype_AF_{very_rare_threshold_L}_{very_rare_threshold_H}.pkl")
rare = pd.read_pickle(f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/01_rare_genotype_AF_{rare_threshold_L}_{rare_threshold_H}.pkl")
common = pd.read_pickle(f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/01_common_genotype_AF_{common_threshold_L}_{common_threshold_H}.pkl")

In [4]:
complete = pd.concat([common, rare, very_rare], axis=1)
complete = ((complete*2)-1)

In [5]:
path_bottle = f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/phenotype/abyss_bottleneck"
bottle_file = [f for f in os.listdir(path_bottle) if int(f.split("_")[2]) ==  bottleneck_nr][0]
elapsed_time_bottleneck = float(bottle_file.split('_')[3].split('seconds')[0])
bottle = pd.read_pickle(f"{path_bottle}/{bottle_file}")

In [6]:
bottle

Unnamed: 0,dim1,dim2,dim3,dim4,dim5,dim6,dim7,dim8,dim9,dim10,...,dim56,dim57,dim58,dim59,dim60,dim61,dim62,dim63,dim64,cluster
0,0.081082,0.197128,0.503182,0.440843,0.331699,0.388049,0.279960,0.161503,0.249616,0.192196,...,0.164692,0.063415,-0.122220,0.194483,0.326115,0.210555,0.392330,0.258133,0.315346,0
1,0.051411,0.194343,0.514074,0.433425,0.358165,0.360208,0.276100,0.134745,0.230915,0.193033,...,0.169409,0.046458,-0.113131,0.170390,0.357199,0.197391,0.383531,0.282799,0.295248,0
2,0.046141,0.206636,0.541591,0.446705,0.335784,0.370777,0.286725,0.135338,0.219357,0.176615,...,0.147489,0.031859,-0.122554,0.201595,0.342846,0.178856,0.404147,0.258219,0.290054,0
3,0.060589,0.197433,0.534488,0.448788,0.340554,0.385742,0.278493,0.150117,0.232181,0.183369,...,0.157063,0.045721,-0.118741,0.198521,0.348593,0.192604,0.405227,0.266623,0.301941,0
4,0.055164,0.252698,0.499711,0.418918,0.378986,0.364336,0.339525,0.138378,0.234135,0.153697,...,0.124367,0.048491,-0.138461,0.214008,0.362452,0.191057,0.369223,0.294302,0.299973,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1595,0.582848,0.176673,-0.129426,0.271133,0.172404,0.658223,0.288276,0.627660,0.482198,0.294968,...,0.322009,0.275365,-0.117538,0.452628,-0.086028,0.434468,0.159564,0.126498,0.555196,1
1596,0.602513,0.165645,-0.125714,0.281242,0.156494,0.675679,0.276139,0.641265,0.498672,0.300380,...,0.325722,0.290104,-0.122428,0.454089,-0.097069,0.445591,0.171101,0.113237,0.574625,1
1597,0.597850,0.189113,-0.139056,0.279335,0.132832,0.679685,0.298575,0.650033,0.477526,0.290160,...,0.320565,0.268665,-0.110846,0.494788,-0.128898,0.425239,0.170201,0.087521,0.555398,1
1598,0.607971,0.142840,-0.095925,0.293555,0.192652,0.668437,0.254003,0.626998,0.536524,0.306473,...,0.319868,0.328265,-0.150512,0.405673,-0.046194,0.475758,0.184838,0.152123,0.607392,1


# Run Abyss on LD block

In [7]:
def maf_prediction(bottle_in, geno_out, epoch, patience):
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(bottle_in, geno_out, test_size=0.2, random_state=42)
    
    # Regularization parameter
    l2_regularizer = 0.001
    
    # Original autoencoder model with L2 regularization
    decoder = tf.keras.Sequential([
        tf.keras.layers.Dense(int(nr_snps/2), activation='elu', input_shape=(bottle_in.shape[1],), kernel_regularizer=regularizers.l2(l2_regularizer)),  # First hidden layer with L2 regularization
        layers.BatchNormalization(),
        tf.keras.layers.Activation('elu'),
        tf.keras.layers.Dense(geno_out.shape[1], activation='linear', kernel_regularizer=regularizers.l2(l2_regularizer))  # Output layer
    ])
    
    # Compile the original model with L2 regularization
    decoder.compile(optimizer='adam',
                        loss='mean_squared_error',
                        metrics=['mean_absolute_error'])
    
    # Define Early Stopping
    early_stopping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)
    
    # Fit the original model with Early Stopping
    history = decoder.fit(X_train, y_train, epochs=epoch, batch_size=32, validation_split=0.2, callbacks=[early_stopping], verbose=0)
    
    return decoder, history

In [None]:
for pop in bottle['cluster'].unique():
    temp_bottle = bottle[bottle['cluster'] == pop]
    temp_bottle = temp_bottle.drop('cluster', axis=1)
    temp_bottle_tensor = tf.convert_to_tensor(temp_bottle, dtype=tf.float32)
    path_output = f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/LD_blocks_estimated_mafs/{pop}"
    os.system(f"rm -rf {path_output}")
    os.makedirs(path_output, exist_ok = True)
    path_one_hot_genotype = f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/LD_blocks_one_hot/{pop}"
    path_lds = f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/LD_blocks/{pop}"
    ld_files = os.listdir(path_lds)
    p2s = []
    twopqs = []
    q2s = []
    for ld_file in ld_files:
        db_minor = pd.read_pickle(f"{path_one_hot_genotype}/{ld_file.split('.pkl')[0]}_db_minor.pkl")
        db_het = pd.read_pickle(f"{path_one_hot_genotype}/{ld_file.split('.pkl')[0]}_db_het.pkl")
        db_major = pd.read_pickle(f"{path_one_hot_genotype}/{ld_file.split('.pkl')[0]}_db_major.pkl")
        
        start_time_p2 = time.time()
        decoder, history = maf_prediction(temp_bottle, db_major, epoch, patience)
        end_time_p2 = time.time()
        
        elapsed_time_p2 = np.round(end_time_p2 - start_time_p2,3)
        p2 = decoder(temp_bottle_tensor)
        p2 = pd.DataFrame(data=p2, columns = db_major.columns)
    
        p2.index = db_major.index
        #p2s.append(p2)
        
        start_time_2pq = time.time()
        decoder, history = maf_prediction(temp_bottle, db_het, epoch, patience)
        end_time_2pq = time.time()
        
        elapsed_time_2pq = np.round(end_time_2pq - start_time_2pq,3)
        
        twopq = decoder(temp_bottle_tensor)
        twopq = pd.DataFrame(data=twopq, columns = db_het.columns)  
        twopq.index = db_het.index
        #twopqs.append(twopq)
        
        start_time_q2 = time.time()
        decoder, history = maf_prediction(temp_bottle, db_minor, epoch, patience)
        end_time_q2 = time.time()
        
        elapsed_time_q2 = np.round(end_time_q2 - start_time_q2,3)
        q2 = decoder(temp_bottle_tensor)
        q2 = pd.DataFrame(data=q2, columns = db_minor.columns)
        
        q2.index = db_minor.index
        #q2s.append(q2)
        
        path_output_global = f"data/G{G}_L{L}_c{c}_k{k}_M{M}_HWE{HWE}/genotype/LD_blocks_estimated_mafs/{pop}"
        p2.to_pickle(f"{path_output_global}/{ld_file}_esti_p2_via_esti_pop_{elapsed_time_p2}seconds.pkl")
        twopq.to_pickle(f"{path_output_global}/{ld_file}_esti_2pq_via_esti_pop_{elapsed_time_2pq}seconds.pkl")        
        q2.to_pickle(f"{path_output_global}/{ld_file}_esti_q2_via_esti_pop_{elapsed_time_q2}seconds.pkl")


In [8]:
"""
p2 = pd.concat(p2s, axis=1)
p2 = p2.sort_index()
p2 = p2[list(complete.columns)]

q2 = pd.concat(q2s, axis=1)
q2 = q2.sort_index()
q2 = q2[list(complete.columns)]

twopq = pd.concat(twopqs, axis=1)
twopq = twopq.sort_index()
twopq = twopq[list(complete.columns)]
"""

'\np2 = pd.concat(p2s, axis=1)\np2 = p2.sort_index()\np2 = p2[list(complete.columns)]\n\nq2 = pd.concat(q2s, axis=1)\nq2 = q2.sort_index()\nq2 = q2[list(complete.columns)]\n\ntwopq = pd.concat(twopqs, axis=1)\ntwopq = twopq.sort_index()\ntwopq = twopq[list(complete.columns)]\n'