In [None]:
import os
import glob
import random
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import pywt

# =========================
# PATHS
# =========================
INPUT_ROOT  = r"F:\NeuTech\CWRU"
OUTPUT_ROOT = r"F:\NeuTech\Signals\with axis\cwru"

CLASS_FILES = {
    "healthy": os.path.join(INPUT_ROOT, "healthy.mat"),
    "inner":   os.path.join(INPUT_ROOT, "inner.mat"),
    "ball":    os.path.join(INPUT_ROOT, "ball.mat"),
    "outer":   os.path.join(INPUT_ROOT, "outer.mat"),
}

IMAGES_PER_CLASS = 2

# =========================
# SAMPLING / CWT SETTINGS
# =========================
DEFAULT_FS = 48_000         # set to 12_000 if your signals are 12kHz
CAP_FMAX_HZ = 20_000        # cap frequency axis (bearing faults region)
FREQ_MIN = 10               # Hz
WAVELET = "cmor1.5-1.0"     # amor-like
N_SCALES = 256

# =========================
# PLOT STYLE
# =========================
LABEL_FONTSIZE = 22
TICK_FONTSIZE  = 18
DPI = 300

random.seed(42)

# =========================
# HELPERS
# =========================
def get_fs_from_mat(mat_dict, default_fs=DEFAULT_FS):
    keys = ["fs", "Fs", "FS", "sampling_rate", "sampling_frequency", "sr", "SR"]
    for k in keys:
        if k in mat_dict:
            try:
                v = float(np.array(mat_dict[k]).squeeze())
                if 1_000 <= v <= 5_000_000:
                    return v
            except Exception:
                pass
    return float(default_fs)


def extract_largest_numeric(mat_dict, file_label="mat"):
    """
    Recursively finds the largest numeric ndarray, including inside struct/cell.
    Returns (path, array)
    """
    best_path, best_arr, best_size = None, None, 0

    def consider(path, arr):
        nonlocal best_path, best_arr, best_size
        if isinstance(arr, np.ndarray) and np.issubdtype(arr.dtype, np.number) and arr.size > best_size:
            best_path, best_arr, best_size = path, arr, arr.size

    def walk(obj, path):
        # dict
        if isinstance(obj, dict):
            for k, v in obj.items():
                if str(k).startswith("__"):
                    continue
                walk(v, f"{path}.{k}")
            return

        # ndarray
        if isinstance(obj, np.ndarray):
            if np.issubdtype(obj.dtype, np.number):
                consider(path, obj)
            elif obj.dtype == object:
                for idx, item in np.ndenumerate(obj):
                    walk(item, f"{path}[{idx}]")
            return

        # matlab struct as numpy.void
        if isinstance(obj, np.void):
            if obj.dtype.names:
                for name in obj.dtype.names:
                    walk(obj[name], f"{path}.{name}")

    walk(mat_dict, file_label)

    if best_arr is None:
        raise ValueError("No numeric ndarray found inside this .mat file.")
    return best_path, best_arr


def pick_1d_signal(arr):
    arr = np.squeeze(np.array(arr))
    if arr.ndim == 1:
        sig = arr
    elif arr.ndim == 2:
        r, c = arr.shape
        if r == 1 or c == 1:
            sig = arr.flatten()
        else:
            sig = arr[np.random.randint(0, r), :] if r <= c else arr[:, np.random.randint(0, c)]
    else:
        sig = arr.flatten()

    sig = np.asarray(sig, dtype=np.float64)
    sig = sig[np.isfinite(sig)]
    if sig.size < 256:
        raise ValueError(f"Signal too short: {sig.size}")
    return sig


def normalize_signal(x):
    x = np.asarray(x, dtype=np.float64)
    x = x - np.mean(x)
    return x / (np.std(x) + 1e-12)


def compute_cwt_db(x, fs):
    dt = 1.0 / fs
    t = np.arange(len(x)) * dt

    fmax = min(CAP_FMAX_HZ, fs / 2.0)
    freqs = np.geomspace(max(FREQ_MIN, 1e-6), fmax, N_SCALES)

    w = pywt.ContinuousWavelet(WAVELET)
    scales = pywt.central_frequency(w) / (freqs * dt)

    coef, _ = pywt.cwt(x, scales, w, sampling_period=dt)
    power = np.abs(coef) ** 2
    power_db = 10.0 * np.log10(power + 1e-12)
    return power_db, freqs, t


def plot_and_save(power_db, freqs, t, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # robust contrast
    vmin = np.percentile(power_db, 5)
    vmax = np.percentile(power_db, 99)

    plt.figure(figsize=(8, 5))
    plt.imshow(
        power_db,
        aspect="auto",
        origin="lower",
        extent=[t[0], t[-1], freqs[0], freqs[-1]],
        interpolation="nearest",
        vmin=vmin,
        vmax=vmax
    )

    plt.xlabel("Time (s)", fontsize=LABEL_FONTSIZE, fontweight="bold")
    plt.ylabel("Frequency (Hz)", fontsize=LABEL_FONTSIZE, fontweight="bold")

    ax = plt.gca()
    ax.tick_params(axis="both", labelsize=TICK_FONTSIZE, width=2)
    for tick in ax.get_xticklabels() + ax.get_yticklabels():
        tick.set_fontweight("bold")

    plt.tight_layout()
    plt.savefig(save_path, dpi=DPI, bbox_inches="tight")
    plt.close()


# =========================
# MAIN
# =========================
def main():
    for cls, mat_path in CLASS_FILES.items():
        if not os.path.isfile(mat_path):
            print(f"[SKIP] Missing file: {mat_path}")
            continue

        out_dir = os.path.join(OUTPUT_ROOT, cls)
        os.makedirs(out_dir, exist_ok=True)

        mat = sio.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
        fs = get_fs_from_mat(mat, default_fs=DEFAULT_FS)

        src_path, arr = extract_largest_numeric(mat, file_label=os.path.basename(mat_path))

        for i in range(1, IMAGES_PER_CLASS + 1):
            try:
                sig = normalize_signal(pick_1d_signal(arr))
                power_db, freqs, t = compute_cwt_db(sig, fs=fs)

                base = os.path.splitext(os.path.basename(mat_path))[0]
                save_path = os.path.join(out_dir, f"{base}_CWT_{i:02d}.png")

                plot_and_save(power_db, freqs, t, save_path)
                print(f"[OK] {cls} | Fs={fs:.0f} Hz | from={src_path} -> {save_path}")

            except Exception as e:
                print(f"[FAIL] {cls} | {os.path.basename(mat_path)} -> {e}")


if __name__ == "__main__":
    main()


[OK] healthy | Fs=48000 Hz | from=healthy.mat.Healthy_CW[(0, 0)][(1,)] -> F:\NeuTech\Signals\with axis\cwru\healthy\healthy_CWT_01.png
[OK] healthy | Fs=48000 Hz | from=healthy.mat.Healthy_CW[(0, 0)][(1,)] -> F:\NeuTech\Signals\with axis\cwru\healthy\healthy_CWT_02.png
