
# Predict Local Spatial Frequency with a TensorFlow/Keras Model (Multi‑Output)

This notebook reproduces **Copy_of_testMultipleResponsesPredictUsingML.asv** but using a
**TensorFlow/Keras** multi-output regression model. It reads the model paths from an Excel DB row,
builds features on sliding patches, predicts `[w, φx, φy, θ]` maps, and draws quick diagnostics.


In [None]:

# --- Setup ---
# %pip install tensorflow==2.16.* scikit-learn==1.4.* pandas openpyxl matplotlib scipy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from ml_spatialfreq_utils import (
    TrainedModelTF, DBInfo, calc_spatial_freqs_supervised_regression_batch, synth_fringe
)

plt.rcParams["figure.figsize"] = (5.5, 4.2)


In [None]:

# --- Resolve model paths from the DB Excel ---
rootFolderDB = Path('./ML_Models')              # Adjust if needed
db_name = 'DB-trainingSets-OM4M007.xlsx'        # Same as MATLAB
db_sheet = 'Sheet1'                             
trainingSet_Idx = 7                              # 1-based row index

db_path = rootFolderDB / db_name
print("Reading DB:", db_path.resolve())
db_tb = pd.read_excel(db_path, sheet_name=db_sheet)

row_idx = trainingSet_Idx - 1
assert 0 <= row_idx < len(db_tb), "trainingSet_Idx out of range"

row = db_tb.iloc[row_idx].to_dict()

# Expected columns for Python artifacts (preferred). If missing, you can fill them below.
kerasModelPath = row.get('kerasModelPath')
scalerPath     = row.get('scalerPath')
metaPath       = row.get('metaPath')
featureName    = row.get('featureName', 'feature_projected_DFT')
patch_NR       = int(row.get('patch_NR', 21))
patch_NC       = int(row.get('patch_NC', 21))

# If your DB only has a MATLAB trainedModel .mat file, manually set Keras paths here:
if not isinstance(kerasModelPath, str) or not kerasModelPath:
    # Fallback example (edit to your actual locations)
    kerasModelPath = str(Path('models') / 'trainedModel.keras')
    scalerPath     = str(Path('models') / 'scaler.pkl')
    metaPath       = str(Path('models') / 'feature_metadata.json')

print("kerasModelPath:", kerasModelPath)
print("scalerPath    :", scalerPath)
print("metaPath      :", metaPath)
print("featureName   :", featureName, " patch:", patch_NR, "x", patch_NC)


In [None]:

# --- Build or load a test fringe pattern (like the MATLAB script) ---
NR, NC = 511, 512
w0_x, w0_y = np.pi/4, np.pi/4
g = synth_fringe(NR, NC, w0_x, w0_y, modulation=1.0, background=0.0, noise_std=0.01)

plt.figure(); plt.imshow(g, cmap='gray'); plt.title('Fringe pattern g'); plt.axis('off'); plt.show()

M_ROI = np.ones_like(g, dtype=bool)  # full image ROI


In [None]:

# --- Predict spatial-frequency maps ---
dbi = DBInfo(featureName=str(featureName), patch_NR=int(patch_NR), patch_NC=int(patch_NC))
tm = TrainedModelTF(model_path=kerasModelPath, scaler_path=scalerPath, meta_path=metaPath, DB_info=dbi)

w_phi, phi_x, phi_y, theta, QM, M_proc = calc_spatial_freqs_supervised_regression_batch(
    g, tm, feature_name=str(featureName), M_ROI=M_ROI
)

# Masks
MNan = M_proc.copy().astype(float)
MNan[~QM] = np.nan

# Display predictions (where valid)
fig, axs = plt.subplots(2,2, figsize=(9,7))
im = axs[0,0].imshow(w_phi*MNan, cmap='viridis'); axs[0,0].set_title('w_φ (rad/px)'); plt.colorbar(im, ax=axs[0,0])
im = axs[0,1].imshow(phi_x*MNan, cmap='viridis'); axs[0,1].set_title('φ_x (rad/px)'); plt.colorbar(im, ax=axs[0,1])
im = axs[1,0].imshow(phi_y*MNan, cmap='viridis'); axs[1,0].set_title('φ_y (rad/px)'); plt.colorbar(im, ax=axs[1,0])
im = axs[1,1].imshow(theta*MNan, cmap='twilight'); axs[1,1].set_title('θ (rad)'); plt.colorbar(im, ax=axs[1,1])
for ax in axs.ravel(): ax.axis('off')
plt.tight_layout(); plt.show()


In [None]:

# --- Histograms on valid pixels (as in MATLAB) ---
mask = QM
edges1 = np.linspace(-np.pi, np.pi, 100)

plt.figure(); plt.hist(w_phi[mask], bins=edges1); plt.title('hist(w_φ)'); plt.xlabel('rad/px'); plt.show()
plt.figure(); plt.hist(theta[mask], bins=edges1); plt.title('hist(θ)');   plt.xlabel('rad');    plt.show()
plt.figure(); plt.hist(phi_x[mask], bins=edges1); plt.title('hist(φ_x)'); plt.xlabel('rad/px'); plt.show()
plt.figure(); plt.hist(phi_y[mask], bins=edges1); plt.title('hist(φ_y)'); plt.xlabel('rad/px'); plt.show()
