Import the fractal aggregates

In [None]:
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models

def load_dataset(base_path, target_shape=(128, 128)):
    """
    Load the dataset from the specified path, preprocess images, and extract fractal dimensions.

    Parameters:
    base_path (str): Path to the dataset folders.
    target_shape (tuple): Target shape for resizing images.

    Returns:
    tuple: (input_tensors, target_values)
    """
    input_tensors = []
    target_values = []

    for folder in sorted(os.listdir(base_path)):
        folder_path = os.path.join(base_path, folder)
        if not os.path.isdir(folder_path):
            continue

        # Load the 3 BMP images
        try:
            images = []
            for axis in ['xy_projection.bmp', 'xz_projection.bmp', 'yz_projection.bmp']:
                image_path = os.path.join(folder_path, axis)
                img = Image.open(image_path).convert('L')  # Convert to grayscale
                img_resized = img.resize(target_shape)
                img_array = np.array(img_resized) / 255.0  # Normalize to [0, 1]
                images.append(img_array)
            
            # Stack images as channels (H, W, 3)
            stacked_images = np.stack(images, axis=-1)
            input_tensors.append(stacked_images)

            # Load fractal dimension from correlation_dimension.txt
            with open(os.path.join(folder_path, 'correlation_dimension.txt'), 'r') as f:
                for line in f:
                    if "Correlation dimension" in line:
                        target_values.append(float(line.split(':')[-1]))
        except Exception as e:
            print(f"Error processing folder {folder_path}: {e}")

    return np.array(input_tensors), np.array(target_values)

# Example usage
base_path = "/Volumes/PortableSSD/flocfractal/data1blurry/"
X, y = load_dataset(base_path)
print(f"Loaded dataset with {X.shape[0]} samples.")

Train the CNN

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def build_cnn_model(input_shape):
    """
    Builds a CNN model for predicting the 3D fractal dimension.

    Parameters:
    input_shape (tuple): Shape of the input tensor (e.g., (128, 128, 3)).

    Returns:
    tf.keras.Model: Compiled CNN model.
    """
    model = models.Sequential([
        # Convolutional layers
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),

        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(1)  # Regression output
    ])

    # Compile the model
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model

Save the trained h5

In [None]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Split the dataset into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Build the CNN model
input_shape = X_train.shape[1:]  # e.g., (128, 128, 3)
model = build_cnn_model(input_shape)

# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=7,
    batch_size=32
)

# Evaluate the model on the test set
test_loss, test_mae = model.evaluate(X_test, y_test)
print(f"Test Loss (MSE): {test_loss:.4f}")
print(f"Test MAE: {test_mae:.4f}")

# Save the trained model
model.save("/Volumes/PortableSSD/flocfractal/cnn_model_grayscale_blurry.h5")

# Plot training history
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss (MSE)')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()