In [None]:
import os
import pandas as pd
import zipfile, os, numpy as np, tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import load_model

ZIP_PATH     = '/content/drive/MyDrive/archive1.zip'
EXTRACT_ROOT = '/content/xray_dataset'

with zipfile.ZipFile(ZIP_PATH, 'r') as z:
    z.extractall(EXTRACT_ROOT)
print("Dataset extracted to", EXTRACT_ROOT)

Dataset extracted to /content/xray_dataset


In [None]:
BASE_DIR = os.path.join(
    EXTRACT_ROOT,
    'Coronahack-Chest-XRay-Dataset',
    'Coronahack-Chest-XRay-Dataset'
)
print("Using BASE_DIR =", BASE_DIR)
META_CSV = os.path.join(EXTRACT_ROOT, 'Chest_xray_Corona_Metadata.csv')
df = pd.read_csv(META_CSV)
df['Dataset_type'] = df['Dataset_type'].str.lower()
print("Splits:", df['Dataset_type'].unique())
df['image_path'] = df.apply(
    lambda r: os.path.join(BASE_DIR, r['Dataset_type'], r['X_ray_image_name']),
    axis=1
)
df = df[df['image_path'].apply(os.path.exists)].reset_index(drop=True)

df['detailed_label'] = df['Label_2_Virus_category'].fillna('Normal').astype('category')
class_names = list(df['detailed_label'].cat.categories)
print("Classes:", class_names)

Using BASE_DIR = /content/xray_dataset/Coronahack-Chest-XRay-Dataset/Coronahack-Chest-XRay-Dataset
Splits: ['train' 'test']
Classes: ['ARDS', 'COVID-19', 'Normal', 'SARS', 'Streptococcus']


In [None]:
train_df = df[df['Dataset_type']=='train']
test_df  = df[df['Dataset_type']=='test']
print("Train samples:", len(train_df), "Test samples:", len(test_df))

Train samples: 5286 Test samples: 624


In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

IMAGE_SIZE = (224,224)
BATCH_SIZE = 16

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)
val_datagen = ImageDataGenerator(rescale=1./255)

train_gen = train_datagen.flow_from_dataframe(
    train_df,
    x_col='image_path', y_col='detailed_label',
    classes=class_names,
    target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=True
)
val_gen = val_datagen.flow_from_dataframe(
    test_df,
    x_col='image_path', y_col='detailed_label',
    classes=class_names,
    target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=False
)


Found 5286 validated image filenames belonging to 5 classes.
Found 624 validated image filenames belonging to 5 classes.


In [None]:
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import TopKCategoricalAccuracy

def build_detailed_xray_model(input_shape=(224,224,3), n_classes=len(class_names)):
    base = MobileNet(include_top=False, weights='imagenet', input_shape=input_shape)
    x = GlobalAveragePooling2D()(base.output)
    x = Dropout(0.3)(x)
    out = Dense(n_classes, activation='softmax')(x)
    model = Model(inputs=base.input, outputs=out)
    for layer in base.layers[:-20]:
        layer.trainable = False
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='categorical_crossentropy',
        metrics=[
            'accuracy',
            TopKCategoricalAccuracy(2, name='top_2_acc'),
            TopKCategoricalAccuracy(3, name='top_3_acc')
        ]
    )
    return model

model = build_detailed_xray_model()
model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf_no_top.h5
[1m17225924/17225924[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
checkpoint = ModelCheckpoint(
    'best_xray_detailed.h5',
    monitor='val_accuracy', mode='max', save_best_only=True, verbose=1
)
rlrop = ReduceLROnPlateau(
    monitor='val_accuracy', factor=0.5, patience=3, min_lr=1e-6, verbose=1
)

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=20,
    callbacks=[checkpoint, rlrop]
)
print("Done. Best detailed model saved as best_xray_detailed.h5")

  self._warn_if_super_not_called()


Epoch 1/20
[1m331/331[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.7590 - loss: 0.7252 - top_2_acc: 0.8285 - top_3_acc: 0.8953
Epoch 1: val_accuracy improved from -inf to 1.00000, saving model to best_xray_detailed.h5




[1m331/331[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m550s[0m 2s/step - accuracy: 0.7595 - loss: 0.7238 - top_2_acc: 0.8289 - top_3_acc: 0.8955 - val_accuracy: 1.0000 - val_loss: 3.6561e-04 - val_top_2_acc: 1.0000 - val_top_3_acc: 1.0000 - learning_rate: 1.0000e-04
Epoch 2/20
[1m196/331[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m3:18[0m 1s/step - accuracy: 0.9954 - loss: 0.0170 - top_2_acc: 0.9989 - top_3_acc: 0.9999

KeyboardInterrupt: 

In [None]:
from tensorflow.keras.models import load_model
from sklearn.metrics import classification_report

m = load_model('best_xray_multiclass.h5', compile=False)
m.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

val_gen.reset()
y_true = val_gen.classes
preds = m.predict(val_gen, verbose=1)
y_pred = np.argmax(preds, axis=1)

print(classification_report(y_true, y_pred,target_names=class_names))

[1m39/39[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 106ms/step
              precision    recall  f1-score   support

      Normal       0.88      0.92      0.90       234
    Pnemonia       0.95      0.93      0.94       390

    accuracy                           0.92       624
   macro avg       0.92      0.92      0.92       624
weighted avg       0.92      0.92      0.92       624

