<a href="https://colab.research.google.com/github/jogong2718/COVID-19-Radiography-Models/blob/main/Github_classification_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import
import os
import numpy as np
from tqdm import tqdm
import imageio as iio
import cv2 as cv
import pickle
from natsort import natsorted
import matplotlib.pyplot as plt
import random
import tensorflow as tf
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
os.chdir('your data link')
os.listdir()

In [None]:
with open('test_everything_final_2.pkl', 'rb') as handle:
    test_everything = pickle.load(handle)

In [None]:
with open('train_everything_final_2.pkl', 'rb') as handle:
    train_everything = pickle.load(handle)

In [None]:
test_data = test_everything[0]
train_data = train_everything[0]

In [None]:
train_data = train_data/255.
test_data = test_data/255.

In [None]:
train_data.shape
test_data.shape

In [None]:
test_labels = test_everything[1]
train_labels = train_everything[1]

In [None]:
train_data = train_data.reshape((3228, 128, 128, 1))
train_data = np.concatenate([train_data, train_data, train_data], 3)

In [None]:
test_data = test_data.reshape((807, 128, 128, 1))
test_data = np.concatenate([test_data, test_data, test_data], 3)

In [None]:
conv_base = tf.keras.applications.DenseNet201(
        include_top=False,
        weights='imagenet',
        pooling='max')

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import regularizers

In [None]:
# input layer
input_layer = tf.keras.Input(shape=(128, 128, 3))
input_layer_aux = tf.keras.layers.Input(shape=(128, 128, 3))
conv1_aux = conv_base(input_layer_aux)
conv3_aux = tf.keras.layers.Flatten()(conv1_aux)

# encode
conv1 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(input_layer)
conv1_2 = tf.keras.layers.BatchNormalization()(conv1)
conv2 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv1_2)
conv2_2 = tf.keras.layers.BatchNormalization()(conv2)
conv4 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv2_2)
conv3_2 = tf.keras.layers.BatchNormalization()(conv4)
conv4 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv3_2)
conv3_2 = tf.keras.layers.BatchNormalization()(conv4)
conv4 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv3_2)
conv3_2 = tf.keras.layers.BatchNormalization()(conv4)
conv4 = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv3_2)

conv4 = tf.keras.layers.BatchNormalization()(conv4)

# decode
deconv1 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(conv4)
deconv2 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(deconv1)
deconv4 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(deconv2)
deconv4 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(deconv4)
deconv4 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(deconv4)
deconv4 = tf.keras.layers.Conv2DTranspose(32, kernel_size=(2,2))(deconv4)
# output
deconv_final = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(deconv4)
conv = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv4)
conv = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
conv = tf.keras.layers.MaxPool2D()(conv)

conv = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv)
conv = tf.keras.layers.Conv2D(32, kernel_size=(2,2))(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
conv = tf.keras.layers.MaxPool2D()(conv)

aux_output = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(conv)
aux_output = tf.keras.layers.Flatten()(aux_output)

# concatenate
concat_layer = tf.keras.layers.Concatenate()([aux_output, conv3_aux])
clf_dense_1 = tf.keras.layers.Dense(128, activation = "relu")(concat_layer)
clf_dropout_1 = tf.keras.layers.Dropout(0.2)(clf_dense_1)
clf_dense_2 = tf.keras.layers.Dense(64, activation = "relu")(clf_dropout_1)
clf_dropout_2 = tf.keras.layers.Dropout(0.3)(clf_dense_2)
clf_dense_3 = tf.keras.layers.Dense(32, activation = "relu")(clf_dropout_2)
clf_dropout_3 = tf.keras.layers.Dropout(0.2)(clf_dense_3)
aux_output_dense = tf.keras.layers.Dense(3, activation = "softmax")(clf_dropout_3)

# final model
final_model = tf.keras.models.Model(inputs = [input_layer, input_layer_aux], outputs = [deconv_final, aux_output_dense])

In [None]:
final_model.summary()

In [None]:
tf.keras.utils.plot_model(final_model, show_shapes=True)

In [None]:
opt = tf.keras.optimizers.Adam(0.0001)
final_model.compile(
    optimizer = opt, # "Adam"
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0])]
)

In [None]:
with tf.device('/device:GPU:0'):
  history = final_model.fit(
      [train_data, train_data], [train_data[:, :, :, 0:1], train_labels], # rescale (aka /255.)
      validation_data = ([test_data, test_data], [test_data[:, :, :, 0:1], test_labels]),
      batch_size = 1,
      epochs = 100
  )

In [None]:
print("best train loss: " + str(min(history.history['dense_3_loss'])) + "\n" + "best train acc: " + str(max(history.history['dense_3_accuracy'])))
print("best test loss: " + str(min(history.history['val_dense_3_loss'])) + "\n" + "best test acc: " + str(max(history.history['val_dense_3_accuracy'])))
print("best conv test acc: " + str(max(history.history['conv2d_6_accuracy'])) + "\n" + "best conv train acc: " + str(max(history.history['val_conv2d_6_accuracy'])))

In [None]:
epochs = 50
acc = history.history['dense_3_accuracy']
val_acc = history.history['val_dense_3_accuracy']

loss = history.history['dense_3_loss']
val_loss = history.history['val_dense_3_loss']

epochs_range = range(epochs)

plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='upper left')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()