In [13]:
import os
import numpy as np
import cv2
from skimage import io
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from joblib import dump
from sklearn.model_selection import GridSearchCV

In [14]:
def load_images_and_masks(data_dir, mask_dir):
    image_files = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.endswith('.png')]
    mask_files = [os.path.join(mask_dir, f) for f in sorted(os.listdir(mask_dir)) if f.endswith('.png')]

    all_features = []
    all_labels = []

    for image_file, mask_file in zip(image_files, mask_files):
        # Load the image and mask
        image = io.imread(image_file)
        mask = io.imread(mask_file, as_gray=True)

        # Flatten the image to have shape (n_pixels, n_channels)
        features = image.reshape(-1, 3)
        # Flatten the mask and convert to binary labels
        labels = (mask.flatten() > 128).astype(int)  # Adjust threshold as necessary

        all_features.append(features)
        all_labels.append(labels)

    # Stack all features and labels vertically
    return np.vstack(all_features), np.concatenate(all_labels)


In [15]:
def train_random_forest(features, labels):
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.20, random_state=42)
    clf = clf = RandomForestClassifier(n_estimators=100, max_depth=20, max_features=None, min_samples_leaf=5, min_samples_split=10, n_jobs=-1, random_state=42)
    clf.fit(X_train, y_train)
    print("Accuracy on test set:", clf.score(X_test, y_test))
    return clf
    

In [16]:
def save_model(model, filename):
    dump(model, filename)

In [18]:
def main():
    train_img_dir = 'Dataset/images/train'
    train_mask_dir = 'Dataset/masks/train'
    val_img_dir = 'Dataset/images/val'
    val_mask_dir = 'Dataset/masks/val'
    
    # Load training data
    train_features, train_labels = load_images_and_masks(train_img_dir, train_mask_dir)
    
    # Train model
    clf = train_random_forest(train_features, train_labels)
    
    # Optionally, load and evaluate validation data
    val_features, val_labels = load_images_and_masks(val_img_dir, val_mask_dir)
    print("Validation accuracy:", clf.score(val_features, val_labels))
    
    # Save the trained model
    save_model(clf, 'random_forest_model5.joblib')

if __name__ == "__main__":
    main()

Accuracy on test set: 0.821901003519694
Validation accuracy: 0.8172175089518229
