# Mind model

## 0. Imports

In [None]:
import os
import random
from typing import Any
from matplotlib import pyplot as plt
import numpy as np
from numpy import ndarray
import tensorflow as tf
from keras.callbacks import TensorBoard
from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate, Dropout, BatchNormalization
from keras.metrics import BinaryCrossentropy
from keras.models import Model, load_model
from keras.optimizers import Adam
from keras.utils import Sequence
from PIL import Image as Img, ImageOps
from PIL.Image import Image

## 1. Set Up

### **Hardware Requirements**:
- At least 16 GB of RAM


### 1.1 Set constants

In [None]:
DATASET_DIR = "../data/dataset/"
MODEL_DIR = "../data/model/"
MODEL_NAME = "mind.keras"
VRAM = 2048

### 1.2 Configure hardware

Attempts to allocate only as much GPU memory as needed for the runtime allocations: it starts out allocating very little memory, and as the program gets run and more GPU memory is needed, the GPU memory region is extended for the TensorFlow process. The GPU memory is limited by the constant `VRAM`

In [None]:
GPU_LIST = tf.config.list_physical_devices("GPU")
if GPU_LIST:
    try:
        for gpu in GPU_LIST:
            tf.config.experimental.set_memory_growth(gpu, True)
            tf.config.set_logical_device_configuration(gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
        logical_gpu_list = tf.config.list_logical_devices("GPU")
        print(f"Available GPUs: {len(GPU_LIST)} Physical GPUs, {len(logical_gpu_list)} Logical GPUs")
    except RuntimeError as e:
        print(f"Error: {e}")

## 2. Import data

### 2.1 Load dataset

#### 2.1.1 Helper function to load data

In [None]:
def load_data(directory: str) -> tuple[list[Image], list[Image]]:
    image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if "_mask" not in file_name]
    mask_paths = [image_path.replace(".tif", "_mask.tif") for image_path in image_paths]
    images = [Img.open(image_path) for image_path in image_paths]
    masks = [Img.open(mask_path) for mask_path in mask_paths]
    return images, masks

#### 2.1.2 Load data


In [None]:
images, masks = load_data(DATASET_DIR)

### 2.3 Split the data

#### 2.3.1 Helper function to split data

In [None]:
def split_data(images: list[Image], masks: list[Image], second_split_percentage: float) -> tuple[list[Image], list[Image], list[Image], list[Image]]:
    if len(images) != len(masks):
        raise ValueError("Length of images and masks must be the same.")
    
    # Set the random seed for reproducibility
    random.seed(42)
    
    # Combine images and masks into pairs
    data_pairs = list(zip(images, masks))

    # Shuffle the data pairs
    random.shuffle(data_pairs)

    # Calculate the split index
    split_num = int(len(data_pairs) * (1 - second_split_percentage))

    # Split data
    split_1_data = data_pairs[:split_num]
    split_2_data = data_pairs[split_num:]

    # Unzip
    split_1_images, split_1_masks = zip(*split_1_data)
    split_2_images, split_2_masks = zip(*split_2_data)

    return split_1_images, split_1_masks, split_2_images, split_2_masks

#### 2.3.2 Split the data into Train, Validate and Test

In [None]:
train_images, train_masks, tmp_images, tmp_masks = split_data(images, masks, 0.2)
validation_images, validation_masks, test_images, test_masks = split_data(tmp_images, tmp_masks, 0.5)


### 2.4 Augment data

Artificially increase the size of a dataset by applying various transformations to the existing data samples. The goal is to diversify the dataset and improve the generalization and robustness of a ML model. Data augmentation is commonly used in computer vision tasks, such as image classification and object detection, to create variations of the input data without collecting new samples.

The validation and test set is used to try to estimate how your method works on real world data, thus it should only contain real world data. Adding augmented data will to not improve the accuracy of the validation. It will at best say something about how well your method responds to the data augmentation, and at worst ruin the validation results and interpretability.

In [None]:
# TODO(eñaut): Ikusi nola egin

## 4. Preprocess data

### 4.1 Clip mask values

Each pixel in the mask has a value of either 0 or 255. This means, that if the pixel value is 0, the same pixel position in the associated image is not part of a tumor, and if its 255, then it is part of a tumor. To better represent that, the mask is converted to an array of boolean values, by replacing anything greater than 1 as 1, and then changing the datatype of the elements of the ndarray to a boolean.

#### 4.1.1 Helper function to clip masks

In [None]:
def clip(mask: Image) -> ndarray:
    return np.array(mask).clip(max=1).astype(bool)

#### 4.1.2 Clip masks

In [None]:
train_masks = [clip(mask) for mask in train_masks]
validation_masks = [clip(mask) for mask in validation_masks]

### 4.2 Normalize images

Set all pixel value range to [0, 1]. The primary reasons for normalizing images in CNNs are:

1. **Improved Convergence**: Normalizing images helps the optimization algorithm converge faster during training. Neural networks often perform better when the input data has zero mean and a small standard deviation. Normalization brings the pixel values to a common scale, preventing large input values from dominating the learning process. This can lead to faster convergence and more stable training.

2. **Gradient Descent Stability**: During backpropagation, the optimization algorithm adjusts the weights of the neural network based on the gradients of the loss with respect to the weights. Normalizing the input data helps ensure that the gradients are within a reasonable range. This can prevent issues like exploding or vanishing gradients, which can hinder the training process.

3. **Model Robustness**: Normalization can make the model more robust to variations in illumination and contrast. By bringing the pixel values to a standard scale, the network becomes less sensitive to changes in lighting conditions or differences in pixel intensity across different images.

4. **Generalization**: Normalization aids in generalization by making the model less dependent on the specific characteristics of the training data. It allows the model to learn patterns and features that are more transferable across different datasets.

5. **Compatibility with Activation Functions**: Some activation functions, such as the sigmoid and tanh functions, perform better when the input values are within a certain range. Normalizing the data helps ensure that the inputs to these functions fall within the regions where they exhibit desirable properties.

#### 4.2.1 Helper function to normalize images

In [None]:
def normalize(image: Image) -> ndarray:
    return np.array(image) / 255.0          


# Galdetu normalizazio motak, eta zergatik modeloek 0 centered izatea gustatzen zaien

# def normalize(image) -> ndarray:
#     # Convert PIL Image to a NumPy array
#     image_array = np.array(image)

#     # Normalize the image array (subtract mean and divide by standard deviation)
#     normalized_image_array = (image_array - np.mean(image_array)) / np.std(image_array)

#     # Convert the normalized array back to a PIL Image
#     normalized_image = Image.fromarray((normalized_image_array * 255).astype(np.uint8))

#     return normalized_image

#### 4.2.2 Normalize images

In [None]:
train_images = [normalize(image) for image in train_images]
validation_images = [normalize(image) for image in validation_images]

## 5. Create Model

### 5.1 Create U-Net Model

Idatzi zergatik U-Net

#### 5.1.1 Helper function to create U-Net model

In [None]:
# TODO (eñaut): Refactor
def unet_builder(input_size = (256,256,3)):
    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2DTranspose(512,2,strides=(2,2),padding='same')(drop5)
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2DTranspose(256,2,strides=(2,2),padding='same')(conv6)
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2DTranspose(128,2,strides=(2,2),padding='same')(conv7)
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2DTranspose(64,2,strides=(2,2),padding='same')(conv8)
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)

    conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs, conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss='binary_crossentropy', metrics=['accuracy', BinaryCrossentropy()])

    model.summary()
    return model


#### 5.1.2 Create U-Net model

In [None]:
unet_model = unet_builder()

### 5.2 Train the model

#### 5.2.1 Set constants

In [None]:
BATCH_SIZE = 4
EPOCHS = 20

#### 5.2.2 Generator class

When training the model, all the images are loaded to the GPUs VRAM, and it gets full very quickly. To avoid that, the following generator feeds the model, by batches, from RAM.

In [None]:
# Feed the model little by little, to not overload the GPU 
class DataGenerator(Sequence):
    def __init__(self, x_set: list, y_set: list, batch_size: int):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self) -> int:
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, i) -> tuple[ndarray, ndarray]:
        batch_x = np.asarray(self.x[i * self.batch_size:(i + 1) * self.batch_size])
        batch_y = np.asarray(self.y[i * self.batch_size:(i + 1) * self.batch_size])
        return batch_x, batch_y

#### 5.2.3 Set generators

In [None]:
train_generator = DataGenerator(train_images, train_masks, BATCH_SIZE)
validation_generator = DataGenerator(validation_images, validation_masks, BATCH_SIZE)

#### 5.2.4 Start training

In [None]:
unet_model.fit(train_generator, validation_data=validation_generator, epochs=EPOCHS, steps_per_epoch=len(train_generator))

### 5.3 Save the model

In [None]:
unet_model.save(MODEL_DIR + MODEL_NAME, save_format="keras")

## 6. Test Model

### 6.1 Select model

Select the saved model, or the one already loaded.

In [None]:
LOAD_MODEL = True
unet_model: Any
if LOAD_MODEL:
    unet_model = load_model(MODEL_DIR + MODEL_NAME)

### 6.2 Preprocess test data

In this part, the same preprocessing, that the validation and train data had, will be applied to the test data.

#### 6.2.1 Clip masks

In [None]:
test_masks = [clip(mask) for mask in test_masks]

#### 6.2.2 Normalize images

In [None]:
test_images = [normalize(image) for image in test_images]

### 6.3 Evaluate model

Evaluate model using the Accuracy and BinaryCrossentropy metrics, defined in the model creation, with unseen data.

In [None]:
evaluation = unet_model.evaluate(np.asarray(test_images), np.asarray(test_masks))
print(f"Loss: {evaluation[0]}, Accuracy: {evaluation[1]}")

### 6.4 Visual evaluation

#### 6.4.1 Predict

In [None]:
predictions = unet_model.predict(test_images)

#### 6.4.2 Plot images

In [None]:
for i in range(5):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(test_images[i])
    plt.title("Input image")

    plt.subplot(1, 3, 2)
    plt.imshow(test_masks[i])
    plt.title("Real mask")

    plt.subplot(1, 3, 3)
    plt.imshow(predictions[i, :, :, 0], cmap="gray")
    plt.title("Predicted mask")

    plt.show()
    # plt.savefig(f"fig_{i}.jpg")