In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import warnings
import sklearn
import random

warnings.filterwarnings("ignore")

### Do derivation / test split

Remove datapoints measured too close (< 30 days) before diagnosis

In [None]:
def deriv_test_split(patient_list, shuffle=False, random_state=42):
    # Divide patients to train / validation / test groups
         
    random.seed(random_state)
    
    if shuffle == True:
        random.shuffle(patient_list)
    
    # Calculate the number of items in each sublist
    total_items = len(patient_list)
    deriv_size = int(total_items * 0.85)
    test_size = total_items - deriv_size  # To ensure all items are included

    # Divide the list into sublists
    deriv_list = patient_list[:deriv_size]
    test_list = patient_list[deriv_size:]

    return deriv_list, test_list

In [None]:
my_path = '~/mounts/research/husdatalake/disease/scripts/Preleukemia/oona_git'

In [None]:
disease = 'MDS'

In [None]:
data = pd.read_csv(my_path + '/data/modelling/' + disease + '_and_healthy_modelling_data.csv')
data = data.drop(['Unnamed: 0', 'event_1y', 'time'], axis=1)

In [None]:
def get_nan_percentage(df):
    # Calculate the number of missing values in each column
    nan_counts = df.isna().sum()
    # Calculate the percentage of missing values
    nan_percentage = (nan_counts / len(df)) * 100
    return nan_percentage

### Drop columns with too many nans

In [None]:
nan_percentages = get_nan_percentage(data)

cols = list(nan_percentages.index)
values = list(nan_percentages.values)

too_many_missing = []

for i in range(len(cols)):
    
    if values[i] > 75:
        #print(True)
        too_many_missing.append(cols[i])
    

too_many_missing

In [None]:
data = data.drop(columns=too_many_missing)

### Create deriv and test sets

In [None]:
disease_patients = list(data[data['disease'] == 1]['henkilotunnus'].unique())
healthy_patients = list(data[data['disease'] == 0]['henkilotunnus'].unique())

In [None]:
rs=123

In [None]:
deriv_disease, test_disease = deriv_test_split(disease_patients, shuffle=True, random_state=rs)
deriv_healthy, test_healthy = deriv_test_split(healthy_patients, shuffle=True, random_state=rs)

In [None]:
len(deriv_disease), len(test_disease)

In [None]:
len(deriv_healthy), len(test_healthy)

In [None]:
# Get datapoints based on patient lists
deriv_disease_data = data[data['henkilotunnus'].isin(deriv_disease)]
test_disease_data = data[data['henkilotunnus'].isin(test_disease)]

deriv_healthy_data = data[data['henkilotunnus'].isin(deriv_healthy)]
test_healthy_data = data[data['henkilotunnus'].isin(test_healthy)]

In [None]:
# Concatenate disease + healthy sets together
deriv_data = pd.concat([deriv_disease_data, deriv_healthy_data], ignore_index=True)
test_data = pd.concat([test_disease_data, test_healthy_data], ignore_index=True)

In [None]:
# Censoring 'disease' == 0
deriv_data.loc[deriv_data['disease'] == 0, 'time_to_dg'] *= -1
test_data.loc[test_data['disease'] == 0, 'time_to_dg'] *= -1

In [None]:
# Remove datapoints that were measured too close to diagnosis (< 1 month before dg)
disease_dp_before_removal = len(deriv_data[deriv_data['disease'] == 1])
deriv_data = deriv_data[(deriv_data['time_to_dg'] > 30) | (deriv_data['time_to_dg'] < 0)]
disease_dp_after_removal = len(deriv_data[deriv_data['disease'] == 1])
print(f'{disease_dp_before_removal - disease_dp_after_removal} disease = 1 datapoints were removed, as they were measured less than 30 days before diagnosis.')

In [None]:
# Remove datapoints that were measured too close to diagnosis (< 1 month before dg)
disease_dp_before_removal = len(test_data[test_data['disease'] == 1])
test_data = test_data[(test_data['time_to_dg'] > 30) | (test_data['time_to_dg'] < 0)]
disease_dp_after_removal = len(test_data[test_data['disease'] == 1])
print(f'{disease_dp_before_removal - disease_dp_after_removal} disease = 1 datapoints were removed, as they were measured less than 30 days before diagnosis.')

In [None]:
deriv_ht = list(deriv_data['henkilotunnus'].unique())

In [None]:
test_ht = list(test_data['henkilotunnus'].unique())

In [None]:
def check_common_elements(list1, list2):
    # Convert lists to sets for faster membership testing
    set1 = set(list1)
    set2 = set(list2)
    
    # Check if there is any common element
    common_elements = set1.intersection(set2)
    
    return len(common_elements) > 0


In [None]:
check_common_elements(deriv_ht, test_ht)

### Save deriv and test sets

In [None]:
deriv_data = deriv_data.sample(frac=1).reset_index(drop=True)

In [None]:
test_data = test_data.sample(frac=1).reset_index(drop=True)

In [None]:
deriv_data.to_csv(my_path + '/data/modelling/' + disease + '_derivation_data.csv', index=False)

In [None]:
test_data.to_csv(my_path + '/data/modelling/' + disease + '_test_data.csv', index=False)