In [1]:
import os
from keras.models import * 
from keras.layers import *
from keras.applications import *
from keras.preprocessing.image import *
from keras.utils.training_utils import multi_gpu_model
from multiprocessing import cpu_count
import tensorflow as tf

Using TensorFlow backend.


In [2]:
nb_classes = 21
nb_cpus = cpu_count()//2
nb_gpus = 4
os.environ["CUDA_VISIBLE_DEVICES"] = '4, 5, 6, 7'

image_size = (299, 299)
input_shape= (299, 299, 3)

In [3]:
with tf.device('/cpu:0'):
    input_tensor = Input(input_shape)
    x = Lambda(xception.preprocess_input)(input_tensor)

    base_model = Xception(input_tensor=x, weights=None, include_top=False)
    m_out = base_model.output
    p_out = GlobalAveragePooling2D()(m_out)
    p_out = Dropout(1.0)(p_out)
    predictions = Dense(nb_classes, activation='softmax')(p_out)
    model = Model(inputs=base_model.input, outputs=predictions)
    
model = multi_gpu_model(model, gpus=nb_gpus)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 299, 299, 3)  0                                            
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 299, 299, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 299, 299, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 299, 299, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_5 (

In [4]:
model.load_weights('weights_008_0.0881.hdf5')

In [5]:
batch_size = 64
test_path = "/home/hdd_array0/batch6.2_xcp/batch6.2-cells-half299/valid"

gen = ImageDataGenerator()
test_generator = gen.flow_from_directory(test_path, 
                                         target_size=image_size, 
                                         shuffle=False,
                                         batch_size=batch_size)
test_img_nums = test_generator.samples
all_test_results = model.predict_generator(test_generator, 
                                           len(test_generator), 
                                           workers=nb_cpus, 
                                           use_multiprocessing=True,
                                           verbose=1)
all_labels = test_generator.classes

Found 44457 images belonging to 21 classes.


In [6]:
class_label_dict = test_generator.class_indices
print(test_generator.class_indices)
def get_key(dict_, value):
    return [k for k, v in dict_.items() if v == value]

# create class num lens dict, every dict store current class predict num
total_predictions_dict = {}
for row_class_name, label in class_label_dict.items():
    total_predictions_dict[row_class_name] = {}
    for column_class_name, label in class_label_dict.items():
        total_predictions_dict[row_class_name][column_class_name] = 0

{'GEC': 7, 'MC': 13, 'TRI': 19, 'LSIL_F': 12, 'PH': 14, 'SCC_G': 17, 'HSIL_M': 9, 'EC': 5, 'VIRUS': 20, 'ACTINO': 0, 'ASCUS': 3, 'RC': 15, 'HSIL_S': 10, 'AGC_A': 1, 'FUNGI': 6, 'LSIL_E': 11, 'AGC_B': 2, 'HSIL_B': 8, 'SCC_R': 18, 'SC': 16, 'CC': 4}


In [16]:
thresh = 0.95

for i, label in enumerate(all_labels):
    predict_index = np.argmax(all_test_results[i])
    predict_det = all_test_results[i][predict_index]
    if (predict_det > thresh):
        # get the first result
        label_class_name = get_key(class_label_dict, label)[0]
        test_class_name = get_key(class_label_dict, np.argmax(all_test_results[i]))[0]
#     print(label_class_name)
#     print(total_predictions_dict[label_class_name].keys())
        total_predictions_dict[label_class_name][test_class_name] += 1

In [17]:
print(total_predictions_dict['RC'])

{'HSIL_S': 0, 'GEC': 22, 'MC': 0, 'LSIL_E': 0, 'PH': 0, 'SCC_G': 0, 'HSIL_M': 9, 'EC': 0, 'VIRUS': 0, 'TRI': 0, 'ASCUS': 0, 'RC': 1130, 'ACTINO': 0, 'AGC_A': 0, 'FUNGI': 0, 'LSIL_F': 0, 'AGC_B': 0, 'HSIL_B': 0, 'SCC_R': 0, 'SC': 0, 'CC': 0}


In [18]:
import csv
out = open('confusion_matrix.csv','a', newline='')
csv_write = csv.writer(out,dialect='excel')

# write the title
line = [class_name for class_name, label in class_label_dict.items()]
line = [" "] + line + ["TOTAL"] + ["ACC"]
csv_write.writerow(line)

# write rows

true_num = 0
all_num = 0

for row_class_name, label in class_label_dict.items():
    one_class_total_predict = 0
    line = [row_class_name]
    for column_class_name, label in class_label_dict.items():
        one_class_total_predict += total_predictions_dict[row_class_name][column_class_name]
    
    for column_class_name, label in class_label_dict.items():
        one_class_cur_predict = total_predictions_dict[row_class_name][column_class_name]
        # acc
        #acc = round((one_class_cur_predict / one_class_total_predict), 4)
        #line.append(str(acc))
        # num
        line.append(one_class_cur_predict)
    print(one_class_total_predict)
    print(str(one_class_total_predict))
    line.append(str(one_class_total_predict))
    line.append(round((total_predictions_dict[row_class_name][row_class_name] / one_class_total_predict), 4))       
    print(line)
    csv_write.writerow(line)
    
    true_num += total_predictions_dict[row_class_name][row_class_name]
    all_num += one_class_total_predict
    
csv_write.writerow(["ALL_ACC"] + [round((true_num / all_num), 4)])
    
out.close()

6189
6189
['GEC', 6083, 0, 0, 0, 0, 0, 1, 5, 10, 4, 9, 30, 5, 0, 0, 0, 13, 20, 0, 9, 0, '6189', 0.9829]
6736
6736
['MC', 0, 6644, 0, 0, 89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, '6736', 0.9863]
33609
33609
['TRI', 0, 0, 33598, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 0, '33609', 0.9997]
1335
1335
['LSIL_F', 0, 0, 0, 1230, 0, 0, 0, 0, 0, 0, 93, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, '1335', 0.9213]
6442
6442
['PH', 0, 87, 0, 0, 6312, 0, 0, 0, 0, 0, 22, 0, 0, 0, 8, 0, 0, 0, 0, 0, 13, '6442', 0.9798]
3324
3324
['SCC_G', 4, 0, 0, 0, 0, 3065, 5, 0, 0, 0, 4, 0, 231, 12, 0, 0, 0, 0, 3, 0, 0, '3324', 0.9221]
6119
6119
['HSIL_M', 4, 0, 0, 0, 0, 1, 6070, 0, 0, 0, 0, 0, 15, 0, 0, 0, 3, 26, 0, 0, 0, '6119', 0.992]
747
747
['EC', 0, 0, 0, 0, 0, 0, 0, 736, 0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, '747', 0.9853]
5502
5502
['VIRUS', 0, 0, 0, 0, 0, 0, 1, 0, 5452, 0, 0, 0, 43, 0, 0, 0, 6, 0, 0, 0, 0, '5502', 0.9909]
10982
10982
['ACTINO', 0, 0, 0, 0, 14, 0, 0, 0, 0, 10906, 0, 0, 4, 0, 4, 0, 