# Analiza czerniaka za pomocą fraktalnej sieci neuronowej

In [None]:
import numpy as np
from scipy.ndimage import measurements

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_addons as tfa
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Sprawdzamy dostępne urządzenie

In [None]:
tf.config.list_physical_devices('GPU')

Zapisujemy konfigurację do zmiennych.

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 32

Tworzymy callbacki do zbierania danych o wydajności modelu do Tensorboard, zapisywania modelu w trakcie jego trenowania i zatrzymania trenowania modelu, jeśli nie ma poprawy w wynikach w ciągu 10 epok. 

In [None]:
log_dir = '../logs/fit/' + datetime.datetime.now().strftime('fractal_net')
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
checkpoint_path = 'checkpoints/fractal_net.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

checkpoint_callback = ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    save_freq='epoch',
    mode='auto')

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.01, patience=10, restore_best_weights=True)

Definiujemy warstwę, która będzie tworzyła obraz z fraktalnych cech podanych jej obrazów.

In [None]:
class Fractal2D(tf.keras.layers.Layer):
    PERCOLATION_THRESHOLD = 0.59275
    
    def __init__(self):
        super(Fractal2D, self).__init__(name='fractal_layer')
        self.kernel_size_range = (3, 41)

    def chessboard_distance(self, patched_inputs, central_pixels, kernel_size):
        return tf.cast(
            tf.math.less_equal(
                tf.math.reduce_max(
                    tf.math.abs(tf.math.subtract(patched_inputs, central_pixels)), 
                    axis=3), 
                kernel_size), 
            dtype=tf.int32)
    
    def euclidean_distance(self, patched_inputs, central_pixels, kernel_size):
        return tf.cast(
            tf.math.less_equal(
                tf.math.pow(
                    tf.math.reduce_sum(
                        tf.math.pow(
                            tf.math.subtract(patched_inputs, central_pixels), 
                            2), 
                        axis=3), 
                    0.5), 
                kernel_size), 
            dtype=tf.int32)
    
    def manhattan_distance(self, patched_inputs, central_pixels, kernel_size):
        return tf.cast(
            tf.math.less_equal(
                tf.math.reduce_sum(
                    tf.math.abs(tf.math.subtract(patched_inputs, central_pixels)), 
                    axis=3), 
                kernel_size), 
            dtype=tf.int32)
    
    def extract_binary_patches(self, inputs, kernel_size, distance_function):
        patched_inputs = tf.image.extract_patches(inputs,
                                                     sizes=(1, kernel_size, kernel_size, 1),
                                                     strides=(1, kernel_size, kernel_size, 1),
                                                     rates=(1, 1, 1, 1),
                                                     padding='SAME')
        _, rows, cols, _ = patched_inputs.shape
        patched_inputs = tf.reshape(patched_inputs, shape=(-1, kernel_size, kernel_size, 3))
        
        central_pixels = tf.image.resize_with_crop_or_pad(patched_inputs, 1, 1)
        
        return tf.reshape(distance_function(patched_inputs, central_pixels, kernel_size), 
                          shape=(-1, rows * cols, kernel_size, kernel_size))
    
    def calculate_probability_matrices(self, binary_inputs, kernel_size):
        number_of_ones = tf.map_fn(lambda binary_input: tf.map_fn(lambda binary_patch: tf.math.reduce_sum(binary_patch), 
                                                                  binary_input), 
                                   binary_inputs)
        _, patch_number = number_of_ones.shape
        return tf.math.bincount(number_of_ones,
                                minlength=1, 
                                maxlength=kernel_size ** 2 + 1, 
                                axis=-1) / patch_number
    
    def calculate_fractal_dimensions(self, probability_matrices):
        def fd_helper(matrix):
            return tf.math.reduce_sum(tf.math.divide(matrix, tf.range(1, len(matrix) + 1, dtype=tf.float64)))
        return tf.map_fn(lambda matrix: fd_helper(matrix), probability_matrices)
    
    def calculate_lacunarity(self, probability_matrices):
        def m_helper(matrix):
            return tf.math.reduce_sum(tf.math.multiply(matrix, tf.range(1, len(matrix) + 1, dtype=tf.float64)))
        
        def m2_helper(matrix):
            return tf.math.reduce_sum(tf.math.multiply(tf.math.pow(matrix, 2), tf.range(1, len(matrix) + 1, dtype=tf.float64)))
        
        return tf.map_fn(lambda probability_matrix: 
                         tf.math.divide(
                             tf.math.subtract(m2_helper(probability_matrix), 
                                               tf.math.pow(m_helper(probability_matrix), 2)), 
                             tf.math.pow(m_helper(probability_matrix), 2)), 
                         probability_matrices)
    
    def average_cluster_percolation(self, binary_inputs, kernel_size):
        number_of_ones = tf.map_fn(lambda binary_input: tf.map_fn(lambda binary_patch: tf.math.reduce_sum(binary_patch), 
                                                                  binary_input), 
                                   binary_inputs)
        
        return tf.math.reduce_mean(
                        tf.cast(
                            tf.math.greater_equal(
                                tf.math.divide(number_of_ones, kernel_size ** 2), 
                                self.PERCOLATION_THRESHOLD), 
                            dtype=tf.int32), 
                    axis=1)
    
    def average_cluster_number(self, binary_inputs):
        return tf.math.reduce_mean(
            tf.map_fn(
                lambda binary_input: tf.map_fn(
                    lambda patch: tf.math.reduce_max(tfa.image.connected_components(patch)), 
                    binary_input), 
                binary_inputs), 
            axis=1)
        
    def average_cluster_max_area(self, binary_inputs):    
        def most_common(array):
            _, _, counts = tf.unique_with_counts(array)
            return tf.math.reduce_max(counts)
        
        return tf.math.reduce_mean(
                tf.map_fn(lambda binary_input: 
                            tf.map_fn(lambda patch: 
                                        most_common(tf.reshape(tfa.image.connected_components(patch), shape=(-1,))), 
                                      binary_input), 
                          binary_inputs), axis=1)

    def calculate_components(self, inputs, kernel_size, distance_function):
        binary_patches = self.extract_binary_patches(inputs, kernel_size, distance_function)

        probability_matrices = self.calculate_probability_matrices(binary_patches, kernel_size)
        fractal_dimensions = self.calculate_fractal_dimensions(probability_matrices)
        lacunarity = self.calculate_lacunarity(probability_matrices)

        average_cluster_percolation = self.average_cluster_percolation(binary_patches, kernel_size)
        average_cluster_number = self.average_cluster_number(binary_patches)
        average_cluster_max_area = self.average_cluster_max_area(binary_patches)

        return tf.convert_to_tensor((average_cluster_number,
                                    average_cluster_percolation,
                                    average_cluster_max_area,
                                    lacunarity,
                                    fractal_dimensions), dtype=tf.float64)
    
    def rearrage_metrics(self, components):
        def helper(components_input):
            length, = components_input.shape
            
            rearranged_components = tf.concat([
                tf.boolean_mask(components_input, tf.range(length) % 5 == 0),
                tf.boolean_mask(components_input, tf.range(length) % 5 == 1),
                tf.boolean_mask(components_input, tf.range(length) % 5 == 2),
                tf.boolean_mask(components_input, tf.range(length) % 5 == 3),
                tf.boolean_mask(components_input, tf.range(length) % 5 == 4),
            ], axis=0)
            return rearranged_components
        return tf.map_fn(helper, components)
    
    
    def call(self, inputs):
        kernel_size_start, kernel_size_end = self.kernel_size_range

        
        cd_components, ed_components, md_components = [], [], []
        for kernel_size in range(kernel_size_start, kernel_size_end + 1, 2):
            cd_components.append(
                tf.transpose(
                    self.calculate_components(inputs,
                                           kernel_size, 
                                           distance_function=self.chessboard_distance)))
            ed_components.append(
                tf.transpose(
                    self.calculate_components(inputs,
                                           kernel_size, 
                                           distance_function=self.euclidean_distance)))
            md_components.append(
                tf.transpose(
                    self.calculate_components(inputs,
                                           kernel_size, 
                                           distance_function=self.manhattan_distance)))
            
        cd_components = tf.reshape(self.rearrage_metrics(tf.concat(cd_components, axis=1)), shape=(-1, 10, 10))
        ed_components = tf.reshape(self.rearrage_metrics(tf.concat(ed_components, axis=1)), shape=(-1, 10, 10))
        md_components = tf.reshape(self.rearrage_metrics(tf.concat(md_components, axis=1)), shape=(-1, 10, 10))
        
        outputs = tf.concat([tf.expand_dims(cd_components, axis=3), 
                             tf.expand_dims(ed_components, axis=3),
                             tf.expand_dims(md_components, axis=3)], 
                            axis=3)
        
        return tf.image.resize(outputs, size=(224, 224))

Ładujemy dane do trenowania i walidacji.

In [None]:
datagen = ImageDataGenerator(validation_split=0.2, rescale=1.0)
training_set = datagen.flow_from_directory('/small-data',
                                           target_size=(IMAGE_SIZE, IMAGE_SIZE),
                                           batch_size=BATCH_SIZE,
                                           class_mode='categorical',
                                           subset='training')
validation_set = datagen.flow_from_directory('/small-data',
                                             target_size=(IMAGE_SIZE, IMAGE_SIZE),
                                             batch_size=BATCH_SIZE,
                                             class_mode='categorical',
                                             subset='validation')

Zapisujemy ilość rozpoznawalnych diagnoz.

In [None]:
DIAGNOSIS_NUMBER = len(training_set.class_indices)

Tworzymy model, który wykorzystuje wcześniej zdefiniowaną warstwę.

In [None]:
fractal_model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)),
    Fractal2D(),
    hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 
                   output_shape=[1280],
                   trainable=False),
    tf.keras.layers.Dense(DIAGNOSIS_NUMBER, activation='softmax')
])

In [None]:
fractal_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
fractal_model.fit(training_set, validation_data=validation_set, epochs=2)

Tworzymy model, który pracuje bezpośrednio z obrazkami.

In [None]:
original_model = Sequential([
    hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", output_shape=[1280],
                   trainable=False),
    Dense(DIAGNOSIS_NUMBER, activation='softmax')
])
original_model.build([None, 224, 224, 3])

In [None]:
original_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
original_model.fit(training_set, validation_data=validation_set, epochs=2)