In [14]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

from joblib import Parallel, delayed
import pickle
import time
from tqdm import tqdm
import os

In [3]:
data = pickle.load(open('../hierarchical/train_frontal_Bit_m-r101x1_with_labels.p', 'rb'))

In [4]:
X_all = np.array([vec for vec in data['vector']])

category_indices = np.array([6,8,10,11,12,13,14,15,16,17,18])
category_indices = np.array([6,8,10,11,12,13,14,15,16,17,18])

conditions = np.array(list(data.iloc[0, category_indices].keys()))

competition_conditions = ['No Finding', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
competition_conditions_indices = [category_indices[np.where(conditions == c)[0][0]] for c in competition_conditions]

##- Comment out if we want to use entire dataset
# category_indices = competition_conditions_indices.copy()
# conditions = competition_conditions.copy()

labels = np.arange(len(conditions))

n, d = X_all.shape
y_all = np.zeros(n)

fly_list = []
for i in range(n):
    temp_conditions = data.iloc[i, category_indices]
    positive_conditions = np.where(temp_conditions.values.astype(int) > 0)[0]
    
    if len(positive_conditions) > 1:
        temp_competition_condition_indices = []
        for pc in positive_conditions:
            if pc in competition_conditions_indices:
                temp_competition_condition_indices.append(pc)
        if len(temp_competition_condition_indices) == 1:
            y_all[i] = temp_competition_condition_indices[0]
            fly_list.append(i)
    elif len(positive_conditions) == 1:
        y_all[i] = positive_conditions[0]
        fly_list.append(i)
        
fly_list = np.array(fly_list)
X = X_all[fly_list]
y = y_all[fly_list]

In [6]:
idx_by_label = [np.where(y == c)[0] for c in labels]
print("total:", len(y))

for i, ibl in enumerate(idx_by_label):
    print(conditions[i], len(ibl))

total: 132490
No Finding 16974
Cardiomegaly 5507
Lung Lesion 2442
Edema 13744
Consolidation 3294
Pneumonia 1583
Atelectasis 15131
Pneumothorax 7598
Pleural Effusion 62310
Pleural Other 922
Fracture 2985


In [9]:
# from graspologic.cluster import AutoGMMCluster as GMM
from graspologic.cluster import GaussianCluster as GMM

n_iter=1
master_seed = 42
np.random.seed(master_seed)
seeds = np.random.randint(10000, size=n_iter)

cluster_dics = []
class_cond_clusters = []


for iteration in tqdm(range(n_iter)):
    start = time.time()
    seed =  seeds[iteration]
    train_inds, _, _, _ = train_test_split(np.arange(len(fly_list)), y, test_size=0.5, random_state=seed)
    
    X_train, y_train = X[train_inds], y[train_inds]
        
    idx_by_label = [np.where(y_train == c)[0] for c in labels]
    
    clusters = [np.zeros(len(ibl)) for ibl in idx_by_label]
    for i, ibl in enumerate(tqdm(idx_by_label)):
        clusters[i] = GMM(min_components=5, max_components=5, reg_covar=1e-3).fit_predict(X_train[ibl])
        
    idx_by_induced_label = []
    for i,c in enumerate(clusters):
        for j in np.unique(c):
            idx_by_induced_label.append(idx_by_label[i][np.where(c == j)[0]])
            
    y_induced = np.zeros(X_train.shape[0], dtype='int')
    for i, c in enumerate(idx_by_induced_label):
        y_induced[c] = i
        
        
    class_clusters_dic = {}
    
    for i, file_name in enumerate(list(data['Path'])):
        if i in fly_list[train_inds]:
            ind = np.where(fly_list[train_inds] == i)[0][0]
            class_clusters_dic[file_name] = y_induced[ind]
            
    cluster_dics.append(class_clusters_dic)
    
    data_dimension=128
        
    pca = PCA(n_components=data_dimension)
    pca.fit(X_train)
    X_train = pca.transform(X_train)

    unique_y = np.unique(y_induced)

    conditional_means = np.array([np.mean(X_train[np.where(y_induced == c)[0]], axis=0) for c in unique_y])

    gmm = GMM(min_components=5, max_components=5, reg_covar=1e-3)

    class_cond_clusters.append(gmm.fit_predict(conditional_means))

  0%|          | 0/1 [00:00<?, ?it/s]

  9%|▉         | 1/11 [01:48<18:02, 108.23s/it][A
 18%|█▊        | 2/11 [02:14<12:33, 83.69s/it] [A
 27%|██▋       | 3/11 [02:26<08:17, 62.23s/it][A
 36%|███▋      | 4/11 [03:45<07:50, 67.14s/it][A
 45%|████▌     | 5/11 [03:59<05:06, 51.10s/it][A
 55%|█████▍    | 6/11 [04:09<03:13, 38.75s/it][A
 64%|██████▎   | 7/11 [05:48<03:47, 56.87s/it][A
 73%|███████▎  | 8/11 [06:33<02:40, 53.36s/it][A
 82%|████████▏ | 9/11 [18:08<08:11, 245.98s/it][A
 91%|█████████ | 10/11 [18:16<02:54, 174.51s/it][A
100%|██████████| 11/11 [18:28<00:00, 100.75s/it][A
100%|██████████| 1/1 [19:10<00:00, 1150.95s/it]


In [13]:
pickle.dump(cluster_dics, open('mini_classes.pkl', 'wb'))
pickle.dump(class_cond_clusters, open('class_conditional_clusters.pkl', 'wb'))

In [139]:
conditions_to_induced_label_dic= {}
for i, cond in enumerate(conditions):
    for j in range(10):
        conditions_to_induced_label_dic[i*5 + j] = cond

In [141]:
pickle.dump(conditions_to_induced_label_dic, open('conditions_to_induced_label_dic.pkl', 'wb'))