In [1]:
import tensorflow as tf
from glob import glob
import numpy as np
import pickle
import json
import os
import random
import matplotlib.pyplot as plt

In [2]:
# only using CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
def save_data(output_path, mydata):
    with open(output_path, 'wb') as f:
        
        pickle.dump(mydata, f)
        
def load_data(data_path):
    data = pickle.load(open(data_path, 'rb'))

    return data

def convert_binary_score(severe_score):
    
    if severe_score >= 10:
        return 1
    else:
        return 0
    
def lower_filepath(filepath):
    
    filename, fileformat = filepath.split(".")
    fileformat = fileformat.lower()
    
    converted_filepath = filename + "." + fileformat
    
    return converted_filepath
    
def shuffle_data(mydata):
    mydata = np.array(mydata)
    idx = np.arange(len(mydata))
    random.shuffle(idx)
    
    return mydata[idx]
    
def filter_available_visits(amd_record_visit):
    
    filtered_visits = []
    
    for visit in amd_record_visit:
        re_label_test = False
        re_data_test = False
        le_label_test = False
        le_data_test = False
        
        try:
            this_re_severe_score = visit["AMDSEVRE"]
            if np.isnan(this_re_severe_score):
                re_label_test = False
            else:
                re_label_test = True
        except:
            continue
        
        try:
            this_le_severe_score = visit["AMDSEVLE"]
            if np.isnan(this_le_severe_score):
                le_label_test = False
            else:
                le_label_test = True
        except:
            continue
            
        try:
            this_re_data = visit["RE_IMG"]
            if len(this_re_data) > 0:
                re_data_test = True
            else:
                re_data_test = False
        except:
            continue
            
        try:
            this_le_data = visit["LE_IMG"]
            if len(this_le_data) > 0:
                le_data_test = True
            else:
                le_data_test = False
        except:
            continue
            
        test_result = re_label_test * le_label_test * re_data_test * le_data_test
        
        if test_result == 1:
            filtered_visits.append(visit)
            
    return filtered_visits

In [4]:
def build_patient_dict(data_dir, remove_recurrent=True):
    
    json_data = open(data_dir)
    amd_data = json.load(json_data)
    patient_dict = dict()
    
    count_idx = 0
    count_patient = 0
    count_removed_patient = 0
    for idx, record in enumerate(amd_data):
        count_patient += 1
        if idx % 100 == 0:
            print("{} patients processed...".format(count_idx*100))
            count_idx += 1
        
        this_id = record["ID2"]
        this_visits = record["VISITS"]
        filtered_visits = filter_available_visits(this_visits)
        this_re = dict()
        this_le = dict()
        
        re_year = []
        le_year = []
        re_img = []
        le_img = []
        re_severe_score = []
        le_severe_score = []
        re_late_amd = []
        le_late_amd = []
        
        if len(filtered_visits) <= 1:
            count_removed_patient += 1
            continue
            
        for i, visit in enumerate(filtered_visits):
            
            re_year.append(int(visit["VISNO"])/2)
            le_year.append(int(visit["VISNO"])/2)
            re_img.append(lower_filepath(visit["RE_IMG"]))
            le_img.append(lower_filepath(visit["LE_IMG"]))
            re_severe_score.append(visit["AMDSEVRE"])
            le_severe_score.append(visit["AMDSEVLE"])
            re_late_amd.append(convert_binary_score(visit["AMDSEVRE"]))
            le_late_amd.append(convert_binary_score(visit["AMDSEVLE"]))
            
        re_late_amd = np.array(re_late_amd)
        le_late_amd = np.array(le_late_amd)
        
        this_re["re_year"] = re_year
        this_re["re_img"] = re_img
        this_re["re_severe_score"] = re_severe_score
        this_re["re_late_amd"] = re_late_amd
        
        this_le["le_year"] = le_year
        this_le["le_img"] = le_img
        this_le["le_severe_score"] = le_severe_score
        this_le["le_late_amd"] = le_late_amd
        
        patient_late_amd_check = np.sum(re_late_amd) + np.sum(le_late_amd)
        
        if patient_late_amd_check > 0:
            patient_late_amd_label = 1
        else:
            patient_late_amd_label = 0
        
        patient_dict[this_id] = {"re" : this_re, "le" : this_le, "late_amd_label" : patient_late_amd_label}
    
    print("the number of patients: {}".format(count_patient))
    print("{} patients that have no valid visits were excluded".format(count_removed_patient))
    
    return patient_dict

In [5]:
patient_dict = build_patient_dict("/home/jl5307/current_research/AMD_prediction/data/AREDS_participants_amd3.json", remove_recurrent=True)

0 patients processed...
100 patients processed...
200 patients processed...
300 patients processed...
400 patients processed...
500 patients processed...
600 patients processed...
700 patients processed...
800 patients processed...
900 patients processed...
1000 patients processed...
1100 patients processed...
1200 patients processed...
1300 patients processed...
1400 patients processed...
1500 patients processed...
1600 patients processed...
1700 patients processed...
1800 patients processed...
1900 patients processed...
2000 patients processed...
2100 patients processed...
2200 patients processed...
2300 patients processed...
2400 patients processed...
2500 patients processed...
2600 patients processed...
2700 patients processed...
2800 patients processed...
2900 patients processed...
3000 patients processed...
3100 patients processed...
3200 patients processed...
3300 patients processed...
3400 patients processed...
3500 patients processed...
3600 patients processed...
3700 patients

In [6]:
patient_dict["1001"]

{'re': {'re_year': [0.0, 2.0, 3.0, 4.0, 5.0, 7.0, 8.0, 10.0],
  're_img': ['51685 QUA F2 RE LS.jpg',
   '51685 04 F2 RE LS.jpg',
   '51685 06 F2 RE LS.jpg',
   '51685 08 F2 RE LS.jpg',
   '51685 10 F2 RE LS.jpg',
   '51685 14 F2 RE LS.jpg',
   '51685 16 F2 RE LS.jpg',
   '51685 20 F2 RE LS.jpg'],
  're_severe_score': [4.0, 3.0, 4.0, 4.0, 5.0, 6.0, 8.0, 7.0],
  're_late_amd': array([0, 0, 0, 0, 0, 0, 0, 0])},
 'le': {'le_year': [0.0, 2.0, 3.0, 4.0, 5.0, 7.0, 8.0, 10.0],
  'le_img': ['51685 QUA F2 LE LS.jpg',
   '51685 04 F2 LE LS.jpg',
   '51685 06 F2 LE LS.jpg',
   '51685 08 F2 LE LS.jpg',
   '51685 10 F2 LE LS.jpg',
   '51685 14 F2 LE LS.jpg',
   '51685 16 F2 LE LS.jpg',
   '51685 20 F2 LE LS.jpg'],
  'le_severe_score': [1.0, 3.0, 3.0, 2.0, 4.0, 6.0, 8.0, 6.0],
  'le_late_amd': array([0, 0, 0, 0, 0, 0, 0, 0])},
 'late_amd_label': 0}

In [7]:
# split patient dict into train-validation-test set

In [8]:
def split_patient_dict(patient_dict, train_size, validation_size, test_size):
    
    late_amd_patient_list = []
    non_late_amd_patient_list = []
    
    for pid, value in patient_dict.items():
        
        if value["late_amd_label"] == 1:
            late_amd_patient_list.append(pid)
        else:
            non_late_amd_patient_list.append(pid)
            
    late_amd_patient_list = shuffle_data(late_amd_patient_list)
    non_late_amd_patient_list = shuffle_data(non_late_amd_patient_list)
    
    print("the number of patients have late-AMD at least one eye: {}".format(len(late_amd_patient_list)))
    print("the number of patients do have late-AMD at least one eye: {}".format(len(non_late_amd_patient_list)))
    
    ## train set
    train_late_amd_patient_list = late_amd_patient_list[:int(np.floor(len(late_amd_patient_list)*train_size))]
    train_non_late_amd_patient_list = non_late_amd_patient_list[:int(np.floor(len(non_late_amd_patient_list)*train_size))]  
    
    ## test and validation set
    test_validation_late_amd_patient_list = late_amd_patient_list[int(np.floor(len(late_amd_patient_list)*train_size)):]
    test_validation_non_late_amd_patient_list = non_late_amd_patient_list[int(np.floor(len(non_late_amd_patient_list)*train_size)):]
    
    late_amd_test_validation_slicing_ind = int(np.floor(len(test_validation_late_amd_patient_list) * (test_size / (validation_size+test_size))))
    non_late_amd_test_validation_slicing_ind = int(np.floor(len(test_validation_non_late_amd_patient_list) * (test_size / (validation_size+test_size))))
    
    test_late_amd_patient_list = test_validation_late_amd_patient_list[:late_amd_test_validation_slicing_ind]
    test_non_late_amd_patient_list = test_validation_non_late_amd_patient_list[:non_late_amd_test_validation_slicing_ind]
    
    validation_late_amd_patient_list = test_validation_late_amd_patient_list[late_amd_test_validation_slicing_ind:]
    validation_non_late_amd_patient_list = test_validation_non_late_amd_patient_list[non_late_amd_test_validation_slicing_ind:]

    train_pid_list = []
    validation_pid_list = []
    test_pid_list = []

    train_pid_list.extend(train_late_amd_patient_list)
    train_pid_list.extend(train_non_late_amd_patient_list)
    validation_pid_list.extend(validation_late_amd_patient_list)
    validation_pid_list.extend(validation_non_late_amd_patient_list)
    test_pid_list.extend(test_late_amd_patient_list)
    test_pid_list.extend(test_non_late_amd_patient_list)
    
    train_pid_list = shuffle_data(train_pid_list)
    validation_pid_list = shuffle_data(validation_pid_list)
    test_pid_list = shuffle_data(test_pid_list)

    train_set = dict()
    validation_set = dict()
    test_set = dict()
    
    for pid in train_pid_list:
        train_set[pid] = patient_dict[pid]
    
    for pid in validation_pid_list:
        validation_set[pid] = patient_dict[pid]
        
    for pid in test_pid_list:
        test_set[pid] = patient_dict[pid]
        
    print(len(train_set))
    print(len(validation_set))
    print(len(test_set))
    
    return {"train_set" : train_set, "validation_set" : validation_set, "test_set" : test_set}

In [9]:
splitted_patient_dict = split_patient_dict(patient_dict, 0.7, 0.15, 0.15)

the number of patients have late-AMD at least one eye: 1223
the number of patients do have late-AMD at least one eye: 3092
3020
648
647


In [10]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/splitted_patient_dict.pkl", splitted_patient_dict)

In [4]:
splitted_patient_dict = load_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/splitted_patient_dict.pkl")

In [5]:
def filter_by_length(eye_list, label_list, min_len):
    
    filtered_eye_list = []
    filtered_label_list = []
    
    for eye_img_list, eye_label_list in zip(eye_list, label_list):
        
        if len(eye_img_list) >= min_len and len(eye_label_list) >= min_len:
            filtered_eye_list.append(eye_img_list)
            filtered_label_list.append(eye_label_list)
        else:
            continue
            
    return filtered_eye_list, filtered_label_list

def build_longitudinal_sequential_prediction_data_dict(splitted_patient_dict, timedelta, min_length, remove_recurrent=True, generate_per_len_test_set=True):
    
    train_set = splitted_patient_dict["train_set"]
    validation_set = splitted_patient_dict["validation_set"]
    test_set = splitted_patient_dict["test_set"]
    
    train_eye_dict = dict()
    validation_eye_dict = dict()
    test_eye_dict = dict()
    
    train_eye_list = []
    train_label_list = []
    validation_eye_list = []
    validation_label_list = []
    test_eye_list = []
    test_label_list = []
    
    train_eye_exclusion_count = 0
    validation_eye_exclusion_count = 0
    test_eye_exclusion_count = 0
    
    # train set
    for pid, value in train_set.items():
        
        re_year = np.array(value["re"]["re_year"])
        re_img_list = value["re"]["re_img"]
        re_severe_score = value["re"]["re_severe_score"]
        re_late_amd = np.array(value["re"]["re_late_amd"])
        
        if len(re_img_list) != len(re_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
        
        le_year = np.array(value["le"]["le_year"])
        le_img_list = value["le"]["le_img"]
        le_severe_score = value["le"]["le_severe_score"]
        le_late_amd = np.array(value["le"]["le_late_amd"])
        
        if len(le_img_list) != len(le_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
            
        if remove_recurrent:
            
            if np.sum(re_late_amd) > 0: # test whether the eye had late-amd status
                re_first_late_amd_idx = np.where(re_late_amd == 1)[0][0]
                re_year = re_year[:re_first_late_amd_idx+1]
                re_img_list = re_img_list[:re_first_late_amd_idx+1]
                re_severe_score = re_severe_score[:re_first_late_amd_idx+1]
                re_late_amd = re_late_amd[:re_first_late_amd_idx+1]
                
            if np.sum(le_late_amd) > 0: # test whether the eye had late-amd status
                le_first_late_amd_idx = np.where(le_late_amd == 1)[0][0]
                le_year = le_year[:le_first_late_amd_idx+1]
                le_img_list = le_img_list[:le_first_late_amd_idx+1]
                le_severe_score = le_severe_score[:le_first_late_amd_idx+1]
                le_late_amd = le_late_amd[:le_first_late_amd_idx+1]
            
        re_total_year_len = re_year[-1] - re_year[0]
        le_total_year_len = le_year[-1] - le_year[0]
        re_label_list = []
        le_label_list = []
        
        if re_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(re_year, re_img_list)):
                
                if idx == (len(re_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(re_year > year, True, False)
                    label_year_ind2 = np.where(re_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        re_label_list.append(int(np.max(re_late_amd[label_ind])))
                    else:
                        re_label_list.append(re_late_amd[idx+1])
            
        else:
            train_eye_exclusion_count += 1
            
        if le_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(le_year, le_img_list)):
                
                if idx == (len(le_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(le_year > year, True, False)
                    label_year_ind2 = np.where(le_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        le_label_list.append(int(np.max(le_late_amd[label_ind])))
                    else:
                        le_label_list.append(le_late_amd[idx+1])
            
        else:
            train_eye_exclusion_count += 1
            
        if len(re_label_list) > 0:
            assert len(re_img_list[:-1]) == len(re_label_list), "length of the label list and img list must be the same"
            train_eye_list.append(re_img_list[:-1])
            train_label_list.append(re_label_list)
        
        if len(le_label_list) > 0:
            assert len(le_img_list[:-1]) == len(le_label_list), "length of the label list and img list must be the same"
            train_eye_list.append(le_img_list[:-1])
            train_label_list.append(le_label_list)
        
    train_eye_dict["eye_list"] = train_eye_list
    train_eye_dict["label_list"] = train_label_list
    
    # validation set
    for pid, value in validation_set.items():
        
        re_year = np.array(value["re"]["re_year"])
        re_img_list = value["re"]["re_img"]
        re_severe_score = value["re"]["re_severe_score"]
        re_late_amd = np.array(value["re"]["re_late_amd"])
        
        if len(re_img_list) != len(re_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
        
        le_year = np.array(value["le"]["le_year"])
        le_img_list = value["le"]["le_img"]
        le_severe_score = value["le"]["le_severe_score"]
        le_late_amd = np.array(value["le"]["le_late_amd"])
        
        if len(le_img_list) != len(le_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
            
        if remove_recurrent:
            
            if np.sum(re_late_amd) > 0: # test whether the eye had late-amd status
                re_first_late_amd_idx = np.where(re_late_amd == 1)[0][0]
                re_year = re_year[:re_first_late_amd_idx+1]
                re_img_list = re_img_list[:re_first_late_amd_idx+1]
                re_severe_score = re_severe_score[:re_first_late_amd_idx+1]
                re_late_amd = re_late_amd[:re_first_late_amd_idx+1]
                
            if np.sum(le_late_amd) > 0: # test whether the eye had late-amd status
                le_first_late_amd_idx = np.where(le_late_amd == 1)[0][0]
                le_year = le_year[:le_first_late_amd_idx+1]
                le_img_list = le_img_list[:le_first_late_amd_idx+1]
                le_severe_score = le_severe_score[:le_first_late_amd_idx+1]
                le_late_amd = le_late_amd[:le_first_late_amd_idx+1]
            
        re_total_year_len = re_year[-1] - re_year[0]
        le_total_year_len = le_year[-1] - le_year[0]
        re_label_list = []
        le_label_list = []
        
        if re_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(re_year, re_img_list)):
                
                if idx == (len(re_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(re_year > year, True, False)
                    label_year_ind2 = np.where(re_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        re_label_list.append(int(np.max(re_late_amd[label_ind])))
                    else:
                        re_label_list.append(re_late_amd[idx+1])
            
        else:
            validation_eye_exclusion_count += 1
            
        if le_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(le_year, le_img_list)):
                
                if idx == (len(le_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(le_year > year, True, False)
                    label_year_ind2 = np.where(le_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        le_label_list.append(int(np.max(le_late_amd[label_ind])))
                    else:
                        le_label_list.append(le_late_amd[idx+1])
            
        else:
            validation_eye_exclusion_count += 1
            
        if len(re_label_list) > 0:
            assert len(re_img_list[:-1]) == len(re_label_list), "length of the label list and img list must be the same"
            validation_eye_list.append(re_img_list[:-1])
            validation_label_list.append(re_label_list)
        
        if len(le_label_list) > 0:
            assert len(le_img_list[:-1]) == len(le_label_list), "length of the label list and img list must be the same"
            validation_eye_list.append(le_img_list[:-1])
            validation_label_list.append(le_label_list)
        
    validation_eye_dict["eye_list"] = validation_eye_list
    validation_eye_dict["label_list"] = validation_label_list
    
    # test set
    for pid, value in test_set.items():
        
        re_year = np.array(value["re"]["re_year"])
        re_img_list = value["re"]["re_img"]
        re_severe_score = value["re"]["re_severe_score"]
        re_late_amd = np.array(value["re"]["re_late_amd"])
        
        if len(re_img_list) != len(re_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
        
        le_year = np.array(value["le"]["le_year"])
        le_img_list = value["le"]["le_img"]
        le_severe_score = value["le"]["le_severe_score"]
        le_late_amd = np.array(value["le"]["le_late_amd"])
        
        if len(le_img_list) != len(le_late_amd):
            raise ValueError("the length of img_list and label_list must be the same")
            
        if remove_recurrent:
            
            if np.sum(re_late_amd) > 0: # test whether the eye had late-amd status
                re_first_late_amd_idx = np.where(re_late_amd == 1)[0][0]
                re_year = re_year[:re_first_late_amd_idx+1]
                re_img_list = re_img_list[:re_first_late_amd_idx+1]
                re_severe_score = re_severe_score[:re_first_late_amd_idx+1]
                re_late_amd = re_late_amd[:re_first_late_amd_idx+1]
                
            if np.sum(le_late_amd) > 0: # test whether the eye had late-amd status
                le_first_late_amd_idx = np.where(le_late_amd == 1)[0][0]
                le_year = le_year[:le_first_late_amd_idx+1]
                le_img_list = le_img_list[:le_first_late_amd_idx+1]
                le_severe_score = le_severe_score[:le_first_late_amd_idx+1]
                le_late_amd = le_late_amd[:le_first_late_amd_idx+1]
            
        re_total_year_len = re_year[-1] - re_year[0]
        le_total_year_len = le_year[-1] - le_year[0]
        re_label_list = []
        le_label_list = []
        
        if re_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(re_year, re_img_list)):
                
                if idx == (len(re_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(re_year > year, True, False)
                    label_year_ind2 = np.where(re_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        re_label_list.append(int(np.max(re_late_amd[label_ind])))
                    else:
                        re_label_list.append(re_late_amd[idx+1])
            
        else:
            test_eye_exclusion_count += 1
            
        if le_total_year_len >= timedelta:
            for idx, (year, img) in enumerate(zip(le_year, le_img_list)):
                
                if idx == (len(le_year)-1):
                    continue
                else:
                    label_year_end = year + timedelta
                    label_year_ind1 = np.where(le_year > year, True, False)
                    label_year_ind2 = np.where(le_year <= label_year_end, True, False)
                    label_ind = label_year_ind1 * label_year_ind2
                
                    if np.sum(label_ind) > 0:
                        le_label_list.append(int(np.max(le_late_amd[label_ind])))
                    else:
                        le_label_list.append(le_late_amd[idx+1])
            
        else:
            test_eye_exclusion_count += 1
            
        if len(re_label_list) > 0:
            assert len(re_img_list[:-1]) == len(re_label_list), "length of the label list and img list must be the same"
            test_eye_list.append(re_img_list[:-1])
            test_label_list.append(re_label_list)
        
        if len(le_label_list) > 0:
            assert len(le_img_list[:-1]) == len(le_label_list), "length of the label list and img list must be the same"
            test_eye_list.append(le_img_list[:-1])
            test_label_list.append(le_label_list)
        
    test_eye_dict["eye_list"] = test_eye_list
    test_eye_dict["label_list"] = test_label_list
    
    print("{} eyes excluded from train set".format(train_eye_exclusion_count))
    print("{} eyes excluded from validation set".format(validation_eye_exclusion_count))
    print("{} eyes excluded from test set".format(test_eye_exclusion_count))
    
    if generate_per_len_test_set:
        
        per_len_test_eye_dict = dict()

        for length in range(min_length):
            per_len_test_eye_dict[length+1] = {"eye_list" : [], "label_list" : []}
            
        print("filter the test data based on the minimum length ({})".format(min_length))
        filtered_test_eye_list, filtered_test_label_list = filter_by_length(test_eye_list, test_label_list, min_length)
        assert len(filtered_test_eye_list) == len(filtered_test_label_list), "the number of eyes and labels must be the same"
        print("the number of eye-label pair: {}".format(len(filtered_test_eye_list)))
        
        for eye_img_list, eye_label_list in zip(filtered_test_eye_list, filtered_test_label_list):
            
            for idx in range(min_length):
                per_len_test_eye_dict[idx+1]["eye_list"].append(eye_img_list[-(idx+1):])
                per_len_test_eye_dict[idx+1]["label_list"].append(eye_label_list[-(idx+1):])

        return {"train_set" : train_eye_dict, "validation_set" : validation_eye_dict, "test_set" : test_eye_dict, "per_length_test_set" : per_len_test_eye_dict}
    
    else:
        return {"train_set" : train_eye_dict, "validation_set" : validation_eye_dict, "test_set" : test_eye_dict}

In [6]:
longitudinal_sequential_prediction_td2_min5_data_dict = build_longitudinal_sequential_prediction_data_dict(splitted_patient_dict, timedelta=2, min_length=5)

691 eyes excluded from train set
136 eyes excluded from validation set
142 eyes excluded from test set
filter the test data based on the minimum length (5)
the number of eye-label pair: 766


In [7]:
longitudinal_sequential_prediction_td5_min5_data_dict = build_longitudinal_sequential_prediction_data_dict(splitted_patient_dict, timedelta=5, min_length=5)

1670 eyes excluded from train set
360 eyes excluded from validation set
342 eyes excluded from test set
filter the test data based on the minimum length (5)
the number of eye-label pair: 763


In [8]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_td2_min5_data_dict.pkl",
         longitudinal_sequential_prediction_td2_min5_data_dict)

In [9]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_td5_min5_data_dict.pkl",
         longitudinal_sequential_prediction_td5_min5_data_dict)

In [None]:
# unroll patients in the data dictionary

In [10]:
def unroll_sequence(sequence):
    
    unrolled_sequence = []
    
    for i in range(len(sequence)):
        unrolled_sequence.append(sequence[:(i+1)])
        
    return unrolled_sequence

def build_unrolled_longitudinal_sequential_prediction_data_dict(longitudinal_sequential_prediction_data_dict):
    # unroll patients' visit sequences in the training, validation, and test set
    
    train_eye_list = longitudinal_sequential_prediction_data_dict["train_set"]["eye_list"]
    train_label_list = longitudinal_sequential_prediction_data_dict["train_set"]["label_list"]
    validation_eye_list = longitudinal_sequential_prediction_data_dict["validation_set"]["eye_list"]
    validation_label_list = longitudinal_sequential_prediction_data_dict["validation_set"]["label_list"]
    test_eye_list = longitudinal_sequential_prediction_data_dict["test_set"]["eye_list"]
    test_label_list = longitudinal_sequential_prediction_data_dict["test_set"]["label_list"]
    
    unrolled_train_eye_list = []
    unrolled_train_label_list = []
    unrolled_validation_eye_list = []
    unrolled_validation_label_list = []
    unrolled_test_eye_list = []
    unrolled_test_label_list = []
    
    for eye_list, label_list in zip(train_eye_list, train_label_list):
        unrolled_train_eye_list.extend(unroll_sequence(eye_list))
        unrolled_train_label_list.extend(unroll_sequence(label_list))

    for eye_list, label_list in zip(validation_eye_list, validation_label_list):
        unrolled_validation_eye_list.extend(unroll_sequence(eye_list))
        unrolled_validation_label_list.extend(unroll_sequence(label_list))
        
    for eye_list, label_list in zip(test_eye_list, test_label_list):
        unrolled_test_eye_list.extend(unroll_sequence(eye_list))
        unrolled_test_label_list.extend(unroll_sequence(label_list))
        
    unrolled_longitudinal_sequential_prediction_data_dict = dict()
    
    unrolled_longitudinal_sequential_prediction_data_dict["train_set"] = {"eye_list" : unrolled_train_eye_list,
                                                                          "label_list" : unrolled_train_label_list}
    unrolled_longitudinal_sequential_prediction_data_dict["validation_set"] = {"eye_list" : unrolled_validation_eye_list,
                                                                          "label_list" : unrolled_validation_label_list}
    unrolled_longitudinal_sequential_prediction_data_dict["test_set"] = {"eye_list" : unrolled_test_eye_list,
                                                                          "label_list" : unrolled_test_label_list}
    unrolled_longitudinal_sequential_prediction_data_dict["per_length_test_set"] = longitudinal_sequential_prediction_data_dict["per_length_test_set"]
    
    assert len(unrolled_train_eye_list) == len(unrolled_train_label_list), "the length of eye list and label list in training set must be the same"
    assert len(unrolled_validation_eye_list) == len(unrolled_validation_label_list), "the length of eye list and label list in validation set must be the same"
    assert len(unrolled_test_eye_list) == len(unrolled_test_label_list), "the length of eye list and label list in test set must be the same"
    
    print("the number of instances in train set: {}".format(len(unrolled_train_eye_list)))
    print("the number of instances in validation set: {}".format(len(unrolled_validation_eye_list)))
    print("the number of instances in test set: {}".format(len(unrolled_test_eye_list)))
    
    return unrolled_longitudinal_sequential_prediction_data_dict

In [12]:
unrolled_longitudinal_sequential_prediction_td2_min5_data_dict = build_unrolled_longitudinal_sequential_prediction_data_dict(longitudinal_sequential_prediction_td2_min5_data_dict)

the number of instances in train set: 34536
the number of instances in validation set: 7396
the number of instances in test set: 7429


In [13]:
unrolled_longitudinal_sequential_prediction_td5_min5_data_dict = build_unrolled_longitudinal_sequential_prediction_data_dict(longitudinal_sequential_prediction_td5_min5_data_dict)

the number of instances in train set: 32627
the number of instances in validation set: 6935
the number of instances in test set: 6996


In [14]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_td2_min5_data_dict_unrolled.pkl",
         unrolled_longitudinal_sequential_prediction_td2_min5_data_dict)

In [15]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_td5_min5_data_dict_unrolled.pkl",
         unrolled_longitudinal_sequential_prediction_td5_min5_data_dict)

In [None]:
# stratified batch

In [57]:
def shuffle_data_pair(mydata1, mydata2):
    assert len(mydata1) == len(mydata2), "the length of each data in the data pair must be the same"
    mydata1_array = np.array(mydata1)
    mydata2_array = np.array(mydata2)
    idx = np.arange(len(mydata1_array))
    random.shuffle(idx)
    
    return list(mydata1_array[idx]), list(mydata2_array[idx])

def build_stratified_batch(eye_list, label_list, batch_size):
    
    late_amd_eye_list = []
    non_late_amd_eye_list = []
    late_amd_label_list = []
    non_late_amd_label_list = []
    
    for eyes, labels in zip(eye_list, label_list):
        
        if labels[-1] == 1:
            late_amd_eye_list.append(eyes)
            late_amd_label_list.append(labels)
        else:
            non_late_amd_eye_list.append(eyes)
            non_late_amd_label_list.append(labels)
            
    late_amd_eye_list, late_amd_label_list = shuffle_data_pair(late_amd_eye_list, late_amd_label_list)
    non_late_amd_eye_list, non_late_amd_label_list = shuffle_data_pair(non_late_amd_eye_list, non_late_amd_label_list)
    
    non_late_amd_batch = int(np.ceil(len(non_late_amd_eye_list)/(batch_size-1)))
    
    total_batch = np.max([len(late_amd_eye_list), non_late_amd_batch])
    
    if total_batch > len(late_amd_eye_list):
        addition_length = total_batch - len(late_amd_eye_list)
        late_amd_eye_list.extend(late_amd_eye_list[:addition_length])
        late_amd_label_list.extend(late_amd_label_list[:addition_length])
        
    elif total_batch > int(np.ceil(len(non_late_amd_eye_list)/(batch_size-1))):
        addition_length = int((total_batch - np.ceil(len(non_late_amd_eye_list)/(batch_size-1))) * 31)
        non_late_amd_eye_list.extend(non_late_amd_eye_list[:addition_length])
        non_late_amd_label_list.extend(non_late_amd_label_list[:addition_length])
    
    stratified_eye_list = []
    stratified_label_list = []
    
    for i in range(total_batch):
        non_late_amd_eyes = non_late_amd_eye_list[(batch_size-1)*i:(batch_size-1)*(i+1)]
        non_late_amd_labels = non_late_amd_label_list[(batch_size-1)*i:(batch_size-1)*(i+1)]
        late_amd_eyes = late_amd_eye_list[i:(i+1)]
        late_amd_labels = late_amd_label_list[i:(i+1)]
        
        stratified_eye_list.extend(non_late_amd_eyes)
        stratified_label_list.extend(non_late_amd_labels)
        stratified_eye_list.extend(late_amd_eyes)
        stratified_label_list.extend(late_amd_labels)
    
    return stratified_eye_list, stratified_label_list

In [31]:
def count_total_label(dataset):
    
    label_list = dataset["label_list"]
    
    total_label_num = 0
    total_late_amd = 0
    
    for eye_label_list in label_list:
        total_label_num += len(eye_label_list)
        total_late_amd += np.sum(eye_label_list)
        
    print("total label: {}".format(total_label_num))
    print("total late amd: {}".format(total_late_amd))
    print("total non late amd: {}".format(total_label_num-total_late_amd))
    
def count_total_label_unrolled(dataset):
    
    label_list = dataset["label_list"]
    
    total_label_num = 0
    total_late_amd = 0
    
    for eye_label_list in label_list:
        total_label_num += 1
        total_late_amd += np.max(eye_label_list)
        
    print("total label: {}".format(total_label_num))
    print("total late amd: {}".format(total_late_amd))
    print("total non late amd: {}".format(total_label_num-total_late_amd))

In [34]:
count_total_label_unrolled(unrolled_longitudinal_sequential_prediction_timedelta2_min7_data_dict["train_set"])

total label: 34536
total late amd: 851
total non late amd: 33685


In [33]:
count_total_label_unrolled(unrolled_longitudinal_sequential_prediction_timedelta5_min7_data_dict["train_set"])

total label: 32627
total late amd: 1057
total non late amd: 31570


In [6]:
def calculate_per_length_statistics(longitudinal_sequential_prediction_data_dict):
    
    per_length_statistics = dict()
    
    per_length_test_set = longitudinal_sequential_prediction_data_dict["per_length_test_set"]
    total_late_amd_count = 0
    total_patient_count = 0
    
    for length, length_dict in per_length_test_set.items():
        
        this_length_label_list = length_dict["label_list"]
        total_patient_count += len(this_length_label_list)
        this_length_late_amd_count = 0

        for label_list in this_length_label_list:
            
            if label_list[-1] == 1:
                this_length_late_amd_count += 1
                total_late_amd_count += 1
        
        per_length_statistics[length] = {"late_amd_count" : this_length_late_amd_count, 
                                         "non_late_amd_count" : (len(this_length_label_list) - this_length_late_amd_count)}
        
    per_length_statistics["total"] = {"total_late_amd_count" : total_late_amd_count, 
                                      "total_non_late_amd_count" : total_patient_count - total_late_amd_count}
    
    return per_length_statistics

def calculate_per_length_statistics_unrolled(longitudinal_sequential_prediction_data_dict):
    
    per_length_statistics = dict()
    
    per_length_test_set = longitudinal_sequential_prediction_data_dict["per_length_test_set"]
    total_late_amd_count = 0
    total_data_count = 0
    
    for length, length_dict in per_length_test_set.items():
        
        this_length_label_list = length_dict["label_list"]
        total_data_count += len(this_length_label_list)
        this_length_late_amd_count = 0

        for label_list in this_length_label_list:
            
            if label_list[-1] == 1:
                this_length_late_amd_count += 1
                total_late_amd_count += 1
        
        per_length_statistics[length] = {"late_amd_count" : this_length_late_amd_count, 
                                         "non_late_amd_count" : (len(this_length_label_list) - this_length_late_amd_count)}
        
    per_length_statistics["total"] = {"total_late_amd_count" : total_late_amd_count, 
                                      "total_non_late_amd_count" : total_data_count - total_late_amd_count}
    
    return per_length_statistics

In [None]:
data_dict

In [7]:
data_dict_min7 = load_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta2_min7_data_dict_unrolled.pkl")
data_dict_min5 = load_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta2_data_dict_min5_unrolled.pkl")

In [17]:
len(longitudinal_sequential_prediction_timedelta2_min5_data_dict["per_length_test_set"][2]["eye_list"])

766

In [20]:
data_dict_min5["per_length_test_set"][2]["eye_list"] == longitudinal_sequential_prediction_timedelta2_min5_data_dict["per_length_test_set"][2]["eye_list"]

True

In [8]:
calculate_per_length_statistics_unrolled(data_dict_min7)

{1: {'late_amd_count': 20, 'non_late_amd_count': 567},
 2: {'late_amd_count': 20, 'non_late_amd_count': 567},
 3: {'late_amd_count': 20, 'non_late_amd_count': 567},
 4: {'late_amd_count': 20, 'non_late_amd_count': 567},
 5: {'late_amd_count': 20, 'non_late_amd_count': 567},
 6: {'late_amd_count': 20, 'non_late_amd_count': 567},
 7: {'late_amd_count': 20, 'non_late_amd_count': 567},
 'total': {'total_late_amd_count': 140, 'total_non_late_amd_count': 3969}}

In [9]:
calculate_per_length_statistics_unrolled(data_dict_min5)

{1: {'late_amd_count': 38, 'non_late_amd_count': 728},
 2: {'late_amd_count': 38, 'non_late_amd_count': 728},
 3: {'late_amd_count': 38, 'non_late_amd_count': 728},
 4: {'late_amd_count': 38, 'non_late_amd_count': 728},
 5: {'late_amd_count': 38, 'non_late_amd_count': 728},
 'total': {'total_late_amd_count': 190, 'total_non_late_amd_count': 3640}}

In [44]:
unrolled_longitudinal_sequential_prediction_timedelta2_min7_data_dict["per_length_test_set"][1]["label_list"]

[[1],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [1],
 [1],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [1],
 [1],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0]

In [9]:
calculate_per_length_statistics(longitudinal_sequential_prediction_timedelta2_data_dict)

{1: {'late_amd_count': 34, 'non_late_amd_count': 1118},
 2: {'late_amd_count': 42, 'non_late_amd_count': 1047},
 3: {'late_amd_count': 32, 'non_late_amd_count': 962},
 4: {'late_amd_count': 23, 'non_late_amd_count': 856},
 5: {'late_amd_count': 18, 'non_late_amd_count': 748},
 6: {'late_amd_count': 13, 'non_late_amd_count': 694},
 7: {'late_amd_count': 10, 'non_late_amd_count': 577},
 8: {'late_amd_count': 7, 'non_late_amd_count': 507},
 9: {'late_amd_count': 6, 'non_late_amd_count': 416},
 10: {'late_amd_count': 4, 'non_late_amd_count': 228},
 11: {'late_amd_count': 3, 'non_late_amd_count': 76},
 12: {'late_amd_count': 0, 'non_late_amd_count': 8},
 'total': {'total_late_amd_count': 192, 'total_non_late_amd_count': 7237}}

In [22]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta2_data_dict_min5.pkl", longitudinal_sequential_prediction_timedelta2_data_dict)

In [28]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta5_data_dict.pkl", longitudinal_sequential_prediction_timedelta5_data_dict)

In [None]:
# convert longitudinal prediction dictionary to ResNet test dictionary

In [20]:
longitudinal_sequential_prediction_timedelta2_data_dict = load_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta2_data_dict.pkl")

In [21]:
longitudinal_sequential_prediction_timedelta5_data_dict = load_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta5_data_dict.pkl")

In [22]:
def convert_longitudinal_sequential_prediction_data_dict_for_resnet_test(longitudinal_sequential_prediction_data_dict, include_per_length_test_set=True):
    
    validation_set_eye_list = longitudinal_sequential_prediction_data_dict["validation_set"]["eye_list"]
    validation_set_label_list = longitudinal_sequential_prediction_data_dict["validation_set"]["label_list"]
    test_set_eye_list = longitudinal_sequential_prediction_data_dict["test_set"]["eye_list"]
    test_set_label_list = longitudinal_sequential_prediction_data_dict["test_set"]["label_list"]
    
    converted_validation_set = dict()
    converted_test_set = dict()

    for eye_list, label_list in zip(validation_set_eye_list, validation_set_label_list):
        converted_validation_set[eye_list[-1]] = label_list[-1]
        
    for eye_list, label_list in zip(test_set_eye_list, test_set_label_list):
        converted_test_set[eye_list[-1]] = label_list[-1]
        
    if include_per_length_test_set:
        
        per_length_test_set = longitudinal_sequential_prediction_data_dict["per_length_test_set"]
        converted_per_len_test_eye_dict = dict()
        
        for length, length_dict in per_length_test_set.items():
            this_length_eye_list = length_dict["eye_list"]
            this_length_label_list = length_dict["label_list"]
            
            converted_this_length_test_set = dict()
            
            for eye_list, label_list in zip(this_length_eye_list, this_length_label_list):
                converted_this_length_test_set[eye_list[-1]] = label_list[-1]
                
            converted_per_len_test_eye_dict[length] = converted_this_length_test_set
            
        return {"validation_set" : converted_validation_set, "test_set" : converted_test_set, "per_length_test_set" : converted_per_len_test_eye_dict}

    else:
        return {"validation_set" : converted_validation_set, "test_set" : converted_test_set}

In [23]:
longitudinal_sequential_prediction_timedelta2_data_dict_resnet_test = convert_longitudinal_sequential_prediction_data_dict_for_resnet_test(longitudinal_sequential_prediction_timedelta2_data_dict)

In [24]:
longitudinal_sequential_prediction_timedelta5_data_dict_resnet_test = convert_longitudinal_sequential_prediction_data_dict_for_resnet_test(longitudinal_sequential_prediction_timedelta5_data_dict)

In [15]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta2_data_dict_resnet_test.pkl", longitudinal_sequential_prediction_timedelta2_data_dict_resnet_test)

In [16]:
save_data("/home/jl5307/current_research/AMD_prediction/img_data/data_dictionary/longitudinal_sequential_prediction_timedelta5_data_dict_resnet_test.pkl", longitudinal_sequential_prediction_timedelta5_data_dict_resnet_test)

In [None]:
# 