In [None]:
import sys
import os

import pandas as pd

import ipywidgets as widgets
from IPython.display import display

# add parent directory to path
library_path = os.path.abspath('..')
if library_path not in sys.path:
    sys.path.append(library_path)

from ideal_genom.zoom_heatmap import filter_sumstats, snp_annotations, get_gene_information
from ideal_genom.zoom_heatmap import  get_ld_matrix

In [None]:
# Create interactive widgets for input
input_path = widgets.Text(
    value='/home/luis/data/gwasResult/',
    description='Path to project folder:',
    style={'description_width': 'initial'}
)

input_name = widgets.Text(
    value='annotated_normalized_combined_R2_0.3.dose_step2_sex_pheno-glm.PHENO1.glm.logistic.hybrid',
    description='Name of GWAS summary file:',
    style={'description_width': 'initial'}
)

top_snp = widgets.Text(
    value='table_lead_SNPS_GWAS_glm_logistic_final_paper',
    description='Name of file with SNPs to highlight:',
    style={'description_width': 'initial'}
)

bfile_path = widgets.Text(
    value='/home/luis/data/LuxGiantimputed/inputData/',
    description='Path to project folder:',
    style={'description_width': 'initial'}
)

bfile_name = widgets.Text(
    value='luxgiant_imputed_noprobID',
    description='Path to project folder:',
    style={'description_width': 'initial'}
)

# Display the widgets
display(input_path, input_name, top_snp, bfile_path, bfile_name)

# Function to get the text parameter values
def get_params():
    return input_path.value, input_name.value, top_snp.value, bfile_path.value, bfile_name.value

In [None]:
# Use the parameter values
path_params = get_params()
print(f"Parameter 1: {path_params[0]}")
print(f"Parameter 2: {path_params[1]}")
print(f"Parameter 3: {path_params[2]}")
print(f"Parameter 4: {path_params[3]}")
print(f"Parameter 5: {path_params[4]}")

In [None]:
cols_touse = widgets.Textarea(
    value="#CHROM, POS, ID, P",
    description='Columns to use on the Mannhattan plot (comma-separated):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)
display(cols_touse)

def get_cols():
    return cols_touse.value

In [None]:
cols = get_cols()

df_gwas = pd.read_csv(
    os.path.join(path_params[0], path_params[1]), sep='\t', usecols=[col.strip() for col in cols.split(',')]
)
df_gwas.head(5)

In [None]:
rsID = widgets.Text(
    value='SNP',
    description='Column name with the rsID:',
    style={'description_width': 'initial'}
)

display(rsID)

def get_rsID():
    return rsID.value

In [None]:
gwas_rsID = get_rsID()

if path_params[2] is not None or path_params[2] != '':
    highlit_path = os.path.join(path_params[0], path_params[2])
    if os.path.exists(highlit_path):
        df_high = pd.read_csv(
            highlit_path, sep='\t'
        )
        to_highlight = df_high[gwas_rsID].to_list()
        print(to_highlight[:10])
        del df_high
    else:
        print('Path to file with SNPs to highlight does not exist')
        to_highlight = []
else:
    print('No file with SNPs to highlight')
    to_highlight = []

In [None]:
SNP_col = widgets.Text(
    value='ID',
    description='Column with rsID:',
    style={'description_width': 'initial'}
)

CHR_col = widgets.Text(
    value='#CHROM',
    description='Column with chromosome:',
    style={'description_width': 'initial'}
)

POS_col = widgets.Text(
    value='POS',
    description='Column with base-pair position:',
    style={'description_width': 'initial'}
)

P_col = widgets.Text(
    value='P',
    description='Column with p-values:',
    style={'description_width': 'initial'}
)
display(SNP_col, CHR_col, POS_col, P_col)

def get_col_names():
    return SNP_col.value, CHR_col.value, POS_col.value, P_col.value

In [None]:
snp_col, chr_col, pos_col, p_col = get_col_names()

In [None]:
filtered = filter_sumstats(
    data_df       =df_gwas, 
    lead_snp      =to_highlight[0], 
    snp_col       =snp_col, 
    p_col         =p_col, 
    pos_col       =pos_col, 
    chr_col       =chr_col, 
    pval_threshold=5e-6, 
    radius        =1e6
)
filtered.head()

In [None]:
annotated = snp_annotations(
    data_df=filtered, 
    snp_col=snp_col, 
    chr_col=chr_col, 
    pos_col=pos_col
)

annotated['Mbp'] = annotated['POS'] / 1e6
annotated.head()

In [None]:
annotated['consequence'].value_counts(dropna=False)

In [None]:
genes = get_gene_information(
    genes=annotated['GENENAME'].unique().tolist(),
    gtf_path=None,
    build='38'
)
genes['start_esc'] = genes['start']/1e6
genes['end_esc'] = genes['end']/1e6
genes['length_esc'] = genes['end_esc'] - genes['start_esc']
genes['lenght_esc1'] = genes['length_esc']/1e6
genes


In [None]:
get_ld_matrix(
    data_df=annotated,
    snp_col=snp_col,
    pos_col=pos_col,
    bfile_folder=path_params[3],
    bfile_name=path_params[4],
    output_path='/home/luis/data/gwasResult/',
)

In [None]:
df_LD = pd.read_csv(
        '/home/luis/data/LuxGiantimputed/inputData/matrix-ld.ld', 
        sep=r'\s+',
        header=None,
        index_col=None,
        engine='python'
    )
df_LD.head()
ld = df_LD.values

In [None]:
ld.shape

In [None]:
import numpy as np

N=ld.shape[0]
ld = np.tril(ld, k=0)
ldm = np.ma.masked_where(ld==0, ld)

In [None]:
import matplotlib.pyplot as plt
import matplotlib
from itertools import cycle
from matplotlib.patches import FancyArrow
from matplotlib import transforms

region = (annotated['Mbp'].min() - 0.05, annotated['Mbp'].max() + 0.05)

lead_snp = to_highlight[0]

plt.figure(figsize=(10, 10))

# Define the overall grid size (9 rows, 1 column)
ax1 = plt.subplot2grid((9, 1), (0, 0), rowspan=4)  # Top plot (4 rows)
ax2 = plt.subplot2grid((9, 1), (4, 0), rowspan=1)  # Middle plot (1 row)
ax3 = plt.subplot2grid((9, 1), (5, 0), rowspan=4)  # Bottom plot (4 rows)

# Plot for ax1
lead = annotated[annotated[snp_col] == lead_snp]
missense = annotated[annotated['consequence'] == 'intron_variant']
utr3 = annotated[annotated['consequence'] == '3_prime_UTR_variant']
utr5 = annotated[annotated['consequence'] == '5_prime_UTR_variant']
upstrm = annotated[annotated['consequence'] == 'upstream_gene_variant']

ax1.scatter(annotated['Mbp'], annotated[p_col], s=15, color='grey', label='')
ax1.scatter(missense['Mbp'], missense[p_col], s=30, color='orange', label='Intron variant')
ax1.scatter(utr3['Mbp'], utr3[p_col], s=30, color='blue', label="3'-UTR variant")
ax1.scatter(utr5['Mbp'], utr5[p_col], s=30, color='green', label="5'-UTR variant")
ax1.scatter(upstrm['Mbp'], upstrm[p_col], s=30, color='black', label="Upstream variant")
ax1.scatter(lead['Mbp'], lead[p_col], s=30, color='red', label='Lead SNP')
ax1.scatter
ax1.set_xlim(region)
ax1.xaxis.set_ticks_position('top')
ax1.legend(loc='best')
ax1.set_title(f"Million basepairs on BTA{1}", fontsize=12)
ax1.set_ylabel('log10(P)', fontsize=12)

# Plot for ax2
ys = cycle([0.1, 0.4, 0.7, 1])

for i in genes.index:
    symbol, strand = genes.loc[i, 'gene'], genes.loc[i, 'strand']
    start, end, length = genes.loc[i, 'start_esc'], genes.loc[i, 'end_esc'], genes.loc[i, 'length_esc']
    y = next(ys)
    
    if symbol == lead['GENENAME'].values[0]:
        color = 'red'
    else:
        color = 'black'
    
    if strand == '+':
        arrow = FancyArrow(start, y, length, 0, width=0.001, head_width=0.03, head_length=0.01, color=color)
        ax2.add_patch(arrow)
        ax2.text(start + 0.5 * length, y + 0.05, symbol, ha='center', size=9)
    elif strand == '-':
        arrow_neg = FancyArrow(end, y, -length, 0, width=0.001, head_width=0.03, head_length=0.01, color=color)
        ax2.add_patch(arrow_neg)
        ax2.text(start + 0.5 * length, y + 0.05, symbol, ha='center', size=9)

ax2.set_ylim(0, 1.2)
ax2.set_xlim(region)
ax2.axis('off')

base = ax3.transData # to rotate triangle
rotation = transforms.Affine2D().rotate_deg(180+90+45)
cmap=matplotlib.cm.Reds
im=ax3.imshow(ldm, cmap=cmap,transform=rotation+base,aspect='auto')
ax3.set_xlim([2, 1.41*N])
ax3.set_ylim([1*N, 2])
ax3.axis('off')

# Add a colorbar as the legend
cbar = plt.colorbar(im, ax=ax3, orientation='horizontal', fraction=0.05, pad=0.2)
cbar.set_label('LD Value', fontsize=10)  # Adjust label as needed

plt.tight_layout()
plt.show()
