In [1]:
import os
import cv2
import sys
import random
import warnings
import numpy as np 
import pandas as pd
from time import time
from itertools import chain
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt 
from skimage.transform import resize
from skimage.morphology import label
from skimage.io import imread, imshow, imread_collection, concatenate_images
import tensorflow as tf
import tensorflow.keras as keras 

from vit_keras import  vit, utils 

In [14]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale=1/255.,
    validation_split=.3,
    zoom_range=.2,
    rotation_range=.3,
    horizontal_flip=True,
    vertical_flip=True,
    # brightness_range=(.8, 1.2,),
    # fill_mode='constant',
    # cval=0,
)


train_dr =  'D:/Projects/Papers/xray-pneumonia/chest_xray/train'
valid_dr = 'D:/Projects/Papers/xray-pneumonia/chest_xray/val'
test_dr = 'D:/Projects/Papers/xray-pneumonia/chest_xray/test'

batch_size=8

train_gen = datagen.flow_from_directory(directory=train_dr, batch_size=batch_size, class_mode='categorical', target_size=(224, 224), shuffle=True, seed=42)

valid_gen = datagen.flow_from_directory(directory=valid_dr, batch_size=batch_size//4, class_mode='categorical', target_size=(224, 224), shuffle=True, seed=42)

test_gen = datagen.flow_from_directory(directory=test_dr, batch_size=batch_size//4, class_mode='categorical', target_size=(224, 224), shuffle=True, seed=42)

Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


In [3]:
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=.25, patience=5, verbose=1)
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=9, verbose=1, mode='auto',
    baseline=None, restore_best_weights=True
)


ckpt = keras.callbacks.ModelCheckpoint(
    filepath = './saved_model/checkpoint/',
    save_weights_only = True,
    monitor = 'val_loss',
    mode = 'min',
    save_best_only = True
)

callbacks = [reduce_lr, early_stopping, ckpt]

In [12]:
model = vit.vit_b16(
                    image_size=(224, 224), 
                    classes=2,
                    activation='softmax', 
                    include_top=True, 
                    pretrained=True,
                    pretrained_top = False
                    )



In [13]:
model.summary()

Model: "vit-b16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 embedding (Conv2D)          (None, 14, 14, 768)       590592    
                                                                 
 reshape_1 (Reshape)         (None, 196, 768)          0         
                                                                 
 class_token (ClassToken)    (None, 197, 768)          768       
                                                                 
 Transformer/posembed_input   (None, 197, 768)         151296    
 (AddPositionEmbs)                                               
                                                                 
 Transformer/encoderblock_0   ((None, 197, 768),       7087872   
 (TransformerBlock)           (None, 12, None, None))      

In [15]:
model.compile(optimizer=keras.optimizers.Nadam(lr=0.0001, decay=1e-6), loss='binary_crossentropy', metrics=['accuracy'])
model.fit(train_gen, epochs=50, validation_data=valid_gen, callbacks=callbacks, verbose=1)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 00006: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 00012: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06.
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 00016: early stopping


<keras.callbacks.History at 0x26806c53fa0>

In [17]:
y_pred = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred, axis = 1)



In [18]:
from sklearn.metrics import classification_report
def create_df (dataset, label):
    filenames = []  
    labels = []
    for file in os.listdir('./chest_xray/' + f'{dataset}/{label}'):
        filenames.append(file)
        labels.append(label)
    return pd.DataFrame({'filename':filenames, 'label':labels})

test_NORMAL = create_df('test', 'NORMAL')
test_PNEUMONIA = create_df('test', 'PNEUMONIA')
test_ori = test_NORMAL.append(test_PNEUMONIA, ignore_index=True)
test_ori['label'] = test_ori['label'].apply(lambda x: 0 if x=='NORMAL' else 1)
y_true = test_ori['label'].values

print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.42      0.31      0.36       234
           1       0.64      0.75      0.69       390

    accuracy                           0.58       624
   macro avg       0.53      0.53      0.52       624
weighted avg       0.56      0.58      0.57       624

