In [33]:
import os
import pandas as pd
import numpy as np
import cv2
from sklearn.model_selection import KFold
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
from sklearn.metrics import precision_score, recall_score, f1_score
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
%matplotlib inline
def focal_loss(gamma=2., alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        y_true = K.cast(y_true, tf.float32)
        alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)
        p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (1 - y_pred)
        fl = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
        return K.mean(fl)
    return focal_loss_fixed

def load_images_and_masks(data_dir):
    images = []
    masks = []
    
    for patient_folder in os.listdir(data_dir):
        if not os.path.isdir(os.path.join(data_dir, patient_folder)):
            continue
        patient_path = os.path.join(data_dir, patient_folder)
        
        for file in os.listdir(patient_path):
            if file.endswith('.bmp'):
                # Load image
               
                image_path = os.path.join(patient_path, file)
                image = cv2.imread(image_path)
                images.append(image)
                
                # Create corresponding mask
                mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
                csv_file = file.replace('.bmp', '.csv')
                csv_path = os.path.join(patient_path, csv_file)
                
                if os.path.exists(csv_path):
                    try:
                        with open(csv_path, 'r') as f:
                            lines = f.readlines()
                            for line in lines:
                                values = line.strip().split(',')
                                # Ensure even number of values (pairs of coordinates)
                                if len(values) % 2 != 0:
                                    print(f"Warning: Skipping invalid row in {csv_path}: {line}")
                                    continue
                                # Convert pairs of coordinates to integers and set mask
                                for i in range(0, len(values), 2):
                                    try:
                                        col_idx = int(values[i])
                                        row_idx = int(values[i+1])
                                        mask[row_idx, col_idx] = 1
                                    except ValueError:
                                        print(f"Warning: Skipping invalid coordinate pair in {csv_path}: ({values[i]}, {values[i+1]})")
                    
                    except Exception as e:
                        print(f"Error reading {csv_path}: {e}")
                
                masks.append(mask)
    
    return np.array(images), np.array(masks)

def extract_patches(image, patch_size=64, stride=16):
    patches = []
    coordinates = []
    if len(image.shape) == 3:
        h, w, _ = image.shape
    else:
        h, w = image.shape
    
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = image[y:y + patch_size, x:x + patch_size]
            patches.append(patch)
            coordinates.append((y, x))
    return np.array(patches), coordinates

def patch_generator(images, masks, patch_size=64, stride=16, batch_size=64):
    while True:
        batch_images = []
        batch_masks = []
        for image, mask in zip(images, masks):
            for y in range(0, image.shape[0] - patch_size + 1, stride):
                for x in range(0, image.shape[1] - patch_size + 1, stride):
                    image_patch = image[y:y + patch_size, x:x + patch_size]
                    mask_patch = mask[y:y + patch_size, x:x + patch_size]
                    batch_images.append(image_patch)
                    batch_masks.append(mask_patch)
                    if len(batch_images) == batch_size:
                        yield np.array(batch_images), np.expand_dims(np.array(batch_masks), axis=-1)
                        batch_images, batch_masks = [], []
        if batch_images:
            yield np.array(batch_images), np.expand_dims(np.array(batch_masks), axis=-1)

def reconstruct_mask(pred_patches, coordinates, image_shape, patch_size=64, stride=8):
    mask = np.zeros((image_shape[0], image_shape[1]))
    count = np.zeros((image_shape[0], image_shape[1]))

    for i, (y, x) in enumerate(coordinates):
        mask[y:y + patch_size, x:x + patch_size] += pred_patches[i].squeeze()
        count[y:y + patch_size, x:x + patch_size] += 1
    
    mask = mask / count
    return mask


def cross_validation(images, masks, n_splits=5, patch_size=16, stride=8, batch_size=64):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = []

    for train_index, val_index in kf.split(images):
        X_train, X_val = images[train_index], images[val_index]
        y_train, y_val = masks[train_index], masks[val_index]

        train_generator = patch_generator(X_train, y_train, patch_size, stride, batch_size)
        val_generator = patch_generator(X_val, y_val, patch_size, stride, batch_size)
        
        num_train_patches = sum(
            (image.shape[0] - patch_size + 1) * (image.shape[1] - patch_size + 1) // (stride * stride)
            for image in X_train
        )
        num_val_patches = sum(
            (image.shape[0] - patch_size + 1) * (image.shape[1] - patch_size + 1) // (stride * stride)
            for image in X_val
        )

        fold_results.append((train_generator, val_generator, num_train_patches // batch_size, num_val_patches // batch_size, X_val, y_val))
    
    return fold_results


def unet_model(input_size=(64, 64, 3)):
    inputs = Input(input_size)
    
    # Encoder
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)

    # Decoder
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Concatenate()([up6, conv4])
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Concatenate()([up7, conv3])
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Concatenate()([up8, conv2])
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Concatenate()([up9, conv1])
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[outputs])
    
    model.compile(optimizer='adam', loss=focal_loss(), metrics=['accuracy', 'Precision', 'Recall'])
    
    return model

def calculate_metrics(y_true, y_pred):
    # Flatten the arrays to calculate the metrics at the pixel level
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    
    # Calculate precision, recall, and F1-score
    precision = precision_score(y_true, y_pred, average='binary')
    recall = recall_score(y_true, y_pred, average='binary')
    f1 = f1_score(y_true, y_pred, average='binary')
    
    return precision, recall, f1

In [34]:
data_dir = 'scanner_A'  # Your data directory
input_size = (64, 64, 3)  # Define the input size for the model

# Load images and masks
images, masks = load_images_and_masks(data_dir)
print(f"Loaded {len(images)} images and {len(masks)} masks.")

Loaded 35 images and 35 masks.


In [35]:

# Check class distribution
total_pixels = np.prod(masks.shape)
num_mitosis_pixels = np.sum(masks)
num_non_mitosis_pixels = total_pixels - num_mitosis_pixels
print(f"Mitosis pixels: {num_mitosis_pixels}, Non-mitosis pixels: {num_non_mitosis_pixels}")

# Perform 5-fold cross-validation
fold_results = cross_validation(images, masks, patch_size=64, stride=16, batch_size=32)

Mitosis pixels: 135376, Non-mitosis pixels: 151871584.0


In [37]:
for fold_num, (train_gen, val_gen, train_steps, val_steps, X_val, y_val) in enumerate(fold_results):
    print(f"Fold {fold_num + 1}")

    # Define and compile the model
    model = unet_model(input_size=input_size)

    # Train the model using the training generator and validate it on the validation data
    model.fit(train_gen,
                validation_data=val_gen,
                epochs=5,  # Adjust the number of epochs as needed
                steps_per_epoch=train_steps,
                validation_steps=val_steps)

    # Predict on validation patches
    val_patch_preds = model.predict(val_gen, steps=val_steps)
    val_patch_preds = (val_patch_preds > 0.5).astype(np.uint8)  # Convert predictions to binary masks

    # Reconstruct the full validation mask
    val_mask_pred = reconstruct_mask(val_patch_preds, [c for _, c in extract_patches(X_val[0], patch_size=64, stride=16)], X_val[0].shape, patch_size=64, stride=16)
    
    precision, recall, f1 = calculate_metrics(y_val.flatten(), val_mask_pred.flatten())
    print(f"Fold {fold_num + 1} - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

Fold 1
Epoch 1/5
[1m13959/13959[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11142s[0m 798ms/step - Precision: 0.0000e+00 - Recall: 0.0000e+00 - accuracy: 0.9985 - loss: 0.0053 - val_Precision: 0.0000e+00 - val_Recall: 0.0000e+00 - val_accuracy: 0.9988 - val_loss: 0.0050
Epoch 2/5
[1m13959/13959[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11112s[0m 796ms/step - Precision: 0.0000e+00 - Recall: 0.0000e+00 - accuracy: 0.9989 - loss: 0.0043 - val_Precision: 0.0000e+00 - val_Recall: 0.0000e+00 - val_accuracy: 0.9988 - val_loss: 0.0050
Epoch 3/5
[1m 3350/13959[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m18:10:17[0m 6s/step - Precision: 0.0000e+00 - Recall: 0.0000e+00 - accuracy: 0.9985 - loss: 0.0061

KeyboardInterrupt: 

In [None]:
print("Training and evaluation completed.")
model.save('model.h5')