In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Required imports
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import json

# Load the configuration file
with open('config.json', 'r') as f:
    config = json.load(f)

# Check for GPU and set policy if compatible
if tf.test.is_gpu_available(cuda_only=True, min_cuda_compute_capability=(7, 0)):
    from tensorflow.keras import mixed_precision
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
else:
    print("Mixed precision not enabled: Compatible GPU not detected.")

# Define the paths
base_dir = config["paths"]["base_dir"]
csv_path = config["paths"]["csv_path"]
img_dir = config["paths"]["img_dir"]
model_save_path = config["paths"]["model_save_path"]

# Load the CSV file
data = pd.read_csv(csv_path)

# Function to extract patch number from Patch_id
def get_patch_folder(patch_id):
    parts = patch_id.split('_')
    patch_num = parts[-1].split('.')[0]  # Extract patch number
    return f'Patch_{patch_num}'

# Debugging version of get_full_path
def get_full_path(row):
    try:
        patch_folder = get_patch_folder(row['Patch_id'])
        if row['label'] == 1:
            full_path = os.path.join(img_dir, 'mel_patches', patch_folder, row['Patch_id'])
        else:
            full_path = os.path.join(img_dir, 'bkl_patches', patch_folder, row['Patch_id'])
        
        if not os.path.exists(full_path):
            print(f"Invalid path: {full_path}")
        
        return full_path
    except Exception as e:
        print(f"Error processing row: {row}")
        print(f"Error message: {str(e)}")
        return None

# Apply the function to add the full path
data['path'] = data.apply(get_full_path, axis=1)

# Check for invalid paths and remove them
data['path_exists'] = data['path'].apply(os.path.exists)
invalid_paths = data[~data['path_exists']]
print(f"Number of invalid paths: {len(invalid_paths)}")
if not invalid_paths.empty:
    print(invalid_paths.head())
data = data[data['path_exists']]

# Convert label column to string
data['label'] = data['label'].astype(str)

# Split the data into training and validation sets
train_data, val_data = train_test_split(data, test_size=config["training"]["validation_split"], stratify=data['label'], random_state=config["training"]["random_state"])

# Use only a fraction of the data for testing
train_data = train_data.sample(frac=config["training"]["fraction"], random_state=config["training"]["random_state"])
val_data = val_data.sample(frac=config["training"]["fraction"], random_state=config["training"]["random_state"])

# Image data generator for augmentation and preprocessing
train_datagen = ImageDataGenerator(
    rescale=config["data_augmentation"]["rescale"],
    rotation_range=config["data_augmentation"]["rotation_range"],
    width_shift_range=config["data_augmentation"]["width_shift_range"],
    height_shift_range=config["data_augmentation"]["height_shift_range"],
    shear_range=config["data_augmentation"]["shear_range"],
    zoom_range=config["data_augmentation"]["zoom_range"],
    horizontal_flip=config["data_augmentation"]["horizontal_flip"],
    vertical_flip=config["data_augmentation"]["vertical_flip"],
    brightness_range=config["data_augmentation"]["brightness_range"]
)

val_datagen = ImageDataGenerator(rescale=config["data_augmentation"]["rescale"])

# Create training and validation generators
train_generator = train_datagen.flow_from_dataframe(
    train_data,
    x_col='path',
    y_col='label',
    target_size=(64, 64),
    batch_size=config["training"]["batch_size"],
    class_mode='binary'
)

validation_generator = val_datagen.flow_from_dataframe(
    val_data,
    x_col='path',
    y_col='label',
    target_size=(64, 64),
    batch_size=config["training"]["batch_size"],
    class_mode='binary'
)

# Define the CNN model
model = Sequential()
model.add(Input(shape=config["model_architecture"]["input_shape"]))
for layer in config["model_architecture"]["conv_layers"]:
    model.add(Conv2D(filters=layer["filters"], kernel_size=layer["kernel_size"], activation=layer["activation"]))
    model.add(MaxPooling2D(pool_size=config["model_architecture"]["pool_size"]))
model.add(Flatten())
for layer in config["model_architecture"]["dense_layers"]:
    model.add(Dense(units=layer["units"], activation=layer["activation"]))
model.add(Dropout(config["model_architecture"]["dropout_rate"]))
model.add(Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer=config["training"]["optimizer"], 
              loss=config["training"]["loss"], 
              metrics=config["training"]["metrics"])

# Define callbacks
callbacks = [
    ModelCheckpoint(model_save_path, save_best_only=config["callbacks"]["ModelCheckpoint"]["save_best_only"], monitor=config["callbacks"]["ModelCheckpoint"]["monitor"]),
    EarlyStopping(patience=config["callbacks"]["EarlyStopping"]["patience"], restore_best_weights=config["callbacks"]["EarlyStopping"]["restore_best_weights"])
]

# Train the model
history = model.fit(
    train_generator, 
    epochs=config["training"]["epochs"],
    validation_data=validation_generator,
    callbacks=callbacks,
    verbose=1
)

# Define a function to pool predictions from the 16 patches to produce a final prediction for the full image
def predict_full_image(image_id, model, data):
    patches = data[data['image_id'] == image_id]['path'].values
    predictions = []

    for patch in patches:
        img = load_img(patch, target_size=(64, 64))
        img_array = img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        prediction = model.predict(img_array, verbose=0)
        predictions.append(prediction[0][0])

    return np.mean(predictions)

# Evaluate the model on the validation set
val_predictions = []
val_labels = []

for image_id in tqdm(val_data['image_id'].unique(), desc='Evaluating images'):
    label = val_data[val_data['image_id'] == image_id]['label'].values[0]
    val_labels.append(int(label))
    final_prediction = predict_full_image(image_id, model, val_data)
    val_predictions.append(final_prediction)

# Convert predictions to binary class (0 or 1)
val_predictions_binary = [1 if pred >= 0.5 else 0 for pred in val_predictions]

# Calculate accuracy
accuracy = np.mean(np.array(val_predictions_binary) == np.array(val_labels))
print('Validation accuracy:', accuracy)

# Example usage for a single image prediction
image_id = 'ISIC_0028965'
final_prediction = predict_full_image(image_id, model, data)
print(f'Final prediction for image {image_id}:', final_prediction)