In [1]:
!pip install vit_keras -q 

## Setup

In [10]:
import os
import cv2
import sys
import random
import warnings
import numpy as np 
import pandas as pd
from time import time
from itertools import chain
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt 
from skimage.transform import resize
from skimage.morphology import label
from skimage.io import imread, imshow, imread_collection, concatenate_images

import tensorflow as tf
from vit_keras import  vit, utils 
from tensorflow.keras import backend as K
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.layers import (
    Dense, Input, Dropout, Lambda, Conv2D, Conv2DTranspose, MaxPooling2D, Concatenate, 
    Activation, Add, multiply, add, concatenate, LeakyReLU, ZeroPadding2D, UpSampling2D, 
    BatchNormalization, SeparableConv2D, Flatten )

from sklearn.metrics import classification_report
%matplotlib inline

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


In [3]:
MAIN_PATH = './chest_xray/'

## Data Augmentation

In [4]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(rescale=1./255,
                             validation_split=0.25,
                             zoom_range=0.1,
                             rotation_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True,
                             fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

def get_transforms(data):
    
    if data == 'train':
        IMG_TRAIN = MAIN_PATH +'train/'
        train_generator = datagen.flow_from_directory(
            # dataframe = train,
            directory = IMG_TRAIN,
            # x_col = 'filename',
            # y_col = 'label',
            batch_size  = 8,
            shuffle=True,
            class_mode = 'categorical',
            target_size = (224, 224)
        )

        return train_generator

    elif data == 'valid':
        IMG_VAL = MAIN_PATH + 'val/'
        valid_generator = datagen.flow_from_directory(
            # dataframe = valid,
            directory = IMG_VAL,
            # x_col = 'filename',
            # y_col = 'label',
            batch_size = 8,
            shuffle = True,
            class_mode = 'categorical',
            target_size = (224, 224)
        )

        return valid_generator

    else :
        IMG_TEST = MAIN_PATH + 'test/'
        test_generator = test_datagen.flow_from_directory(
            # dataframe = test,
            directory = IMG_TEST,
            # x_col = 'filename',
            # y_col = None,
            batch_size = 8,
            shuffle = False,
            class_mode = None,
            target_size = (224, 224)
        )

        return test_generator

In [5]:
train = get_transforms('train')
valid = get_transforms('valid')
test = get_transforms('test')

Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


## Callbacks

In [6]:
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

reduce_learning_rate = ReduceLROnPlateau(
    monitor='val_loss', factor=0.25, patience=5, verbose=1, mode='auto',
    min_delta=1e-10, cooldown=0, min_lr=0
)

early_stopping = EarlyStopping(
    monitor='val_loss', min_delta=0, patience=9, verbose=1, mode='auto',
    baseline=None, restore_best_weights=True
)

ckpt = ModelCheckpoint(
    filepath = './saved_model/checkpoint/',
    save_weights_only = True,
    monitor = 'val_loss',
    mode = 'min',
    save_best_only = True
)

callbacks = [reduce_learning_rate, early_stopping, ckpt]

## Model

In [7]:
image_size = 224
model = vit.vit_b16(
    image_size = image_size,
    activation = 'softmax',
    pretrained = True,
    include_top = True,
    pretrained_top = False,
    classes = 2
)



In [8]:
model.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.0001, decay=1e-6), loss='binary_crossentropy', metrics=['accuracy'])

## Train

In [9]:
with tf.device("/DML:0"):
    history = model.fit(train, epochs=50, validation_data=valid, callbacks=callbacks, verbose=1)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


ValueError: Unknown attribute 'DML' is encountered while parsing the device spec: '/DML:0'.

In [None]:
model.evaluate(valid, verbose=1)

In [None]:
y_pred = model.predict(test, verbose=1)
y_pred = np.argmax(y_pred, axis = 1)

In [None]:
def create_df (dataset, label):
    filenames = []  
    labels = []
    for file in os.listdir(MAIN_PATH + f'{dataset}/{label}'):
        filenames.append(file)
        labels.append(label)
    return pd.DataFrame({'filename':filenames, 'label':labels})

test_NORMAL = create_df('test', 'NORMAL')
test_PNEUMONIA = create_df('test', 'PNEUMONIA')
test_ori = test_NORMAL.append(test_PNEUMONIA, ignore_index=True)
test_ori['label'] = test_ori['label'].apply(lambda x: 0 if x=='NORMAL' else 1)
y_true = test_ori['label'].values

In [None]:
print(classification_report(y_true, y_pred))