In [None]:
RGB_DIR           = './USA_segmentation/RGB_images'  
NRG_DIR           = './USA_segmentation/NRG_images'  
MASK_DIR          = './USA_segmentation/masks'       
OUTPUT_DIR        = './output'                       
TEST_SIZE         = 0.2                               
VAL_SIZE          = 0.2                              
RANDOM_STATE      = 42                                # random seeds
MAX_TRAIN_SAMPLES = 60000                             # the most training image pixels
IMAGE_EXT         = '.png'                         
MASK_EXT          = '.png'                           
MORPH_KERNEL      = (3,3)                             # Morphological kernel size

import os, glob, time, cv2, numpy as np
from tqdm import tqdm
from skimage.feature import local_binary_pattern
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import jaccard_score, log_loss
import matplotlib.pyplot as plt


def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def load_image(rgb_path, nrg_path):
    rgb = cv2.imread(rgb_path, cv2.IMREAD_COLOR)
    nrg = cv2.imread(nrg_path, cv2.IMREAD_GRAYSCALE)
    return rgb, nrg

def load_mask(mask_path):
    m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    return (m > 127).astype(np.uint8)


def extract_features(rgb, nrg):
    # basic turnel
    B = rgb[:, :, 0].astype(np.float32)
    G = rgb[:, :, 1].astype(np.float32)
    R = rgb[:, :, 2].astype(np.float32)
    # NDVI
    N = nrg.astype(np.float32)
    ndvi = (N - R) / (N + R + 1e-5)
    # HSV 
    hsv = cv2.cvtColor(rgb, cv2.COLOR_BGR2HSV).astype(np.float32)
    H = hsv[:, :, 0]
    S = hsv[:, :, 1]
    V = hsv[:, :, 2]
    # Texture Features
    gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
    lap = cv2.Laplacian(gray, cv2.CV_32F)
    gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0)
    gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1)
    mag = np.sqrt(gx**2 + gy**2)
    lbp = local_binary_pattern(gray, P=16, R=3, method='uniform')
    # Flatten and Stacking
    h, w = gray.shape
    return np.column_stack([
        B.flatten(), G.flatten(), R.flatten(), ndvi.flatten(),
        H.flatten(), S.flatten(), V.flatten(),
        lap.flatten(), mag.flatten(), lbp.flatten()
    ])

if __name__ == '__main__':
    start_time = time.time()
    # 1. start open files to get data
    rgb_files = sorted(glob.glob(os.path.join(RGB_DIR, f'*{IMAGE_EXT}')))
    nrg_files = sorted(glob.glob(os.path.join(NRG_DIR, f'*{IMAGE_EXT}')))
    mask_files= sorted(glob.glob(os.path.join(MASK_DIR, f'*{MASK_EXT}')))
    assert len(rgb_files)==len(nrg_files)==len(mask_files), 'The number of profiles are not same.'

    # 2. split data
    samples = list(zip(rgb_files, nrg_files, mask_files))
    np.random.seed(RANDOM_STATE)
    np.random.shuffle(samples)
    split_pt = int(len(samples)*(1-TEST_SIZE))
    train_samples, test_samples = samples[:split_pt], samples[split_pt:]

    # 3. abstract features
    X_list, y_list = [], []
    for rgb_p, nrg_p, mask_p in tqdm(train_samples, desc='Extract train feats'):
        rgb, nrg = load_image(rgb_p, nrg_p)
        mask = load_mask(mask_p).flatten()
        feats = extract_features(rgb, nrg)
        X_list.append(feats)
        y_list.append(mask)
    X = np.vstack(X_list)
    y = np.hstack(y_list)

    # 4. random select and split validate data
    if X.shape[0] > MAX_TRAIN_SAMPLES:
        idx = np.random.choice(X.shape[0], MAX_TRAIN_SAMPLES, replace=False)
        X, y = X[idx], y[idx]
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=VAL_SIZE, random_state=RANDOM_STATE, stratify=y)

    # 5. random forest model
    param_grid = {'n_estimators':[100,200], 'max_depth':[10,20,None]}
    rf = RandomForestClassifier(class_weight='balanced', random_state=RANDOM_STATE)
    grid = GridSearchCV(rf, param_grid, cv=3, scoring='f1', n_jobs=1)
    grid.fit(X_train, y_train)
    best_rf = grid.best_estimator_
    print(f'Best params: {grid.best_params_}')

    # 6. threshold research
    val_probs = best_rf.predict_proba(X_val)[:,1]
    best_thr, best_iou = 0.5, 0
    for thr in np.linspace(0.2,0.8,13):
        iou = jaccard_score(y_val, (val_probs>thr).astype(int))
        if iou > best_iou:
            best_iou, best_thr = iou, thr
    print(f'Val IoU: {best_iou:.4f} @ Thr: {best_thr:.2f}')

    # 7. evaluate and other processing
    ensure_dir(OUTPUT_DIR)
    kernel = np.ones(MORPH_KERNEL, np.uint8)
    ious, losses = [], []
    for rgb_p, nrg_p, mask_p in tqdm(test_samples, desc='Test eval'):
        rgb, nrg = load_image(rgb_p, nrg_p)
        mask = load_mask(mask_p)
        feats = extract_features(rgb, nrg)
        probs = best_rf.predict_proba(feats)[:,1]
        pred = (probs>best_thr).astype(np.uint8).reshape(mask.shape)
        pred = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel)
        pred = cv2.morphologyEx(pred, cv2.MORPH_OPEN, kernel)
        ious.append(jaccard_score(mask.flatten(), pred.flatten()))
        losses.append(log_loss(mask.flatten(), probs))
        out = os.path.basename(rgb_p).replace(IMAGE_EXT,'_pred.png')
        cv2.imwrite(os.path.join(OUTPUT_DIR,out), pred*255)

    # 8. final result
    print(f'Mean IoU: {np.mean(ious):.4f}, Mean Loss: {np.mean(losses):.4f}')
    print(f'Total runtime: {time.time()-start_time:.1f}s')
    
    # 9. display predit and real image 
    for i, (rgb_p, _, mask_p) in enumerate(test_samples[:5]):
        mask = load_mask(mask_p)
        pred = cv2.imread(
            os.path.join(OUTPUT_DIR, os.path.basename(rgb_p).replace(IMAGE_EXT,'_pred.png')),
            cv2.IMREAD_GRAYSCALE
        ) // 255
        left  = (pred * 255).astype(np.uint8)
        right = (mask * 255).astype(np.uint8)
        pair  = np.concatenate([left, right], axis=1)
        plt.figure(figsize=(8,4))
        plt.imshow(pair, cmap='gray')
        plt.title(f'Sample {i+1} IoU={jaccard_score(mask.flatten(), pred.flatten()):.3f}  (Left=Pred, Right=GT)')
        plt.axis('off')
    plt.show()