<a href="https://colab.research.google.com/github/losirlu1411/project-/blob/main/tuberculosis_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
### 1. Imports and Configuration
import os
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, Conv2D, MaxPooling2D, Input
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import Recall
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications import VGG16, DenseNet201, ResNet101


In [2]:

# Dataset path
base_path = '/content/drive/MyDrive/project /TB_Chest_Radiography_Database'

# Class labels
labels = ['Normal', 'Tuberculosis']
image_size = 224

In [None]:
### 2. Data Loading and Preprocessing
# Function to load and preprocess images with optional sampling
def load_data(base_path, labels, image_size, sample_size=None):
    data, label_list = [], []
    for label_idx, label in enumerate(labels):
        folder_path = os.path.join(base_path, label)
        all_files = os.listdir(folder_path)
        if sample_size:
            sampled_files = random.sample(all_files, min(len(all_files), sample_size))
        else:
            sampled_files = all_files
        for file_name in sampled_files:
            img_path = os.path.join(folder_path, file_name)
            img = cv2.imread(img_path)
            data.append(img)
            label_list.append(label_idx)
    return np.array(data), np.array(label_list)

# Load dataset with all samples
print("Loading full dataset...")
x_data_full, y_data_full = load_data(base_path, labels, image_size)
x_data_full = x_data_full / 255.0  # Normalize images
y_data_full = to_categorical(y_data_full, num_classes=len(labels))  # One-hot encode labels
print(f"Total images loaded: {len(x_data_full)}")
# Load dataset with 500 samples per class
print("Loading sampled dataset (500 per class)...")
x_data_sampled, y_data_sampled = load_data(base_path, labels, image_size, sample_size=500)
x_data_sampled = x_data_sampled / 255.0  # Normalize images
y_data_sampled = to_categorical(y_data_sampled, num_classes=len(labels))  # One-hot encode labels
print(f"Total images loaded: {len(x_data_sampled)}")
# Split into train and test sets for both datasets
x_train_full, x_test_full, y_train_full, y_test_full = train_test_split(x_data_full, y_data_full, test_size=0.2, random_state=42)
x_train_sampled, x_test_sampled, y_train_sampled, y_test_sampled = train_test_split(x_data_sampled, y_data_sampled, test_size=0.2, random_state=42)

Loading full dataset...


In [None]:
### 3. Class Weights and Data Augmentation
# Compute class weights for both datasets
class_weights_full = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(np.argmax(y_train_full, axis=1)),
    y=np.argmax(y_train_full, axis=1)
)
class_weights_sampled = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(np.argmax(y_train_sampled, axis=1)),
    y=np.argmax(y_train_sampled, axis=1)
)
class_weights_dict_full = {i: weight for i, weight in enumerate(class_weights_full)}
class_weights_dict_sampled = {i: weight for i, weight in enumerate(class_weights_sampled)}

# Data augmentation for training
train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)
train_datagen.fit(x_train_full)
train_datagen.fit(x_train_sampled)

In [None]:
### 4. Model Building
# Function to build a custom CNN model
def build_custom_model(input_shape, num_classes):
    model = Sequential([
        Input(shape=input_shape),
        Conv2D(32, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Conv2D(128, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(
        optimizer='adam',
        loss=CategoricalCrossentropy(label_smoothing=0.1),
        metrics=['accuracy', Recall(class_id=1, name="recall")]
    )
    return model

# Function to build pretrained models
def build_pretrained_model(base_model, input_shape, num_classes):
    base_model.trainable = False
    model = Sequential([
        base_model,
        Flatten(),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(
        optimizer='adam',
        loss=CategoricalCrossentropy(label_smoothing=0.1),
        metrics=['accuracy', Recall(class_id=1, name="recall")]
    )
    return model

In [None]:
### 5. Training and Evaluation Framework
# Function to train and evaluate model
def train_and_evaluate_model(model, model_name, x_train, y_train, x_test, y_test, datagen, class_weights):
    print(f"Training {model_name}...")
    early_stopping = EarlyStopping(patience=5, restore_best_weights=True)
    history = model.fit(
        datagen.flow(x_train, y_train, batch_size=32),
        validation_data=(x_test, y_test),
        epochs=20,
        class_weight=class_weights,
        callbacks=[early_stopping],
        verbose=1
    )
    print(f"Evaluating {model_name}...")
    y_pred_probs = model.predict(x_test)
    y_pred = np.argmax(y_pred_probs, axis=1)
    y_true = np.argmax(y_test, axis=1)
    print(f"Classification Report for {model_name}:")
    print(classification_report(y_true, y_pred, target_names=labels))
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()
    return history

In [None]:
### 6. Model Training
# Define input shape and number of classes
input_shape = (image_size, image_size, 3)
num_classes = len(labels)

# Custom Model
custom_model_full = build_custom_model(input_shape, num_classes)
custom_history_full = train_and_evaluate_model(custom_model_full, "Custom CNN (Full Dataset)", x_train_full, y_train_full, x_test_full, y_test_full, train_datagen, class_weights_dict_full)

custom_model_sampled = build_custom_model(input_shape, num_classes)
custom_history_sampled = train_and_evaluate_model(custom_model_sampled, "Custom CNN (Sampled Dataset)", x_train_sampled, y_train_sampled, x_test_sampled, y_test_sampled, train_datagen, class_weights_dict_sampled)

# VGG16
vgg16_base = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
vgg16_model_full = build_pretrained_model(vgg16_base, input_shape, num_classes)
vgg16_history_full = train_and_evaluate_model(vgg16_model_full, "VGG16 (Full Dataset)", x_train_full, y_train_full, x_test_full, y_test_full, train_datagen, class_weights_dict_full)

vgg16_model_sampled = build_pretrained_model(vgg16_base, input_shape, num_classes)
vgg16_history_sampled = train_and_evaluate_model(vgg16_model_sampled, "VGG16 (Sampled Dataset)", x_train_sampled, y_train_sampled, x_test_sampled, y_test_sampled, train_datagen, class_weights_dict_sampled)

# DenseNet201
densenet_base = DenseNet201(weights='imagenet', include_top=False, input_shape=input_shape)
densenet_model_full = build_pretrained_model(densenet_base, input_shape, num_classes)
densenet_history_full = train_and_evaluate_model(densenet_model_full, "DenseNet201 (Full Dataset)", x_train_full, y_train_full, x_test_full, y_test_full, train_datagen, class_weights_dict_full)

densenet_model_sampled = build_pretrained_model(densenet_base, input_shape, num_classes)
densenet_history_sampled = train_and_evaluate_model(densenet_model_sampled, "DenseNet201 (Sampled Dataset)", x_train_sampled, y_train_sampled, x_test_sampled, y_test_sampled, train_datagen, class_weights_dict_sampled)

# ResNet101
resnet_base = ResNet101(weights='imagenet', include_top=False, input_shape=input_shape)
resnet_model_full = build_pretrained_model(resnet_base, input_shape, num_classes)
resnet_history_full = train_and_evaluate_model(resnet_model_full, "ResNet101 (Full Dataset)", x_train_full, y_train_full, x_test_full, y_test_full, train_datagen, class_weights_dict_full)

resnet_model_sampled = build_pretrained_model(resnet_base, input_shape, num_classes)
resnet_history_sampled = train_and_evaluate_model(resnet_model_sampled, "ResNet101 (Sampled Dataset)", x_train_sampled, y_train_sampled, x_test_sampled, y_test_sampled, train_datagen, class_weights_dict_sampled)

In [None]:
# 7 Plot individual models
models = {
    "Custom CNN": custom_history_full,
    "Custom CNN_sampled": custom_history_sampled,
    "VGG16": vgg16_history_full,
    "VGG16_sampled": vgg16_history_sampled,
    "DenseNet201": densenet_history_full,
    "DenseNet201_sampled": densenet_history_sampled,
    "ResNet101": resnet_history_full,
    "ResNet101_sampled": resnet_history_sampled
}

for model_name, history in models.items():
    plt.figure(figsize=(12, 5))

    # Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title(f'{model_name} - Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # Loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'{model_name} - Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
### 8. Predict and Display Images
# Function to display predictions for 5 random images
def display_predictions(model, x_test, y_test, labels, num_images=5):
    indices = np.random.choice(len(x_test), num_images, replace=False)
    plt.figure(figsize=(15, 10))
    for i, idx in enumerate(indices):
        plt.subplot(1, num_images, i + 1)
        img = x_test[idx]
        true_label = labels[np.argmax(y_test[idx])]
        pred_label = labels[np.argmax(model.predict(img[np.newaxis, ...]))]
        color = 'green' if true_label == pred_label else 'red'
        plt.imshow(img)
        plt.title(f"True: {true_label}\nPred: {pred_label}", color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Example usage:
print("Predictions for Full Dataset")
display_predictions(custom_model_full, x_test_full, y_test_full, labels, num_images=5)

print("Predictions for Sampled Dataset")
display_predictions(custom_model_sampled, x_test_sampled, y_test_sampled, labels, num_images=5)


In [None]:
# Predictions for each individual model
vgg_preds_full = np.argmax(vgg16_model_full.predict(x_test_full), axis=1)
resnet_preds_full = np.argmax(resnet_model_full.predict(x_test_full), axis=1)
custom_preds_full = np.argmax(custom_model_full.predict(x_test_full), axis=1)
densenet_preds_full = np.argmax(densenet_model_full.predict(x_test_full), axis=1)

vgg_preds_sampled = np.argmax(vgg16_model_sampled.predict(x_test_sampled), axis=1)
resnet_preds_sampled = np.argmax(resnet_model_sampled.predict(x_test_sampled), axis=1)
custom_preds_sampled = np.argmax(custom_model_sampled.predict(x_test_sampled), axis=1)
densenet_preds_sampled = np.argmax(densenet_model_sampled.predict(x_test_sampled), axis=1)

# True labels
y_true_full = np.argmax(y_test_full, axis=1)
y_true_sampled = np.argmax(y_test_sampled, axis=1)

# Ensemble model predictions (majority voting)
ensemble_preds_full = []
for i in range(len(vgg_preds_full)):
    preds = [vgg_preds_full[i], resnet_preds_full[i], custom_preds_full[i]]
    ensemble_preds_full.append(max(set(preds), key=preds.count))

ensemble_preds_sampled = []
for i in range(len(vgg_preds_sampled)):
    preds = [vgg_preds_sampled[i], resnet_preds_sampled[i], custom_preds_sampled[i]]
    ensemble_preds_sampled.append(max(set(preds), key=preds.count))

# Evaluation function
def evaluate_model(y_true, y_pred, model_name):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    return {"Model": model_name, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1-Score": f1}

# Evaluate models for the full dataset
results_full = []
results_full.append(evaluate_model(y_true_full, vgg_preds_full, "VGG16"))
results_full.append(evaluate_model(y_true_full, resnet_preds_full, "ResNet101"))
results_full.append(evaluate_model(y_true_full, custom_preds_full, "Custom CNN"))
results_full.append(evaluate_model(y_true_full, densenet_preds_full, "DenseNet201"))


# Evaluate models for the sampled dataset
results_sampled = []
results_sampled.append(evaluate_model(y_true_sampled, vgg_preds_sampled, "VGG16"))
results_sampled.append(evaluate_model(y_true_sampled, resnet_preds_sampled, "ResNet101"))
results_sampled.append(evaluate_model(y_true_sampled, custom_preds_sampled, "Custom CNN"))
results_sampled.append(evaluate_model(y_true_sampled, densenet_preds_sampled, "DenseNet201"))


# Convert results to DataFrames for comparison
results_full_df = pd.DataFrame(results_full)
results_sampled_df = pd.DataFrame(results_sampled)

# Bar plot for the full dataset
results_full_df.set_index("Model").plot(kind='bar', figsize=(10, 6))
plt.title("Model Performance Comparison (Full Dataset)")
plt.xlabel("Models")
plt.ylabel("Performance Metrics")
plt.xticks(rotation=45)
plt.legend(loc="best")
plt.tight_layout()
plt.show()

# Bar plot for the sampled dataset
results_sampled_df.set_index("Model").plot(kind='bar', figsize=(10, 6))
plt.title("Model Performance Comparison (Sampled Dataset)")
plt.xlabel("Models")
plt.ylabel("Performance Metrics")
plt.xticks(rotation=45)
plt.legend(loc="best")
plt.tight_layout()
plt.show()
