In [3]:
!pip install scikit-learn

Collecting scikit-learn
  Using cached scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting scipy>=1.5.0 (from scikit-learn)
  Using cached scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl.metadata (53 kB)
Collecting joblib>=1.1.1 (from scikit-learn)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl (9.4 MB)
Using cached joblib-1.4.2-py3-none-any.whl (301 kB)
Using cached scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl (28.8 MB)
Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
Successfully installed joblib-1.4.2 scikit-learn-1.3.2 scipy-1.10.1 threadpoolctl-3.5.0


In [1]:
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import mixed_precision

# Enable mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Define the paths
base_dir = '/Users/andreshofmann/Desktop/Studies/Uol/7t/FP/stage_2'
csv_path = os.path.join(base_dir, 'csv_files', 'patches_with_labels.csv')
img_dir = os.path.join(base_dir, 'Images', 'img_patches')
model_save_path = os.path.join(base_dir, 'models', 'skin_lesion_model.keras')

# Load the CSV file
data = pd.read_csv(csv_path)

# Function to extract patch number from Patch_id
def get_patch_folder(patch_id):
    parts = patch_id.split('_')
    patch_num = parts[-1].split('.')[0]  # Extract patch number
    return f'Patch_{patch_num}'

# Add a full path to each patch
def get_full_path(row):
    try:
        patch_folder = get_patch_folder(row['Patch_id'])
        if row['label'] == 1:
            return os.path.join(img_dir, 'mel_patches', patch_folder, row['Patch_id'])
        else:
            return os.path.join(img_dir, 'bkl_patches', patch_folder, row['Patch_id'])
    except Exception as e:
        print(f"Error processing row: {row}")
        print(f"Error message: {str(e)}")
        return None

data['path'] = data.apply(get_full_path, axis=1)

# Check for invalid paths and remove them
invalid_paths = data[~data['path'].apply(os.path.exists)]
print(f"Number of invalid paths: {len(invalid_paths)}")
if not invalid_paths.empty:
    print(invalid_paths.head())
    data = data[data['path'].apply(os.path.exists)]

# Convert label column to string
data['label'] = data['label'].astype(str)

# Split the data into training and validation sets
train_data, val_data = train_test_split(data, test_size=0.2, stratify=data['label'], random_state=42)

# Use only 10% of the data for testing
train_data = train_data.sample(frac=0.1, random_state=42)
val_data = val_data.sample(frac=0.1, random_state=42)

# Image data generator for augmentation and preprocessing
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
)

val_datagen = ImageDataGenerator(rescale=1./255)

# Create training and validation generators
train_generator = train_datagen.flow_from_dataframe(
    train_data,
    x_col='path',
    y_col='label',
    target_size=(64, 64),
    batch_size=16,
    class_mode='binary'
)

validation_generator = val_datagen.flow_from_dataframe(
    val_data,
    x_col='path',
    y_col='label',
    target_size=(64, 64),
    batch_size=16,
    class_mode='binary'
)

# Define the CNN model
model = Sequential([
    Input(shape=(64, 64, 3)),
    Conv2D(32, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Define callbacks
callbacks = [
    ModelCheckpoint(model_save_path, save_best_only=True, monitor='val_accuracy'),
    EarlyStopping(patience=5, restore_best_weights=True)
]

# Train the model
history = model.fit(
    train_generator, 
    epochs=50,
    validation_data=validation_generator,
    callbacks=callbacks,
    verbose=1
)

# Define a function to pool predictions from the 16 patches to produce a final prediction for the full image
def predict_full_image(image_id, model, data):
    patches = data[data['image_id'] == image_id]['path'].values
    predictions = []

    for patch in patches:
        img = load_img(patch, target_size=(64, 64))
        img_array = img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        prediction = model.predict(img_array, verbose=0)
        predictions.append(prediction[0][0])

    return np.mean(predictions)

# Evaluate the model on the validation set
val_predictions = []
val_labels = []

for image_id in tqdm(val_data['image_id'].unique(), desc='Evaluating images'):
    label = val_data[val_data['image_id'] == image_id]['label'].values[0]
    val_labels.append(int(label))
    final_prediction = predict_full_image(image_id, model, val_data)
    val_predictions.append(final_prediction)

# Convert predictions to binary class (0 or 1)
val_predictions_binary = [1 if pred >= 0.5 else 0 for pred in val_predictions]

# Calculate accuracy
accuracy = np.mean(np.array(val_predictions_binary) == np.array(val_labels))
print('Validation accuracy:', accuracy)

# Example usage for a single image prediction
image_id = 'ISIC_0028965'
final_prediction = predict_full_image(image_id, model, data)
print(f'Final prediction for image {image_id}:', final_prediction)

Number of invalid paths: 0
Found 2831 validated image filenames belonging to 2 classes.
Found 708 validated image filenames belonging to 2 classes.
Epoch 1/50
