## Setup

In [1]:
import pandas as pd
import pickle
from typing import List, Dict

In [2]:
# These pickles are the outputs of the medbert repo: https://github.com/ZhiGroup/Med-BERT/blob/master/Pretraining%20Code/Readme.md

MEDBERT_CODES_DICT_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/eicu_crd_data/eicu_crd.types'
PRETRAINED_TRAIN_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/eicu_crd_data/eicu_crd.bencs.train'
PRETRAINED_VALIDATION_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/eicu_crd_data/eicu_crd.bencs.valid'
PRETRAINED_TEST_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/eicu_crd_data/eicu_crd.bencs.test'

MEDBERT_OUTPUT_PICKLES_DIR = '/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/eicu_crd'
TARGET_DISEASE_IDS = ['51881', '51882', '51883', '51884', '7991'] # Adlt resp fl. icd9 codes.


In [3]:
with open(PRETRAINED_TRAIN_PICKLE_PATH, 'rb') as f:
    medbert_train_data = pickle.load(f)
with open(PRETRAINED_VALIDATION_PICKLE_PATH, 'rb') as f:
    medbert_validation_data = pickle.load(f)
with open(PRETRAINED_TEST_PICKLE_PATH, 'rb') as f:
    medbert_test_data = pickle.load(f)


In [None]:
len(medbert_train_data), len(medbert_validation_data), len(medbert_test_data)

### Convert pickle to df


In [None]:
train_df = pd.DataFrame(medbert_train_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])
validation_df = pd.DataFrame(medbert_validation_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])
test_df = pd.DataFrame(medbert_test_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])
for x in (train_df, validation_df, test_df):
    x.drop(columns=['los', 'time_not_used'], inplace=True)
    x.drop(x[x['code'].apply(lambda x: len(x) <= 1)].index, inplace=True) # remove patients with only one diagnosis
train_df

In [None]:
with open(MEDBERT_CODES_DICT_PATH, 'rb') as f:
          code_to_id_dict = pickle.load(f)
print(code_to_id_dict)

def convert_codes_to_ids(codes: List[str], code_to_id_dict: Dict[str, int]):
    converted_codes = []
    for code in codes: 
        converted_codes.append(code_to_id_dict[str(code)])
    return converted_codes


In [None]:
id_to_code_dict = {v:k for k, v in code_to_id_dict.items()}
file_path = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/eicu_crd_data/eicu_crd_id_to_code.types'
with open(file_path, 'wb') as file:
    pickle.dump(id_to_code_dict, file)


In [8]:
def convert_vocab_to_with_icd_dot():
    id_to_icd9_dot = {v:k for k, v in code_to_id_dict.items()}
    

In [9]:
TARGET_DISEASE_IDS = convert_codes_to_ids(TARGET_DISEASE_IDS, code_to_id_dict)
TARGET_DISEASE_IDS

[14, 10, 131, 239, 122]


## Convert to medbert format

In [None]:
train_df


In [11]:
def add_sep_between_visits(row):
    codes, visits = row.code, row.visits
    new_codes = []
    new_visits = []
    
    for i in range(len(codes)):
        new_codes.append(codes[i])
        new_visits.append(visits[i])
        if i < len(codes) - 1 and visits[i] != visits[i + 1]:
            new_codes.append('SEP')
            new_visits.append('SEP')
    new_codes.append('SEP')
    new_visits.append('SEP')
    assert len(new_codes) == len(new_visits)
    return new_codes, new_visits

for x in (train_df, validation_df, test_df):
    x[['code', 'visits']] = x.apply(add_sep_between_visits, axis=1, result_type='expand')


In [None]:
def has_first_dignosis_target(inner_list):
    first_diag = inner_list[0]
    return first_diag in TARGET_DISEASE_IDS


In [14]:
has_first_dignosis_target([19, 29, 31, 102, 238, 8, 'SEP'])

False

In [None]:
def target_disease_in_first_diagnosis(row):
    codes = row.code
    return codes[0] in TARGET_DISEASE_IDS
    
for x in (train_df, validation_df, test_df):
    mask = x.apply(target_disease_in_first_diagnosis, axis=1)
    x.drop(index=x[mask].index, inplace=True)

train_df.shape

In [None]:
train_df.shape, validation_df.shape, test_df.shape

### To medbert pickle format for fine-tuning

In [17]:
from typing import List
import random

def has_target_disease(target_disease_ids: List[str], codes: List[str]):
    # return True if at least one from target_disease_ids can be found in codes and its index in the codes.
    for index, code in enumerate(codes):
        if code in target_disease_ids:
            if index == 0:
                print('Error!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
                print(codes)
            return True
    return False

def random_negative_index(codes: List[str]):
    # get a random index according the size of codes, and return the list until this index (included this index). 
    if not codes:
        print('Empty list of codes!!!!!!!!!!!!!!!!!!!!!')
        return []
    
    random_index = random.randint(0, len(codes) - 1)
    return codes[:random_index+1]

def get_positive_index(codes: List[str]): 
    if not codes:
        print('Empty list of codes!!!!!!!!!!!!!!!!!!!!!')
        return []
    for index, code in enumerate(codes):
        if code in TARGET_DISEASE_IDS:
            return codes[:index]
    print('Index was not found!')
    return []

def preprocess_patient_records(patient_list_data: List[str], was_target_found: bool):
    # patient_list_data can be codes, ages, years. 
    if was_target_found:
        return get_positive_index(patient_list_data)
    return random_negative_index(patient_list_data)

def preprocess_patient_data(row):
    person_id = row['person_id']
    codes = row['code']
    visits = row['visits']
    assert len(codes) == len(visits)
    classification_binary_label = has_target_disease(TARGET_DISEASE_IDS, codes)
    codes = preprocess_patient_records(codes, classification_binary_label)
    visits = visits[:len(codes)]
    assert len(codes) == len(visits)

    return codes, visits, 1 if classification_binary_label else 0


In [None]:
train_df

In [None]:
for x in (train_df, validation_df, test_df):
    x[['code', 'visits', 'label']] = x.apply(preprocess_patient_data, axis=1, result_type='expand')
train_df

In [None]:
def filter_sep(row):
    codes = row.code
    visits_num = row.visits
    indexes_to_remove = []
    assert len(codes) == len(visits_num)
    for index, code in enumerate(codes):
        if code == 'SEP':
            indexes_to_remove.append(index)
    assert len(codes) == len(visits_num)
    codes = [code for i, code in enumerate(codes) if i not in indexes_to_remove]
    visits_num = [num for i, num in enumerate(visits_num) if i not in indexes_to_remove]
    assert len(codes) == len(visits_num)
    return codes, visits_num

for x in (train_df, validation_df, test_df):
    x[['code', 'visits']] = x.apply(filter_sep, axis=1, result_type='expand')

train_df.head()

### To pickles

In [28]:
def write_df_to_pickle(df: pd.DataFrame, pickle_output_dir: str, df_type: str, disease_name: str):
    # df with columns: person_id, code
    # Create a list to store patient records
    patient_records = []

    # Iterate over each row in the DataFrame
    for index, row in df.iterrows():
        # Extract the necessary information from the row
        pt_id = row['person_id']
        label = row['label']
        seq_list = row['code']
        segment_list = row['visits']
        assert len(seq_list) == len(segment_list)
        
        # Create a patient record as a sublist
        patient_record = [pt_id, label, seq_list, segment_list]
        # Append the patient record to the list of patient records
        patient_records.append(patient_record)

    # Write the list of patient records to a pickle file
    output_pickle_path = f'{pickle_output_dir}/{disease_name}_{df_type}.pickle'
    with open(output_pickle_path, 'wb') as file:
        pickle.dump(patient_records, file)


In [29]:
MEDBERT_OUTPUT_PICKLES_DIR

'/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/eicu_crd'

In [30]:
for current_df, current_df_type in zip([train_df, validation_df, test_df], ['train', 'validation', 'test']):
    write_df_to_pickle(current_df, MEDBERT_OUTPUT_PICKLES_DIR, current_df_type, disease_name='eicu_crd_adult_respiratory_failure')

In [31]:
MEDBERT_OUTPUT_PICKLES_DIR

'/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/eicu_crd'