In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D, Input, UpSampling2D, concatenate
from tensorflow.keras.applications import EfficientNetB0, MobileNetV3Large, ResNet50
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow_hub as hub
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import time
import cv2

# Constants
IMAGE_SIZE = 224
N_CLASSES = 3  # Normal, Adenocarcinoma, Squamous Cell Carcinoma
BATCH_SIZE = 32
MODEL_ACCURACY = 80  # Example threshold for reliable predictions (adjust based on best model's accuracy)
class_labels = {0: 'Normal', 1: 'Adenocarcinoma', 2: 'Squamous Cell Carcinoma'}

# 1. Preprocessing Technique
def preprocess_image(image):
    image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
    image = image / 255.0
    return image

def preprocess_for_prediction(image):
    img = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
    img = img / 255.0
    img = np.expand_dims(img, axis=0)
    return convert_to_rgb(img)

def convert_to_rgb(images):
    return tf.image.grayscale_to_rgb(images) if images.shape[-1] == 1 else images

def create_data_generator():
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        preprocessing_function=preprocess_image,
        rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True
    )
    valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=preprocess_image)
    train_generator = train_datagen.flow_from_directory(
        'path/to/train', target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='sparse'
    )
    valid_generator = valid_datagen.flow_from_directory(
        'path/to/val', target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='sparse'
    )
    return train_generator, valid_generator

# 2. Model Definitions
def build_cnn_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    model = Sequential([
        tf.keras.layers.Input(shape=input_shape),
        Conv2D(32, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        Conv2D(32, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        MaxPooling2D((2, 2)), Dropout(0.25),
        Conv2D(64, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        Conv2D(64, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        MaxPooling2D((2, 2)), Dropout(0.25),
        Conv2D(128, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        Conv2D(128, (3, 3), padding='same', activation='relu'), BatchNormalization(),
        MaxPooling2D((2, 2)), Dropout(0.25),
        Flatten(), Dense(512, activation='relu'), BatchNormalization(), Dropout(0.5),
        Dense(n_classes, activation='softmax')
    ])
    return model

def build_efficientnet_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = False
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(n_classes, activation='softmax')(x)
    return Model(inputs, outputs)

def build_mobilenet_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    base_model = MobileNetV3Large(include_top=False, weights='imagenet', input_shape=input_shape)
    base_model.trainable = False
    model = Sequential([
        base_model, GlobalAveragePooling2D(), Dense(128, activation='relu'), Dropout(0.5),
        Dense(n_classes, activation='softmax')
    ])
    return model

def build_resnet_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = False
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(n_classes, activation='softmax')(x)
    return Model(inputs, outputs)

class SwinTransformerLayer(layers.Layer):
    def __init__(self, hub_url, **kwargs):
        super(SwinTransformerLayer, self).__init__(**kwargs)
        self.hub_layer = hub.KerasLayer(hub_url, trainable=False)
    def call(self, inputs, training=None):
        return self.hub_layer(inputs)

def build_swin_transformer_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    inputs = Input(shape=input_shape)
    swin_transformer_url = "https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224/1"
    x = SwinTransformerLayer(swin_transformer_url)(inputs)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(n_classes, activation='softmax')(x)
    return Model(inputs=inputs, outputs=outputs)

def build_unet_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    inputs = Input(input_shape)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    u6 = UpSampling2D((2, 2))(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    u7 = UpSampling2D((2, 2))(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    u8 = UpSampling2D((2, 2))(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    u9 = UpSampling2D((2, 2))(c8)
    u9 = concatenate([u9, c # Build U-Net model for classification
def build_unet_model(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_classes=N_CLASSES):
    inputs = Input(input_shape)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    u6 = UpSampling2D((2, 2))(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    u7 = UpSampling2D((2, 2))(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    u8 = UpSampling2D((2, 2))(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    u9 = UpSampling2D((2, 2))(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    x = GlobalAveragePooling2D()(c9)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(n_classes, activation='softmax')(x)
    model = Model(inputs, outputs)
    return model

# 3. Training and Evaluation
def train_and_evaluate_model(model, model_name, train_generator, valid_generator):
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    checkpointer = ModelCheckpoint(f'{model_name}.keras', verbose=1, save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10)

    start_time = time.time()
    history = model.fit(
        train_generator.map(lambda x, y: (convert_to_rgb(x), y)),
        epochs=25, validation_data=valid_generator.map(lambda x, y: (convert_to_rgb(x), y)),
        callbacks=[checkpointer, early_stopping], verbose=1
    )
    training_time = time.time() - start_time

    # Evaluate on validation set
    val_preds = model.predict(valid_generator)
    val_labels = np.concatenate([y for x, y in valid_generator], axis=0)
    val_preds_classes = np.argmax(val_preds, axis=1)

    # Metrics
    accuracy = accuracy_score(val_labels, val_preds_classes)
    precision = precision_score(val_labels, val_preds_classes, average='weighted')
    recall = recall_score(val_labels, val_preds_classes, average='weighted')  # Sensitivity
    f1 = f1_score(val_labels, val_preds_classes, average='weighted')
    cm = confusion_matrix(val_labels, val_preds_classes)
    specificity = np.mean([cm[i, i] / cm[i].sum() for i in range(N_CLASSES) if cm[i].sum() > 0])

    # ROC Curve
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(N_CLASSES):
        fpr[i], tpr[i], _ = roc_curve(val_labels == i, val_preds[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    return {
        'model_name': model_name, 'history': history.history, 'training_time': training_time,
        'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'specificity': specificity,
        'fpr': fpr, 'tpr': tpr, 'roc_auc': roc_auc, 'model': model
    }

# 4. Detection and Prediction Logic
def detect_and_predict(image_path, classifier):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        print("Error: Unable to load the image. Please check the file format and path.")
        return None, None

    blurred = cv2.GaussianBlur(image, (5, 5), 0)
    _, binary = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    output_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    predictions = []

    for contour in contours:
        area = cv2.contourArea(contour)
        if area > 100:
            x, y, w, h = cv2.boundingRect(contour)
            cv2.rectangle(output_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cropped_image = image[y:y + h, x:x + w]
            preprocessed_image = preprocess_for_prediction(cropped_image)
            confidence_scores = classifier.predict(preprocessed_image)[0]
            predicted_class_index = np.argmax(confidence_scores)
            predicted_class = class_labels[predicted_class_index]
            confidence = round(100 * confidence_scores[predicted_class_index], 2)
            predictions.append((predicted_class, confidence))
            cv2.putText(output_image, f"{predicted_class} ({confidence}%)", (x, y - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title('Original CT Scan')
    plt.imshow(image, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title('Detected Areas with Predictions')
    plt.imshow(output_image)
    plt.axis('off')
    plt.show()

    for i, (pred_class, conf) in enumerate(predictions):
        print(f'Detected area {i + 1}: {pred_class} with confidence {conf}%')

    if predictions:
        reliable_predictions = [(p, c) for p, c in predictions if c >= MODEL_ACCURACY]
        if reliable_predictions:
            final_pred_class, final_conf = max(reliable_predictions, key=lambda x: x[1])
            print(f'Final prediction result (considering {MODEL_ACCURACY}% model accuracy): '
                  f'{final_pred_class} with confidence {final_conf}%')
            cancer_types = ['adenocarcinoma', 'large_cell_carcinoma', 'squamous_cell_carcinoma']
            if final_pred_class.lower() in cancer_types:
                print(f"Cancer detected: {final_pred_class} with confidence {final_conf}% "
                      f"(model accuracy: {MODEL_ACCURACY}%)")
            else:
                print(f"No cancer detected: {final_pred_class} with confidence {final_conf}% "
                      f"(model accuracy: {MODEL_ACCURACY}%)")
            print(f"The image is classified as: {final_pred_class} with confidence {final_conf}%")
        else:
            print(f"No reliable predictions above model accuracy threshold ({MODEL_ACCURACY}%)")
            print("The image type cannot be confidently determined due to low confidence scores.")
    else:
        print("No significant areas detected.")
        print("The image type cannot be determined due to no detectable areas.")

    return predictions, output_image

# Main Execution
train_generator, valid_generator = create_data_generator()
models = [
    (build_cnn_model(), 'CNN'),
    (build_efficientnet_model(), 'EfficientNetB0'),
    (build_mobilenet_model(), 'MobileNetV3'),
    (build_resnet_model(), 'ResNet50'),
    (build_swin_transformer_model(), 'SwinTransformer'),
    (build_unet_model(), 'UNet')
]

results = []
for model, name in models:
    result = train_and_evaluate_model(model, name, train_generator, valid_generator)
    results.append(result)

# Comparison of Models
print("\nModel Comparison:")
print(f"{'Model':<20} {'Accuracy':<10} {'Sensitivity':<12} {'Specificity':<12} {'Training Time (s)':<18}")
for result in results:
    print(f"{result['model_name']:<20} {result['accuracy']:.4f}    {result['recall']:.4f}      {result['specificity']:.4f}      {result['training_time']:.2f}")

# Plot ROC Curves
plt.figure(figsize=(10, 8))
for result in results:
    for i in range(N_CLASSES):
        plt.plot(result['fpr'][i], result['tpr'][i], label=f"{result['model_name']} Class {i} (AUC = {result['roc_auc'][i]:.2f})")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves')
plt.legend(loc='best')
plt.show()

# Find Best Model
best_model_result = max(results, key=lambda x: x['accuracy'])
best_model_name = best_model_result['model_name']
best_model = best_model_result['model']
print(f"\nBest Model: {best_model_name} with Accuracy: {best_model_result['accuracy']:.4f}")

# Detection and Prediction on CT Scan
image_path = '/content/drive/MyDrive/Data/test/adenocarcinoma/000109 (2).png'
predictions, output_image = detect_and_predict(image_path, best_model)