In [None]:
import matplotlib.pyplot as plt
import numpy as np            
import pandas as pd        

from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.applications import EfficientNetB0
from keras.models import Model
from keras.layers import Dense, Flatten, BatchNormalization, GlobalAveragePooling2D, Dropout
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.optimizers import Adam
from sklearn.utils import resample
from vit_keras import vit

#Đường dẫn thư mục train và test
train_path = r'.\train'
test_path = r'.\test'
val_path = r'.\val'

batch_size = 16 

img_height = 224
img_width = 224

: 

In [8]:
# Data augmentation for training
image_gen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    brightness_range=[0.8,1.2]
)

# Data augmentation for testing/validation
test_data_gen = ImageDataGenerator(rescale=1./255)

In [9]:
# Training data generator
train_generator = image_gen.flow_from_directory(
    train_path,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='binary',
    color_mode='rgb'
)

# Testing data generator
test = test_data_gen.flow_from_directory(
    test_path,
    target_size=(img_height, img_width),
    color_mode='rgb', 
    shuffle=False,
    class_mode='binary',
    batch_size=batch_size
)

# Validation data generator
validation_generator = image_gen.flow_from_directory(
    val_path,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='binary',
    color_mode='rgb'
)

Found 1943 images belonging to 2 classes.
Found 611 images belonging to 2 classes.
Found 497 images belonging to 2 classes.


In [15]:
print(train_generator.class_indices)

{'Disease': 0, 'Normal': 1}


In [10]:
# Load the Vision Transformer model
vit_model = vit.vit_b16(
    image_size=img_height,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=False,
    classes=1
)
# Freeze the transformer layers for fine-tuning
for layer in vit_model.layers[:-4]:
    layer.trainable = False

# Compile the model
vit_model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

# Callbacks
early = EarlyStopping(monitor="val_loss", mode="min", patience=5)
learning_rate_reduction = ReduceLROnPlateau(monitor='val_loss', patience=3, verbose=1, factor=0.3, min_lr=0.000001)
callbacks_list = [early, learning_rate_reduction]

# Compute class weights
from sklearn.utils.class_weight import compute_sample_weight

unique_classes = np.unique(train_generator.classes)
weights = compute_sample_weight(class_weight='balanced', y=train_generator.classes)
cw = dict(zip(unique_classes, weights))
print(cw)

# Train the model
vit_model.fit(
    train_generator,
    epochs=50,  # Increase the number of epochs if needed
    validation_data=validation_generator,
    class_weight=cw,
    callbacks=callbacks_list
)

Downloading data from https://github.com/faustomorales/vit-keras/releases/download/dl/ViT-B_16_imagenet21k+imagenet2012.npz




{0: 0.6979166666666666, 1: 0.6979166666666666}
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 5: ReduceLROnPlateau reducing learning rate to 2.9999999242136255e-05.
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 12: ReduceLROnPlateau reducing learning rate to 8.999999772640877e-06.
Epoch 13/50
Epoch 14/50


<keras.src.callbacks.History at 0x1b5c8c8ee80>

In [11]:
vit_model.save('VisionTransformer.h5')

  saving_api.save_model(


In [14]:
from sklearn.metrics import confusion_matrix, classification_report

# Evaluate the model on the test data
test_accu = vit_model.evaluate(test)
print('The testing accuracy is:', test_accu[1] * 100, '%')

# Predict the labels for the test data
y_pred = vit_model.predict(test)
y_pred = np.round(y_pred).astype(int)  # Chuyển đổi dự đoán thành nhãn nhị phân (0 hoặc 1)

# Get the true labels
y_true = test.classes

# Compute the confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)
print('Confusion Matrix:')
print(conf_matrix)

# Compute classification report
class_report = classification_report(y_true, y_pred, target_names=['Disease', 'Normal'])
print('Classification Report:')
print(class_report)

# Extract Sensitivity, Specificity, and F1-Score from confusion matrix
tn, fp, fn, tp = conf_matrix.ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
precision = tp / (tp + fp)
recall = sensitivity  # Recall là tên khác của Sensitivity
f1_score = 2 * (precision * recall) / (precision + recall)

print(f'Sensitivity: {sensitivity * 100:.2f}%')
print(f'Specificity: {specificity * 100:.2f}%')
print(f'F1-Score: {f1_score:.2f}')

The testing accuracy is: 83.14238786697388 %
Confusion Matrix:
[[340  74]
 [ 29 168]]
Classification Report:
              precision    recall  f1-score   support

     Disease       0.92      0.82      0.87       414
      Normal       0.69      0.85      0.77       197

    accuracy                           0.83       611
   macro avg       0.81      0.84      0.82       611
weighted avg       0.85      0.83      0.84       611

Sensitivity: 85.28%
Specificity: 82.13%
F1-Score: 0.77
