In [None]:
import torch 
import re
import pandas as pd
import os
from tqdm import tqdm
import statsmodels.api as sm
import json
import pickle
import matplotlib.pyplot as plt 
import matplotlib.pyplot as plt

import numpy as np
import seaborn as sns

import panel as pn

import holoviews as hv

from dataclasses import dataclass

import sys
sys.path.append(os.path.abspath(os.path.join('..')))
from src.preprocessing import ParsedChromoData, ParsedGData, ParsedGDataContainer, roman_numerals_inv, ChromoData, ChromoDataContainer



hv.extension("plotly")
pn.extension("plotly")
pn.config.theme = 'dark'
hv.renderer('plotly').theme = 'dark'

device='cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Functions for data processing

# Parses chromosome data
def loader(G_number, fold="RD_cache"):
    # Load the data
    Gxx_df = pd.read_csv(f"{fold}/{G_number}.txt", sep="\t", header=0)
    Gxx_df[['sample1', 'new_col']] = Gxx_df['sample1'].str.split(',', expand=True).astype(np.int64)
    Gxx_df.columns = ['CHROM', 'POS', 'Alt_Count', 'Ref_Count']
    Gxx_df['Read_Depth'] = Gxx_df['Alt_Count'] + Gxx_df['Ref_Count']
    Gxx_df['Gxx_ratio'] = 1 - (Gxx_df['Ref_Count'] / Gxx_df['Read_Depth'])  # CAN revert here to make consistent with plant biologists

    # Separate data by chromosomes and clean
    xI = Gxx_df.loc[Gxx_df['CHROM'] == 'chrI'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')
    xII = Gxx_df.loc[Gxx_df['CHROM'] == 'chrII'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')
    xIII = Gxx_df.loc[Gxx_df['CHROM'] == 'chrIII'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')
    xIV = Gxx_df.loc[Gxx_df['CHROM'] == 'chrIV'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')
    xV = Gxx_df.loc[Gxx_df['CHROM'] == 'chrV'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')
    xX = Gxx_df.loc[Gxx_df['CHROM'] == 'chrX'].sort_values(by=['POS']).drop_duplicates(['POS'], keep='first')

    # Concatenate all data
    xALL = pd.concat([xI, xII, xIII, xIV, xV, xX], axis=0, ignore_index=True)
    xALL.reset_index(drop=True, inplace=True)
    xALL['gPOS'] = xALL.index  # Create global pan chromosome index number
    xALL = xALL[['gPOS', 'CHROM', 'POS', 'Alt_Count', 'Ref_Count', 'Read_Depth', 'Gxx_ratio']]

    # Set binary threshold
    Hawaiian_threshold = 0.9

    # Generate binary column
    xALL_binary = xALL[['gPOS', 'CHROM', 'POS', 'Gxx_ratio', 'Read_Depth', 'Alt_Count', 'Ref_Count']]
    xALL_binary['Gxx_ratio_binary'] = np.where(xALL_binary['Gxx_ratio'] >= Hawaiian_threshold, 1, 0)

    return xI, xII, xIII, xIV, xV, xX, xALL, xALL_binary, [xI, xII, xIII, xIV, xV, xX]

# grouping data by G number into serialisable data containers
def ParsedData_processing(performReprocessing=False):
    parsed_data_file_exists = os.path.exists('../data/preprocessed/processed_data.pkl')
    print(f"Parsed data file exists: {parsed_data_file_exists}")
    if performReprocessing or not parsed_data_file_exists:
        # Preprocessing, ranging gausian/non gaussian data
        os.makedirs('../data/preprocessed', exist_ok=True)
        
        # enumerate all txt files in directory
        txt_files = [f for f in os.listdir('RD_cache') if f.endswith('.txt')]
        data_arr = []

        for txt_file in tqdm(txt_files):
            G_number_str = txt_file.split('.')[0]
            xI, xII, xIII, xIV, xV, xX, xALL, xALL_binary, xARR = loader(G_number_str)
            arrs = [xI, xII, xIII, xIV, xV, xX]
            g_number_data = ParsedGData(G_number_str)
            for i in range(len(arrs)):
                chrom = roman_numerals_inv[i + 1]
                arr = arrs[i]
                g_number_data.add_parsed_chromo_data(ParsedChromoData(chrom, i + 1, arr))
            data_arr.append(g_number_data)
                
        # order arrays by G number
        data_arr.sort(key=lambda x: int(re.search(r'\d+', x.g_number).group()), reverse=True)
        parsed_data_instance = ParsedGDataContainer()
        parsed_data_instance.data = data_arr
        parsed_data_instance.save_to_pkl()
        return parsed_data_instance
    else:
        return ParsedGDataContainer.load_from_pkl()

# Process the data, calculate lowess, find max index, and sort into gausian and non gausian data by data from the overrides, overrides defined manually in external json file
def GaussianData_processing(parsed_data_instance, overrides={}, lowess_iter=3, performReprocessing=False):
    gausian_data_file_exists = os.path.exists('../data/preprocessed/gaussian_data.pkl')
    print(f"Gaussian data file exists: {gausian_data_file_exists}")
    if performReprocessing or not gausian_data_file_exists:

        gausians = []
        not_gausians = []
        map = {}

        for g_container in tqdm(parsed_data_instance.data):
            
            map[g_container.g_number] = []
            
            for chromo_data in g_container.parsed_chromo_data:
                
                G_number = g_container.g_number
                chrom = chromo_data.chrom_label
                arr = chromo_data.array
                
                # calculate lowess curve (for visualisations) and lowess maximum (for prediction measurements)
                
                lowess_result = sm.nonparametric.lowess(arr['Gxx_ratio'], arr['POS'], frac=0.6, it=lowess_iter, return_sorted=True)
                lowess_x, lowess_y = lowess_result.T
                max_idx = np.argmax(lowess_y)
                lowess_out = np.array([lowess_x, lowess_y]).T
                
                # if in overrides, use the is_gausian value, else - this chromo is not gaussian                
                is_gausian_overrided = overrides.get(G_number, {}).get(chrom, {}).get('is_gausian', None)
                
                # if mu in overrides, it means it is measured experimentally and we have to store it for future evaluation of the model
                m_index = None
                m_overrided = overrides.get(G_number, {}).get(chrom, {}).get('m', None)
                if m_overrided is not None:
                    # Calculate the absolute difference and find the index of the minimum value
                    m_index = np.argmin(np.abs(arr['POS'] - m_overrided))
                    m_rd = arr['POS'].iloc[m_index]
                    print(f'G number: {G_number}, chrom: {chrom}, m_overrided: {m_overrided}, m_rd: {m_rd}, m_delta: {np.abs(arr["POS"] - m_overrided).min()}')
                    print(f"Len of arr: {len(arr['POS'])}, m_index: {m_index}, Min POS and max POS: {arr['POS'].min()}, {arr['POS'].max()}, m_overrided: {m_overrided}")
                    
                # add to arrays
                # GN, chrom, arr, is_gausian, measured µ
            
                chromoData = ChromoData(G_number, chrom, chromo_data.i_number, arr, lowess_out, max_idx, is_gausian_overrided, m_index)
                
                if is_gausian_overrided:
                    gausians.append(chromoData)
                else:
                    not_gausians.append(chromoData)
                    
                map[G_number].append(chromoData)
            
            # sort map[GN] by chromo number
            
            map[G_number].sort(key=lambda x: x.i_number)
                    
        
                
            
        gausian_data = ChromoDataContainer()
        gausian_data.gausians = gausians
        gausian_data.not_gausians = not_gausians
        gausian_data.all = gausians + not_gausians
        gausian_data.map = map
        gausian_data.save_to_pkl()
        print("Saved gausian data")
        return gausian_data
    else:
        return ChromoDataContainer.load_from_pkl()

In [None]:
# Perform processing
# it will save hierarchical unprocessed data structures
# it will save processed data structures to another file

# laoding overrides for gausian data. Overrides are used to manually set if the data is gausian or not, and to set the m index for measuered µ
print("Loading overrides")
gausian_overrides = json.loads(open('../configs/preprocessing/gausian_overrides.json').read())

# perform initial data parsing
print("Performing data parsing")
parsed_data_instance = ParsedData_processing()

# perform gausian data processing
print("Performing gausian data processing")
gausian_data_instance = GaussianData_processing(parsed_data_instance, overrides=gausian_overrides, lowess_iter=5)

# check if data is loaded correctly
parsed_data_arr_len = len(parsed_data_instance.data)
data_arr_len = len(parsed_data_instance.data)
        
print(f"Number of gausians: {len(gausian_data_instance.gausians)}, Number of not gausians: {len(gausian_data_instance.not_gausians)}")

In [None]:
# This cell visualizes parsed and processed data
# Expect that manually set gausian property should correspond to the color (yellow: true, red: false). Lowes data and max index should be visible and meaningful

# calculate average mean and std of Gxx_ratio from gausians
gausian_means_pos = []
gausian_stds_pos = []
pos_mins_pos = []
pos_maxs_pos = []

gausian_means_gxx = []
gausian_stds_gxx = []
pos_mins_gxx = []
pos_maxs_gxx = []

# for arr in datasets['plain_gausian_train']:
for g_container in tqdm(parsed_data_instance.data):
    
    for chromo_data in g_container.parsed_chromo_data:
        
        G_number = g_container.g_number
        chrom = chromo_data.chrom_label
        arr = chromo_data.array
        gausian_means_pos.append(arr['POS'].mean())
        gausian_stds_pos.append(arr['POS'].std())
        pos_mins_pos.append(arr['POS'].min())
        pos_maxs_pos.append(arr['POS'].max())
        
        gausian_means_gxx.append(arr['Gxx_ratio'].mean())
        gausian_stds_gxx.append(arr['Gxx_ratio'].std())
        pos_mins_gxx.append(arr['Gxx_ratio'].min())
        pos_maxs_gxx.append(arr['Gxx_ratio'].max())
    
pos_gausian_mean_avg = np.average(gausian_means_pos)
pos_gausian_std_avg = np.average(gausian_stds_pos)
pos_min = int(np.average(pos_mins_pos))
pos_max = int(np.average(pos_maxs_pos))

gxx_gausian_mean_avg = np.average(gausian_means_gxx)
gxx_gausian_std_avg = np.average(gausian_stds_gxx)
gxx_min = np.average(pos_mins_gxx)
gxx_max = np.average(pos_maxs_gxx)

print(f"POS Gausian mean avg: {pos_gausian_mean_avg}, std avg: {pos_gausian_std_avg}, min avg: {pos_min}, max avg: {pos_max}")
print(f"Gxx_ratio Gausian mean avg: {gxx_gausian_mean_avg}, std avg: {gxx_gausian_std_avg}, min avg: {gxx_min}, max avg: {gxx_max}")

# plot graphs    
    
fig_root, axes_root = plt.subplots(data_arr_len, 6, figsize=(40, data_arr_len * 5))
fig_root.suptitle('Distributions', fontsize=16)
for i, (key, data) in enumerate(gausian_data_instance.map.items()):
    for k, g in enumerate(data):
    
        G_number = g.g_number
        chrom = g.chrom_label
        arr = g.array
        lowess_out = g.lowess_out
        max_idx = g.lowess_max_idx
        is_gausian = g.is_gausian
        m_index = g.m_index
    
        ax = axes_root[i, k]
            
        sb = sns.scatterplot(x='POS', y='Gxx_ratio', data=arr, c='grey', ax=ax)        
        sb.set_title(f"{chrom} - {G_number}")
        
        sb.set_xlabel(f'gausian={is_gausian}')
        sb.set_ylabel('')
        
        # hide ticks
        ax.set_xticks([])
        ax.set_yticks([])
        
        # add lowess_out as a chart, x is POS, y is lowess_out
        ax.plot(arr['POS'], lowess_out[:, 1], 'r-', linewidth=2)
        # add max value as a tringle point
        
        # closest POS value to max_idx
        max_pos_rd = arr['POS'].iloc[max_idx]
        
        # put point , big size
        # ax.plot(max_pos, arr['Gxx_ratio'].iloc[max_idx], 'g^', markersize=10)

        # put vertical line on max point
        ax.axvline(max_pos_rd, color='b', linestyle='--')
        
        # put vertical line on half of POS
        
        middle_pos = arr['POS'].iloc[-1] / 2
        ax.axvline(middle_pos, color='r', linestyle='--')
        
        ax.set_title(chrom)
        # bg color is weak green if gausian, weak red if not
        ax.set_facecolor(
            (0.9764705882352941, 0.984313725490196, 0.9058823529411765, 1) 
            if is_gausian 
            else (0.984313725490196, 0.9137254901960784, 0.9058823529411765, 1))
            
        axes_root[i, 0].set_ylabel(key)