In [None]:
import copy
%load_ext autoreload
%autoreload 1
# %matplotlib widget
import matplotlib
import librosa

matplotlib.use('QT5Agg')
import numpy as np
import pandas as pd
import os, pickle, re
from scipy import signal
import matplotlib.pyplot as plt
from magtach.op_codes.preprocess_functions import read_files, min_max_norm

In [None]:
def generate_resample_factors(start: float, end: float, step: float = 0.001):
    resample_factors = []
    rounding_precision = len(str(step).split('.')[1])
    # resample_factors.append(0.32)
    current = start
    while current < end:
        current += step
        current = np.round(current, rounding_precision)
        resample_factors.append(current)
    return resample_factors


def generate_resample_factors_v2(cur_fund_freq , start, end, step, rounding_precision=6):
    freq_lst = np.arange(start, end+step, step)
    freq_lst = freq_lst[freq_lst != 60]

    resample_factors = np.round(cur_fund_freq / freq_lst, rounding_precision)

    mask = (resample_factors > 0.1) & (resample_factors < 2.0)

    freq_generated = freq_lst[mask]
    resample_factors = resample_factors[mask]

    return resample_factors, freq_generated


# strt_pnt = 0.3
# resampl_factor_lst = generate_resample_factors(start=strt_pnt, end=1.9, step=0.01)
def scale_psd(orig_psd_db, final_length_psd, scale_factor, method="average"):
    if scale_factor is None or scale_factor == 1:
        return orig_psd_db[:final_length_psd]

    if method == "decimate":
        ### pick every nth bin
        cur_psd = orig_psd_db[::scale_factor]

    elif method == "nbins_average":
        orig_psd_copy = orig_psd_db[:final_length_psd * scale_factor].copy()
        cur_psd = []
        window_length = scale_factor
        for ind in range(0, len(orig_psd_copy), window_length):
            cur_psd.append(np.mean(orig_psd_copy[ind:ind + window_length]))

    elif method == "average":
        ### Average pooling sliding bins of length = scale factor
        ### no overlap
        cur_psd = orig_psd_db[:final_length_psd * scale_factor].reshape(final_length_psd, scale_factor).mean(axis=1)
    elif method == "median":
        ### Average pooling sliding bins of length = scale factor
        ### no overlap
        cur_psd = orig_psd_db[:final_length_psd * scale_factor].reshape(final_length_psd, scale_factor).median(axis=1)

    elif method == "max":
        ### Max pooling sliding bins of length = scale factor
        ### no overlap
        cur_psd = orig_psd_db[:final_length_psd * scale_factor].reshape(final_length_psd, scale_factor).max(axis=1)

    elif method == "softmax":
        blocks = orig_psd_db[:final_length_psd * scale_factor].reshape(final_length_psd, scale_factor)

        # subtract max for numerical stability
        exps = np.exp((blocks - np.max(blocks, axis=1, keepdims=True)) / 1.0)  # temperature=1.0
        weights = exps / np.sum(exps, axis=1, keepdims=True)

        cur_psd = np.sum(blocks * weights, axis=1)

    elif method == "resample":
        ### resample the psd to desired length
        cur_psd = orig_psd_db[:int(final_length_psd * scale_factor)]
        cur_psd = signal.resample(cur_psd, len(cur_psd) // scale_factor)

    else:
        raise ValueError("method must be 'decimate', 'average', or 'resample'")

    return cur_psd[:final_length_psd]


def compute_harmonic_peaks(fundamental_freq: int, spectral_window: int = 1024, max_iter: int = 500):
    harmonics = []
    j = 0
    for _ in range(max_iter):
        if j + fundamental_freq > spectral_window:
            break
        j += fundamental_freq
        harmonics.append(j)

    peaks = [int(np.round(i, 0)) for i in harmonics]
    return peaks


import numpy as np


def make_true_mask_v2(cur_fund_freq, psd_length=1024, sigma=4.0, n_harmonics=100, mode="gaussian", band_width=2):
    # Collect harmonic indices
    cur_fund_freq_lst = []
    j = 0
    for i in range(n_harmonics):
        if j + cur_fund_freq >= psd_length - 2:
            break
        j += cur_fund_freq
        cur_fund_freq_lst.append(j)

    cur_fund_freq_lst = np.array([int(np.round(i)) for i in cur_fund_freq_lst])

    # Build mask
    true_mask = np.zeros(psd_length)
    x = np.arange(psd_length)

    if mode == "gaussian":
        for f in cur_fund_freq_lst:
            gaussian = (1 / np.sqrt(2 * np.pi * sigma ** 2)) * np.exp(-0.5 * ((x - f) / sigma) ** 2)
            true_mask = np.maximum(true_mask, gaussian)

        # Normalize so max=1
        if np.max(true_mask) > 0:
            true_mask /= np.max(true_mask)

    elif mode == "block":
        for f in cur_fund_freq_lst:
            start = max(0, f - band_width)
            end = min(psd_length, f + band_width + 1)
            true_mask[start:end] = 1.0

    else:
        raise ValueError(f"Unknown mode: {mode}, choose 'gaussian' or 'block'")

    return true_mask, cur_fund_freq_lst


def plot_psds_with_mask(cur_psds, true_mask, file="psd_plot",
                        resamp_factor=1, cur_fund_freq=None,
                        psd_scale_down_factors=None,
                        save_fig=False, out_dir="./"):
    n_psds = len(cur_psds)
    fig, axes = plt.subplots(n_psds, 1, figsize=(8, 1.5 * n_psds), sharex=False)

    if n_psds == 1:
        axes = [axes]  # make iterable for single PSD case

    title_str = f"{file}_{resamp_factor}"
    if cur_fund_freq is not None:
        title_str += f"_{cur_fund_freq:.2f}"

    fig.suptitle(title_str)

    for idx, psd in enumerate(cur_psds):
        axes[idx].plot(psd, label="PSD")
        axes[idx].plot(true_mask, label="True Mask", linestyle="--")

        if psd_scale_down_factors is not None:
            axes[idx].set_title(f"PSD {idx + 1} (factor {psd_scale_down_factors[idx]})")
        else:
            axes[idx].set_title(f"PSD {idx + 1}")

        axes[idx].legend(loc="upper right")

    plt.tight_layout()

    if save_fig:
        os.makedirs(out_dir, exist_ok=True)
        fig.savefig(os.path.join(out_dir, f"{file}.png"))
    else:
        plt.show()

import numpy as np

def comb_scores_linear(psd, fmin=10, fmax=None, Kmin=3, weights="inv"):
    # psd: shape [F], 1 Hz spacing assumed
    F = len(psd)
    if fmax is None: fmax = F - 1
    # # normalize PSD to tame huge lines
    # P = psd / (np.median(psd[(fmin//2):fmax]) + 1e-12)
    P = psd

    def get_w(K):
        k = np.arange(1, K+1, dtype=float)
        if weights == "inv": w = 1.0 / k
        elif weights == "inv2": w = 1.0 / (k**2)
        else: w = np.ones_like(k)
        return w

    H = np.zeros(F)
    for F0 in range(fmin, fmax+1):
        K = fmax // F0
        if K < Kmin:
            continue
        w = get_w(K)
        vals = []
        for k in range(1, K+1):
            fk = k * F0
            lo = int(np.floor(fk))
            hi = min(lo + 1, fmax)
            t = fk - lo
            pk = (1 - t) * P[lo] + t * P[hi]
            vals.append(pk)
        w = w / (w.sum() + 1e-12)
        H[F0] = np.dot(w, np.array(vals))
    return H  # maximal H index is your F0 estimate



### Training
2. compute the hps different scales of welch
3. store the welch, filename, ground truth mas for PPSP

1. accept a list of training folder
    - variable to select the psd length
    - For each file
      - Generate augmented files
        - For each augmented file
          - Welch of each file
          - Summed version of scaled down welch


In [None]:
from collections import OrderedDict

root_folder_pth = f"../data/"
# train_folders_lst = [("fan5_3spd_augment", 41.97), ("fan3_3spd_augment", 96.7), ("bldc_1_augment", 94.98),
#                      ("bldc_2_augment", 80.25)("bldc_5_augment",45.5),("fan5_1spd_augment",27.82)("bldc_3_augment",55.55),]
# train_folders_lst = [("fan5_2spd_augment", 35.55),("bldc_2_augment", 80.25),("fan5_3spd_augment", 41.97),("bldc_1_augment", 94.98)]
train_folders_lst = [("fan5_3spd_augment", 41.97),("bldc_2_augment", 80.25)]


psd_length = 1024
fs = 44100
ss_num_chunks, welch_num_chunks = 3, 2
psd_scale_down_factors = [1, 2, 3, 'sum']
plot_fig, save_fig = True, False
### calculating resample factors
# strt_pnt = 0.3
# resampl_factor_lst = generate_resample_factors(start=strt_pnt, end=1.9, step=0.01)
pattern = re.compile(r"(\d+)cm")
block_width = 3
augmented_freq_gen_lst = []
train_val_dict = OrderedDict()
save_train_data = False
mode = "block"
if save_train_data:
    for _, (train_folder, fund_freq) in enumerate(train_folders_lst):
        print(train_folder, fund_freq)

        cur_folder_pth = os.path.join(root_folder_pth, train_folder)
        if not os.path.exists(cur_folder_pth):
            print(f"{cur_folder_pth} does not exist")
            continue
        else:
            data_folder_pth = os.path.join(cur_folder_pth, "orig_files")
            files_lst = os.listdir(data_folder_pth)

            distances_lst = []
            train_x, train_y = [], []

            for file in files_lst:

                if file not in [".DS_Store"]:
                    print(file)
                    if train_folder.split("_")[0] in ["bldc"]:
                        cur_file_rpm = int(file.split('_')[2])
                        fund_freq = cur_file_rpm/60
                    cur_file_pth = os.path.join(data_folder_pth, file)
                    cur_signal = np.sum(read_files(cur_file_pth), axis=0)

                    resampl_factor_lst, freq_generated = generate_resample_factors_v2(fund_freq, 22, 250, 3, 4)
                    # resampl_factor_lst, freq_generated = [1],[fund_freq]

                    for resamp_ind, resamp_factor in enumerate(resampl_factor_lst):
                        try:
                            match = pattern.search(file)
                            cm_value = int(match.group(1))
                            distances_lst.append(cm_value)
                        except:
                            distances_lst.append(0)

                        cur_psds = []
                        ### Adjusted fundamental frequency
                        cur_fund_freq = fund_freq / resamp_factor

                        # print(cur_fund_freq)
                        true_mask, cur_fund_freq_lst = make_true_mask_v2(cur_fund_freq, psd_length=psd_length, sigma=block_width, n_harmonics=100,
                                                      mode=mode, band_width=block_width)
                        augmented_freq_gen_lst.append(cur_fund_freq_lst[0])

                        if 0.0 < resamp_factor < 3.0:
                            # print(f"Resample factor {resamp_factor}, cur signal id {file}")
                            resample_sig = librosa.resample(cur_signal, orig_sr=fs, target_sr=int(fs * resamp_factor))

                            resamp_num_pnts = len(resample_sig) // welch_num_chunks

                            if resamp_factor >= 2.0:
                                resamp_freq_ss, resamp_Pxx_ss = signal.welch(resample_sig, fs, nperseg=resamp_num_pnts,
                                                                             nfft=int(fs * resamp_factor))
                            else:
                                resamp_freq_ss, resamp_Pxx_ss = signal.welch(resample_sig, fs, nperseg=resamp_num_pnts,
                                                                             nfft=fs)

                            resamp_log_Pxx_ss = np.log(resamp_Pxx_ss)
                            #### perform spectrum downsampling
                            cur_psds = [min_max_norm(resamp_log_Pxx_ss[:psd_length])]

                            for scale_factor in psd_scale_down_factors:
                                if scale_factor not in [1, 'sum']:
                                    cur_scaled_psd = scale_psd(resamp_log_Pxx_ss, psd_length, scale_factor,
                                                               method="average")
                                    # cur_psds.append(cur_scaled_psd)
                                    cur_psds.append(min_max_norm(cur_scaled_psd))

                            stacked = np.stack(cur_psds, axis=0)
                            max_psd = np.median(stacked, axis=0)
                            cur_psds.append(min_max_norm(max_psd))

                            # h = comb_scores_linear(cur_psds[4])
                            # print()

                        cur_x, cur_y = np.array(cur_psds), true_mask
                        train_x.append(cur_x)
                        train_y.append(cur_y)
                        #
                        # plot_psds_with_mask(cur_psds,true_mask,file=f"{file}_{resamp_factor}.png", resamp_factor=resamp_factor,cur_fund_freq=cur_fund_freq,psd_scale_down_factors=psd_scale_down_factors,save_fig=save_fig,out_dir=f"./conv2d_data/train_plots/")

                    # print()
            train_val_dict[train_folder] = [train_x, train_y, distances_lst, psd_scale_down_factors]

    df = pd.DataFrame({
                'augmented_frequency': augmented_freq_gen_lst,
            })

    #### Count occurrences of each frequency
    df_counts = (
        df['augmented_frequency']
        .value_counts()
        .rename_axis("augmented_frequency")
        .reset_index(name="count")
        .sort_values("augmented_frequency")
    )

    # Save to CSV
    df_counts.to_csv("augmented_frequency_lst.csv", index=False)

    with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}.pkl", "wb") as f:
        pickle.dump(train_val_dict, f)

    print(f"Saved dictionary to conv2d_psd_scaled_down_1up_{mode}.pkl")


In [None]:
# augmented_freq_csv = pd.read_csv(f"augmented_frequency_lst.csv", header=0, index_col=None)
# print()

In [None]:
from sklearn.model_selection import train_test_split


def process_loaded_dict(loaded_dict, row_indices=[0, 4], col_size=512, val_ratio=0.001,
                        random_state=42, output_format="channels_first"):
    processed_data = {}
    for key, (train_x, train_y, distances, *_) in loaded_dict.items():
        train_x = np.array(train_x)
        train_y = np.array(train_y)
        distances = np.array(distances)

        ### sort by distance
        sorted_indices = np.argsort(distances)
        train_x = train_x[sorted_indices]
        train_y = train_y[sorted_indices]
        distances = distances[sorted_indices]

        ### select rows and columns size
        X = train_x[:, row_indices, :col_size]
        # train_y = train_y[:,:col_size]

        ### reshape based on desired output format
        if output_format == "channels_first":
            # [N, num_channels, 1, col_size] - channels first
            X = X[:, :, np.newaxis, :]
        elif output_format == "channels_last":
            # [N, 1, num_channels, col_size] - channels last
            X = X[:, np.newaxis, :, :]
        else:
            raise ValueError("output_format must be 'channels_first' or 'channels_last'")

        ### split into train/val sets
        X_train, X_val, y_train, y_val, dist_train, dist_val = train_test_split(
            X, train_y, distances, test_size=val_ratio, random_state=random_state
        )

        processed_data[key] = (X_train, X_val, y_train, y_val, dist_train, dist_val)

    return processed_data


import itertools

nums = [0, 1, 2, 3]

# --- List of tuples (default from itertools.combinations) ---
all_combs_tuples = []
for r in range(2, len(nums) + 1):
    all_combs_tuples.extend(itertools.combinations(nums, r))

# all_combs_lists = [list(c) for c in all_combs_tuples if 0 in c]
all_combs_lists = [[0], [3], [0, 1, 2, 3]]
print(all_combs_lists)

with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}.pkl", "rb") as f:
    loaded_dict = pickle.load(f)

# ### [N, 4, 1, 512] - channels first
# processed_channels_first = process_loaded_dict(
#     loaded_dict,
#     row_indices=[0, 4],
#     col_size=1024,
#     val_ratio=0.2,
#     output_format="channels_first"
# )


In [None]:

### X_train, X_val, y_train, y_val, dist_train, dist_val = processed['fan3_3spd_augment']
aggregated_combs_data_lst = []
for comb in all_combs_lists:

    ### [N, 1, 4, 512] - channels last
    processed_channels_last = process_loaded_dict(
        loaded_dict,
        row_indices=comb,
        col_size=1024,
        val_ratio=0.2,
        output_format="channels_last"
    )
    all_X_train, all_X_val = [], []
    all_y_train, all_y_val = [], []
    all_dist_train, all_dist_val = [], []

    for key, (X_train, X_val, y_train, y_val, dist_train, dist_val) in processed_channels_last.items():
        all_X_train.append(X_train)
        all_X_val.append(X_val)
        all_y_train.append(y_train)
        all_y_val.append(y_val)
        all_dist_train.append(dist_train)
        all_dist_val.append(dist_val)

    ### Concatenate along first axis
    # X_train = np.concatenate(all_X_train, axis=0)
    # X_val   = np.concatenate(all_X_val, axis=0)
    # y_train = np.concatenate(all_y_train, axis=0)
    # y_val   = np.concatenate(all_y_val, axis=0)
    # dist_train = np.concatenate(all_dist_train, axis=0)
    # dist_val   = np.concatenate(all_dist_val, axis=0)

    X_train = np.concatenate(all_X_train, axis=0).squeeze(1)
    X_val = np.concatenate(all_X_val, axis=0).squeeze(1)
    y_train = np.concatenate(all_y_train, axis=0)
    y_val = np.concatenate(all_y_val, axis=0)
    dist_train = np.concatenate(all_dist_train, axis=0)
    dist_val = np.concatenate(all_dist_val, axis=0)

    aggregated_combs_data_lst.append((X_train, X_val, y_train, y_val, dist_train, dist_val))

print(f"Datasets length={len(aggregated_combs_data_lst)}, combination_length={len(all_combs_lists)}")


In [None]:
from magtach.op_codes.nn_functions import DiceLoss, OverlapDiceLoss
from magtach.op_codes.nn_models import FPN_2D, FPN_1D, FPN_2D_regression, FPN_1D_regression, FPN_2_1up, FPN_2, \
    FPN_2_1up, FCN1D

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from magtach.op_codes.mtl_models import MTL_1, MTL_2
from magtach.op_codes.ppsp_mtl import FPN_2_mtl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

def plot_region_masks(model_1_output, model_2_output, index=0):
    """
    model_1_output: list like [[region_mask_f0, gt_y], ...]
    model_2_output: list like [[mask_f0, mask_2f0, mask_3f0, mask_4f0], ...]
    index: which sample to visualize
    """
    # unpack
    f0_mask, gt_y = model_1_output[index]
    masks = model_2_output[index]  # list of harmonic masks
    x = np.arange(len(f0_mask))

    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    # ---- Row 1: fundamental vs ground truth ----
    axes[0].plot(x, gt_y, label='Ground Truth (all harmonics)', color='gray', linewidth=2, alpha=0.7)
    axes[0].plot(x, f0_mask, label='Fundamental region (Head1)', color='blue', linewidth=2)
    axes[0].set_title('Model 1 Output — Fundamental vs Ground Truth')
    axes[0].legend()
    axes[0].grid(True, linestyle='--', alpha=0.4)

    # ---- Row 2: individual harmonic regions ----
    colors = ['r', 'g', 'b', 'orange']
    labels = ['f0', '2f0', '3f0', '4f0']

    for i, mask in enumerate(masks):
        axes[1].plot(x, mask, color=colors[i % len(colors)], label=labels[i], linewidth=2)

    axes[1].set_title('Model 2 Output — Individual Harmonic Regions')
    axes[1].legend()
    axes[1].grid(True, linestyle='--', alpha=0.4)

    plt.xlabel('Frequency Bin')
    plt.tight_layout()
    plt.show()

def make_gaussian(zero_mask, center_idx, psd_length=1024, sigma=block_width):
    """Smooth Gaussian mask centered at center_idx."""
    x = np.arange(psd_length)
    gaussian = (1 / np.sqrt(2 * np.pi * sigma ** 2)) * np.exp(-0.5 * ((x - center_idx) / sigma) ** 2)
    gaussian /= gaussian.max()
    final_mask = np.maximum(zero_mask, gaussian)
    return final_mask.astype(np.float32)


def make_block_mask(zero_mask, center_idx, psd_length=1024, block_width=block_width):
    """Binary block mask: 1 in a window around center_idx, else 0."""
    start = max(0, center_idx - block_width)
    end = min(psd_length, center_idx + block_width + 1)
    mask = zero_mask.copy()
    mask[start:end] = 1.0
    return mask.astype(np.float32)


def remake_targets(batch_y, total_num_outputs=6, mode=mode, psd_length=1024, block_width=block_width):
    """
    Builds multi-task targets for both fundamental and synthetically generated harmonics (2F0, 3F0, 4F0).
    mode: 'gaussian' or 'block'
    """
    batch_y_np = batch_y.numpy()
    model1_list, model2_list = [], []

    for b in range(batch_y_np.shape[0]):
        cur_mask = batch_y_np[b]
        f0_indices = np.where(cur_mask == 1)[0]
        if len(f0_indices) == 0:
            continue

        f0_idx = int(f0_indices[0])+block_width
        zero_mask = np.zeros(psd_length, dtype=np.float32)

        # --- compute synthetic harmonic indices ---
        harmonic_indices = [f0_idx * i for i in range(1, 5)]
        harmonic_indices = [min(idx, psd_length - 1) for idx in harmonic_indices]

        # --- make region masks ---
        if mode == "gaussian":
            f0_mask = make_gaussian(zero_mask, harmonic_indices[0], psd_length, block_width)
            f2_mask = make_gaussian(zero_mask, harmonic_indices[1], psd_length, block_width)
            f3_mask = make_gaussian(zero_mask, harmonic_indices[2], psd_length, block_width)
            f4_mask = make_gaussian(zero_mask, harmonic_indices[3], psd_length, block_width)
        elif mode == "block":
            f0_mask = make_block_mask(zero_mask, harmonic_indices[0], psd_length, block_width)
            f2_mask = make_block_mask(zero_mask, harmonic_indices[1], psd_length, block_width)
            f3_mask = make_block_mask(zero_mask, harmonic_indices[2], psd_length, block_width)
            f4_mask = make_block_mask(zero_mask, harmonic_indices[3], psd_length, block_width)
        else:
            raise ValueError("mode must be 'gaussian' or 'block'")

        # --- stack for both model heads ---
        model1_list.append(np.stack([f0_mask, cur_mask], axis=0))
        model2_list.append(np.stack([f0_mask, f2_mask, f3_mask, f4_mask], axis=0))

    model1_targets = torch.from_numpy(np.stack(model1_list)).float().to(device)
    model2_targets = torch.from_numpy(np.stack(model2_list)).float().to(device)
    return model1_targets, model2_targets


# device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(f"Using device: {device}")


### Training function
def train_fpn2d_model(X_train, y_train, X_val, y_val, num_epochs=50, batch_size=100, learning_rate=0.001, model=None,
                      model_name="", train_model_flag=False):
    """
    Train FPN_2D model with format: [batch_size, 1, channels, 512]
    """
    ### Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train)
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.FloatTensor(y_val)

    print(f"Training data shapes:")
    print(f"X_train_tensor: {X_train_tensor.shape}")  # [num_samples, 1, selected_rows, 512]
    print(f"y_train_tensor: {y_train_tensor.shape}")  # [num_samples, 512]
    print(f"X_val_tensor: {X_val_tensor.shape}")  # [num_samples, 1, selected_rows, 512]
    print(f"y_val_tensor: {y_val_tensor.shape}")  # [num_samples, 512]

    ### Create datasets and dataloaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    ### Initialize model, loss, optimizer

    model.to(device)
    # criterion = DiceLoss()
    criterion = nn.BCEWithLogitsLoss()
    criterion = OverlapDiceLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)

    if train_model_flag:

        # Training variables
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        best_model_weights = None

        ### Training loop
        for epoch in range(num_epochs):
            ### Training phase
            model.train()
            train_loss = 0.0

            for batch_x, batch_y in train_loader:
                model1_targets, model2_targets = remake_targets(batch_y)
                plot_region_masks(model1_targets, model2_targets)
                batch_x = batch_x.to(device)  # [batch_size, 1, 2, 512]
                batch_y = batch_y.to(device)  # [batch_size, 512]


                optimizer.zero_grad()
                outputs = model(batch_x)  # [batch_size, 1, 2, 512] -> [batch_size, 1, 1, 512]
                # loss = criterion(outputs.squeeze(1), batch_y)  # Remove extra dimensions
                ### loss function for model mtl1
                # loss_fundamental = criterion(outputs[0].squeeze(1), model1_targets[:, 0, :])
                loss_fundamental = criterion(outputs.squeeze(1), model1_targets[:, 0, :])
                # loss_harmonic = criterion(outputs[1].squeeze(1), model1_targets[:, 1, :])

                alpha = 0.5
                # loss = alpha * loss_fundamental + (1 - alpha) * loss_harmonic
                loss = loss_fundamental
                ### loss function for model mtl2
                # loss_f0 = criterion(outputs[0].squeeze(1), model2_targets[:, 0, :])
                # loss_2f0 = criterion(outputs[1].squeeze(1), model2_targets[:, 1, :])
                # loss_3f0 = criterion(outputs[2].squeeze(1), model2_targets[:, 2, :])
                # loss_4f0 = criterion(outputs[3].squeeze(1), model2_targets[:, 3, :])
                # loss = (loss_f0 + loss_2f0 + loss_3f0 + loss_4f0)*(1/model2_targets.shape[1])

                loss.backward()
                optimizer.step()

                train_loss += loss.item() * batch_x.size(0)

            train_loss /= len(train_loader.dataset)
            train_losses.append(train_loss)

            ### Validation phase
            model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for batch_x, batch_y in val_loader:
                    model1_targets, model2_targets = remake_targets(batch_y)
                    batch_x = batch_x.to(device)  # [batch_size, 1, 2, 512]
                    batch_y = batch_y.to(device)  # [batch_size, 512]

                    outputs = model(batch_x)  # [batch_size, 1, 1, 512]
                    # loss = criterion(outputs.squeeze(1), batch_y)

                    ### loss function for model mtl1
                    # loss_fundamental = criterion(outputs[0].squeeze(1), model1_targets[:, 0, :])
                    loss_fundamental = criterion(outputs.squeeze(1), model1_targets[:, 0, :])
                    # loss_harmonic = criterion(outputs[1].squeeze(1), model1_targets[:, 1, :])

                    alpha = 0.5
                    # loss = alpha * loss_fundamental + (1 - alpha) * loss_harmonic
                    loss = loss_fundamental
                    ### loss function for model mtl2
                    # loss_f0 = criterion(outputs[0].squeeze(1), model2_targets[:, 0, :])
                    # loss_2f0 = criterion(outputs[1].squeeze(1), model2_targets[:, 1, :])
                    # loss_3f0 = criterion(outputs[2].squeeze(1), model2_targets[:, 2, :])
                    # loss_4f0 = criterion(outputs[3].squeeze(1), model2_targets[:, 3, :])
                    # loss = (loss_f0 + loss_2f0 + loss_3f0 + loss_4f0)*(1/model2_targets.shape[1])

                    val_loss += loss.item() * batch_x.size(0)

            val_loss /= len(val_loader.dataset)
            val_losses.append(val_loss)

            ### Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_weights = model.state_dict().copy()
                torch.save(best_model_weights, f'{model_name}')

            if (epoch + 1) % 50 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], '
                      f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
                      f'Best Val: {best_val_loss:.4f}')

        cur_combination = model_name.split('_')[4].split('.')[0]
        # Plot training history
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('MSE Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'training_history_{cur_combination}.png', dpi=300, bbox_inches='tight')
        plt.close()
        # plt.show()

        ### Load best model
        model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device(device)))
        model.eval()
        print(f"Best validation loss: {best_val_loss:.6f}")
    else:
        pass
        ### Load best model
        # model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device(device)))
        # model.eval()

    return model


for ind, (X_train, X_val, y_train, y_val, dist_train, dist_val) in enumerate(aggregated_combs_data_lst):
    # model = FPN_2(in_channels=X_train.shape[1], out_channels=32)
    # model = FPN_2_1up(in_channels=X_train.shape[1], encoded_channels=32, output_size=1024)
    # model = FCN1D(in_channels=X_train.shape[1], num_classes=1024, kernel_size=7)
    # model1 = MTL_1(in_ch=X_train.shape[1], seq_len=1024, base=32, hidden=512, dilations=(1, 1, 1))
    # model1 = MTL_2(in_ch=X_train.shape[1], seq_len=1024, base=32, hidden=512, num_heads=4,dilations=(1, 1, 1))
    model1 = FPN_2_mtl(in_channels=X_train.shape[1],out_channels=32)
    # model = FPN_2D(in_channels=1,base_channels=32)
    print(f"Training model: best_fpn1d_model_{mode}_{all_combs_lists[ind]}.pth")
    ### Train the model
    trained_model = train_fpn2d_model(
        X_train, y_train, X_val, y_val,
        num_epochs=200,
        batch_size=100,
        learning_rate=0.001,
        model=model1,
        model_name=f"best_fpn2_1up_model_{mode}_{all_combs_lists[ind]}.pth",
        train_model_flag=False
    )


In [None]:

# model = FPN_1D(in_channels=1, base_channels=16, output_size=y_train.shape[1])
# model = FPN_2D(in_channels=1, base_channels=16, output_size=y_train.shape[1])

# model = FPN_1D_regression(in_channels=1, base_channels=32, input_size=psd_length)
# model = FPN_2D_regression(in_channels=1, base_channels=32,input_height=X_train.shape[2], input_width=psd_length)
# criterion = torch.nn.MSELoss()

# model = Conv2d_1up(in_channels=1,nb_classes=y_train.shape[1])

# model = FPN_2_1up(in_channels=X_train.shape[1])
# model = FPN_2()
# model.to(device)
# criterion = DiceLoss()
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)

## Testing


In [None]:
%load_ext autoreload
%autoreload 1
# %matplotlib widget
import matplotlib
import librosa

matplotlib.use('QT5Agg')
import numpy as np
import pandas as pd
import os, pickle, re
from scipy import signal
import matplotlib.pyplot as plt
from magtach.op_codes.preprocess_functions import read_files, min_max_norm

In [None]:
# is_model_training = not(trained_model.training)
# print(F"Model in testing mode: {is_model_training}")
root_folder_pth = f"../data/"
### testing folders bldc_2 and bldc_6
from collections import OrderedDict
### ("bldc_5", 45.5),
test_folders_lst = [ ("bldc_5", 45.5), ("bldc_2", 80.25), ("bldc_6", 131.81)]

psd_length = 1024
fs = 44100
ss_num_chunks, welch_num_chunks = 3, 2
psd_scale_down_factors = [1, 2, 3, 'sum']
plot_fig, save_fig = True, False
### calculating resample factors
pattern = re.compile(r"(\d+)cm")

test_dict = OrderedDict()
save_test_data = False
mode = "gaussian"
if save_test_data:
    for _, (test_folder, fund_freq) in enumerate(test_folders_lst):
        print(test_folder, fund_freq)

        cur_folder_pth = os.path.join(root_folder_pth, test_folder)
        if not os.path.exists(cur_folder_pth):
            print(f"{cur_folder_pth} does not exist")
            continue
        else:
            data_folder_pth = os.path.join(cur_folder_pth, "testing")
            files_lst = os.listdir(data_folder_pth)

            orig_signal_lst, file_names_lst = [], []
            distances_lst, fund_freq_lst = [], []
            test_x, test_y = [], []

            for file in files_lst:

                if file not in [".DS_Store"]:
                    print(file)

                    cur_file_pth = os.path.join(data_folder_pth, file)
                    cur_signal = np.sum(read_files(cur_file_pth), axis=0)

                    orig_signal_lst.append(cur_signal)
                    file_names_lst.append(file)

                    try:
                        match = pattern.search(file)
                        cm_value = int(match.group(1))
                        distances_lst.append(cm_value)
                    except:
                        distances_lst.append(0)

                    try:
                        cur_fund_freq = round(int(file.split("_")[2]) / 60, 4)
                        fund_freq_lst.append(cur_fund_freq)
                    except:
                        cur_fund_freq = fund_freq
                        fund_freq_lst.append(cur_fund_freq)

                    num_pnts = len(cur_signal) // welch_num_chunks

                    fxx, cur_pxx = signal.welch(cur_signal, fs, nperseg=num_pnts, nfft=fs)
                    log_pxx = np.log(cur_pxx)

                    # true_mask, cur_fund_freq_lst= make_true_mask_v2(cur_fund_freq, psd_length=psd_length, sigma=4.0, n_harmonics=100,
                    #                               mode=mode, band_width=4)
                    true_mask, cur_fund_freq_lst = make_true_mask_v2(cur_fund_freq, psd_length=psd_length, sigma=3.0, n_harmonics=100,
                                                      mode=mode, band_width=3)
                    cur_psds = []

                    for scale_factor in psd_scale_down_factors:
                        if scale_factor not in ['sum']:
                            cur_scaled_psd = scale_psd(log_pxx, psd_length, scale_factor, method="average")
                            # cur_psds.append(cur_scaled_psd)
                            cur_psds.append(min_max_norm(cur_scaled_psd))

                    stacked = np.stack(cur_psds, axis=0)
                    max_psd = np.median(stacked, axis=0)
                    cur_psds.append(min_max_norm(max_psd))

                    cur_x, cur_y = np.array(cur_psds), true_mask

                    # plot_psds_with_mask(cur_psds,true_mask,file=f"{file}.png", resamp_factor=1,cur_fund_freq=cur_fund_freq,psd_scale_down_factors=psd_scale_down_factors,save_fig=save_fig,out_dir=f"./")

                    test_x.append(cur_x)
                    test_y.append(cur_y)

            test_dict[test_folder] = [test_x, test_y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst]

    with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_test.pkl", "wb") as f:
        pickle.dump(test_dict, f)

    print(f"Saved dictionary to conv2d_psd_scaled_down_1up_{mode}_test.pkl")



In [None]:
def process_test_dict(loaded_dict, row_indices=[0, 4], col_size=512, output_format="channels_first"):
    processed_data = {}

    for key, (test_x, test_y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst) in loaded_dict.items():
        test_x = np.array(test_x)
        test_y = np.array(test_y)
        distances_lst = np.array(distances_lst)
        fund_freq_lst = np.array(fund_freq_lst)
        orig_signal_lst = np.array(orig_signal_lst)

        # Select rows and columns size (preserve original order - no sorting)
        X = test_x[:, row_indices, :col_size]
        y = test_y

        # Reshape based on desired output format
        if output_format == "channels_first":
            X = X[:, :, np.newaxis, :]  # [N, num_channels, 1, col_size]
        elif output_format == "channels_last":

            X = X[:, np.newaxis, :, :]  # [N, 1, num_channels, col_size]
        else:
            raise ValueError("output_format must be 'channels_first' or 'channels_last'")

        # Return all data in original order (no train/val split)
        processed_data[key] = (X, y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst)

    return processed_data

In [None]:
import fuzzy_logic_functions as fuzzy_funcs
import warnings, pickle

warnings.filterwarnings("ignore")
mode = "gaussian"
# all_combs_lists = [[0], [3], [0, 1, 2, 3]]
all_combs_lists = [[0]]
import matplotlib
matplotlib.use('QT5Agg')

import matplotlib.pyplot as plt
import torch, os
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
# from magtach.op_codes.nn_models import FPN_2D, FPN_1D, FPN_2D_regression, FPN_1D_regression, FPN_2_1up, FPN_2, \
#     FPN_2_1up, FCN1D, FCN1D_resnet, SEResNet1D_FC
from magtach.op_codes.mtl_models import MTL_1, MTL_2, MTL_1_v1
# from magtach.op_codes.ppsp_mtl import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_6 import FPN_2_MTL_Residual
from magtach.op_codes.ppsp_mtl_7 import MTL_V1_Residual_FundGuidesHarm
# from magtach.op_codes.ppsp_mtl_8 import FPN_2, FundamentalFromHarmonics
from magtach.op_codes.ppsp_mtl_9 import  FPN_2,FPN2_withFundamental
from magtach.op_codes.ppsp_mtl_1up import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_3 import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_5 import DualBranchDenseMTL
# from magtach.op_codes.ppsp_mtl_1 import FPN_2_MTL_Dual
# from magtach.op_codes.ppsp_1up_head import PPSP_withFundamental, PPSP
import copy
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
fpn2_weights_pth="../data/train_test_data/"

fpn_weights_file = "best_model_weights_fan5_fan3_bldc_fpn2"
model_name2 = "best_fpn2_1up_model_gaussian_[0].pth"

# ---- Load base FPN_2 ----
harm_model = FPN_2()
harm_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))
harm_model.eval()

# ---- Create wrapper ----
trained_model2 = FPN2_withFundamental(copy.deepcopy(harm_model))
trained_model2.eval()

# ---- Load fundamental head weights ----
trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location="cpu"), strict=False)

# ---- Reapply harmonic weights to ensure identical BN stats ----
trained_model2.harm_net.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))

trained_model3 = FPN_2_mtl(in_channels=1,out_channels=32)
model_name3="crepe"
trained_model3.load_state_dict(torch.load(f'./1up_weights/best_fpn2_1up_model_gaussian_[0].pth', map_location="cpu"))
trained_model3.eval()



print(F"ppsp+1up in testing mode: {not(trained_model2.training)}")
# print(F"crepe in testing mode: {not(trained_model3.training)}")

files=pickle.load(open(f"../data/train_test_data/test_x_y_bldc_correct","rb"))
test_dict_lst =files[0]

dir_path = f"./conv2d_data/pred_plots/[0]/"
os.makedirs(dir_path, exist_ok=True)

for ind, values in enumerate(test_dict_lst):

    if values[1] in ["bldc_6","bldc_5"]:
        print(f"{values[1]}")
        test_dict, fundamental_freq = values[0], values[2]
        res_file = f"{values[1]}_results"

        for key, val in test_dict.items():
            if key == key:
                print(key, len(val))
                cur_results = []

                for sig_ind in range(len(val)):
                    # print(sig_ind)
                    cur_fil_name = val[sig_ind][4]
                    print(cur_fil_name)

                    orig_sig = np.array(val[sig_ind][3])
                    cur_x, cur_y = val[sig_ind][0], val[sig_ind][1][:1, :][0]

                    cur_x = torch.tensor(cur_x, dtype=torch.float32).reshape(1, 1, -1)

                    crepe_prediction = trained_model3(cur_x)
                    ppsp_1up_prediction = trained_model2(cur_x)

                    # cur_pred= trained_model(cur_x)
                    # cur_pred = torch.sigmoid(cur_pred)
                    # pred = cur_pred[0][0].detach().numpy()
                    # pred = np.where(pred <= 0.5, 0, pred)

                    crepe_fund = torch.sigmoid(crepe_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()

                    ppsp_fund = torch.sigmoid(ppsp_1up_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    ppsp_harmonics = torch.sigmoid(ppsp_1up_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()

                    # ppsp_harmonics=np.where(ppsp_harmonics <= 0.5, 0, ppsp_harmonics)

                    ppsp_fund = (ppsp_fund - np.min(ppsp_fund)) / (np.max(ppsp_fund) - np.min(ppsp_fund) + 1e-12)

                    crepe_fund = (crepe_fund - np.min(crepe_fund)) / (np.max(crepe_fund) - np.min(crepe_fund) + 1e-12)



                    plt_fil_name = f"{fundamental_freq}-{cur_fil_name}"



                    fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)

                    # --- Left: FPN_2 prediction + input channels ---
                    for ind_x in range(cur_x.squeeze(0).shape[0]):
                        axes[0].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.9)
                    axes[0].plot(cur_y, linewidth=1.1, label="True",color="orange", alpha=0.7)
                    axes[0].plot(ppsp_fund, '--',linewidth=0.8, color="green", label="ppsp_1up", alpha=0.8)
                    axes[0].plot(np.maximum(0.3,ppsp_harmonics*0.8),'--', linewidth=0.7, color="red", label="non normalized_harmonics", alpha=0.9)
                    axes[0].set_title("ppsp_1up")
                    axes[0].legend(loc="lower right")

                    # --- Right: FPN_2_mtl prediction + same inputs for reference ---
                    for ind_x in range(cur_x.squeeze(0).shape[0]):
                        axes[1].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
                    axes[1].plot(cur_y, linewidth=1, label="True", alpha=0.8)
                    axes[1].plot(crepe_fund,'--', linewidth=0.8, color="green", label="1up", alpha=0.9)
                    axes[1].set_title("1up(crepe)")
                    axes[1].legend(loc="lower right")

                    plt.suptitle(f"{fundamental_freq} Hz – {cur_fil_name}", fontsize=11)
                    plt.tight_layout()

                    # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_compare.png"), dpi=150)
                    # plt.close()
                    plt.show()





In [None]:
import fuzzy_logic_functions as fuzzy_funcs
import warnings, pickle

warnings.filterwarnings("ignore")
mode = "gaussian"
# all_combs_lists = [[0], [3], [0, 1, 2, 3]]
all_combs_lists = [[0],[3]]
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# from magtach.op_codes.nn_models import FPN_2D, FPN_1D, FPN_2D_regression, FPN_1D_regression, FPN_2_1up, FPN_2, \
#     FPN_2_1up, FCN1D, FCN1D_resnet, SEResNet1D_FC
# from magtach.op_codes.mtl_models import MTL_1, MTL_2, MTL_1_v1
# from magtach.op_codes.ppsp_mtl import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_6 import FPN_2_MTL_Residual
# from magtach.op_codes.ppsp_mtl_7 import MTL_V1_Residual_FundGuidesHarm
# from magtach.op_codes.ppsp_mtl_8 import FPN_2, FundamentalFromHarmonics
from magtach.op_codes.ppsp_mtl_9 import FPN_2, FPN2_withFundamental
# from magtach.op_codes.ppsp_mtl_1up import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_3 import FPN_2_mtl
# from magtach.op_codes.ppsp_mtl_5 import DualBranchDenseMTL
# from magtach.op_codes.ppsp_mtl_1 import FPN_2_MTL_Dual
from magtach.op_codes.crepe_1up import FPN_2_mtl
# from magtach.op_codes.fpn_2 import PPSP
# from magtach.op_codes.ppsp_1up_head import PPSP_withFundamental, PPSP
# from magtach.op_codes.ppsp_1up import PPSP_1up
import copy
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
# fpn2_weights_pth="../data/train_test_data/"
# fpn_weights_file= "best_model_weights_fan5_fan3_bldc_fpn2"

def dice_coeff(predictions, targets, smooth=1):
    predictions = predictions.astype(float)
    targets = targets.astype(float)
    intersection = (predictions * targets).sum()
    dice = (2. * intersection + smooth) / (predictions.sum() + targets.sum() + smooth)
    return dice


def overlap_dice(predictions, targets, threshold=0.3, smooth=1e-6):
    ## Binarize predictions
    preds_bin = (predictions > threshold).astype(float)
    targets_bin = (targets > 0.1).astype(float)
    ## True positive region: predictions overlapping target
    TP = np.sum(preds_bin * targets_bin)
    ## False positives outside true regions
    FP = np.sum(preds_bin * (1 - targets_bin))
    FN = np.sum((1 - preds_bin) * targets_bin)
    ## Denominator = TP + FP + smoothing
    # dice = TP / (TP + FP + smooth)
    ## F1 score as the loss
    f1_score = (2 * TP) / ((2 * TP) + FP + FN)
    return f1_score


def region_accuracy(y_true, y_pred, threshold=0.5):
    true_idx = np.where(y_true == 1)[0]
    pred_idx = np.where(y_pred == 1)[0]

    if len(true_idx) == 0 and len(pred_idx) == 0:
        return 1.0
    if len(true_idx) == 0 or len(pred_idx) == 0:
        return 0.0

    overlap = len(np.intersect1d(true_idx, pred_idx))

    frac_overlap = overlap / len(true_idx)

    return 1.0 if frac_overlap >= threshold else 0.0


with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_test.pkl", "rb") as f:
    loaded_dict_test = pickle.load(f)

for ind, comb_lst in enumerate(all_combs_lists):

    dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
    os.makedirs(dir_path, exist_ok=True)

    fpn2_weights_pth="../data/train_test_data/"

    fpn_weights_file = "best_model_weights_fan5_fan3_bldc_fpn2"
    model_name2 = "best_fpn2_1up_model_gaussian_[0].pth"

    # ---- Load base FPN_2 ----
    harm_model = FPN_2()
    harm_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))
    harm_model.eval()

    # ---- Create wrapper ----
    trained_model2 = FPN2_withFundamental(copy.deepcopy(harm_model))
    trained_model2.eval()

    # ---- Load fundamental head weights ----
    trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location="cpu"), strict=False)

    # ---- Reapply harmonic weights to ensure identical BN stats ----
    trained_model2.harm_net.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))

    trained_model3 = FPN_2_mtl(in_channels=1,out_channels=32)
    model_name3="crepe"
    trained_model3.load_state_dict(torch.load(f'./1up_weights/best_fpn2_1up_model_gaussian_[0].pth', map_location="cpu"))
    trained_model3.eval()

    processed_test = process_test_dict(
        loaded_dict_test,
        row_indices=comb_lst,
        col_size=1024,
        output_format="channels_last"  # or "channels_first"
    )
    # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
        print(f"Key: {key}")
        prediction_lst = []
        freq_prediction_lst, fund_freq_lst = [], []
        if key == "bldc_6":
            for ind in range(len(distances)):
                if ind == ind:
                    cur_fund_freq = fund_freq[ind]
                    cur_fil_name = file_names[ind]

                    # cur_x = np.array(X_test[ind])[np.newaxis,:,:,:]
                    cur_x = np.array(X_test[ind])[:, :, :]
                    torch_x = torch.FloatTensor(cur_x).to('cpu')
                    cur_y = np.array(y_test[ind])

                    # cur_prediction = trained_model(torch_x).squeeze(1).squeeze(0).detach().cpu().numpy()
                    # cur_prediction = torch.sigmoid(trained_model(torch_x)).squeeze(1).squeeze(0).detach().cpu().numpy()
                    cur_prediction = trained_model2(torch_x)




                    # fund_pred = torch.softmax(cur_prediction[0], dim=1).squeeze(1).squeeze(0).detach().cpu().numpy()
                    fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    # fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()

                    # fund_pred2 = torch.sigmoid(cur_prediction2[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    # all_harmonic_pred = torch.sigmoid(cur_prediction2[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    all_harmonic_pred1 = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    # all_harmonic_pred2 = torch.sigmoid(cur_prediction[3]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    # prediction_lst.append(cur_prediction)
                    # cur_prediction_norm = (cur_prediction - np.min(cur_prediction)) / (np.max(cur_prediction) - np.min(cur_prediction) + 1e-12)

                    cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)

                    # cur_fund_prediction_norm2 = (fund_pred2 - np.min(fund_pred2)) / (np.max(fund_pred2) - np.min(fund_pred2) + 1e-12)
                    #
                    # cur_harmonic_prediction_norm = (all_harmonic_pred - np.min(all_harmonic_pred)) / (np.max(all_harmonic_pred) - np.min(all_harmonic_pred) + 1e-12)
                    cur_harmonic_prediction_norm1 = (all_harmonic_pred1 - np.min(all_harmonic_pred1)) / (np.max(all_harmonic_pred1) - np.min(all_harmonic_pred1) + 1e-12)
                    # cur_harmonic_prediction_norm2 = (all_harmonic_pred2 - np.min(all_harmonic_pred2)) / (np.max(all_harmonic_pred2) - np.min(all_harmonic_pred2) + 1e-12)


                    # binary_cur_prediction=np.where(cur_prediction_norm<=0.35,0,1)
                    binary_cur_prediction=cur_fund_prediction_norm
                    binary_cur_truth=cur_y


                    plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"

                    plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
                    for ind_x in range(cur_x.squeeze(0).shape[0]):
                        plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)
                    # plt.plot(cur_x.squeeze(0)[0])
                    # plt.plot(cur_x.squeeze(0)[1])

                    plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)
                    # plt.plot(binary_cur_prediction, '--', linewidth=1, label="Predicted")
                    # plt.plot(cur_prediction, linewidth=0.8, label="raw prediction")

                    plt.plot(cur_fund_prediction_norm, linewidth=0.8, label=" f0",alpha=0.8)
                    plt.plot(cur_harmonic_prediction_norm1, '--', linewidth=0.7, label="raw all_harmonics",alpha=0.8)
                    # plt.plot(cur_fund_prediction_norm, linewidth=0.8,label="ppsp_original",alpha=0.8)

                    # plt.plot(fund_pred2, linewidth=0.8, label="raw f0",alpha=0.8)
                    # plt.plot(all_harmonic_pred, linewidth=0.7, label="raw all_harmonics",alpha=0.8)
                    # plt.plot(fund_pred, linewidth=0.8,label="ppsp_original",alpha=0.8)
                    # plt.plot(cur_harmonic_prediction_norm2, linewidth=0.8, label="f3")
                    plt.legend(loc='lower right')

                    # plt.savefig(f"./conv2d_data/pred_plots/{plt_fil_name}.png")
                    # plt.close()
                    #
                    # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
                    # plt.close()

                    plt.show()


                    # fig, axes = plt.subplots(2, 2, figsize=(10, 4), sharey=True)
                    #
                    # # --- Left: FPN_2 prediction + input channels ---
                    # for ind_x in range(cur_x.squeeze(0).shape[0]):
                    #     axes[0].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.9)
                    # axes[0].plot(binary_cur_truth, linewidth=1.1, label="True",color="orange", alpha=0.7)
                    # axes[0].plot(cur_fund_prediction_norm2,'--', linewidth=0.7, color="green", label="ppsp_1up", alpha=0.8)
                    # axes[0].plot(all_harmonic_pred*0.8,'--', linewidth=0.7, color="red", label="non normalized_harmonics", alpha=0.9)
                    # axes[0].set_title("ppsp_1up")
                    # axes[0].legend(loc="lower right")
                    #
                    # # --- Right: FPN_2_mtl prediction + same inputs for reference ---
                    # for ind_x in range(cur_x.squeeze(0).shape[0]):
                    #     axes[1].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
                    # axes[1].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
                    # axes[1].plot(cur_fund_prediction_norm,'--', linewidth=0.8, color="green", label="1up", alpha=0.9)
                    # axes[1].set_title("1up(crepe)")
                    # axes[1].legend(loc="lower right")
                    #
                    # plt.suptitle(f"{cur_fund_freq} Hz – {cur_fil_name}_{comb_lst}", fontsize=11)
                    # plt.tight_layout()
                    #
                    # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_compare.png"), dpi=150)
                    # # plt.close()
                    # plt.show()


                    # cur_pred_accuracy = dice_coeff(binary_cur_prediction, binary_cur_truth)
                    cur_overlap_accuracy = overlap_dice(binary_cur_prediction, binary_cur_truth)

                    ### fuzzy logic
                    # lr_strt, lr_end, cur_peak = fuzzy_funcs.lr_prediction(binary_cur_prediction,cur_x.squeeze(0).squeeze(0)[0],cur_fund_freq, distances[ind],plt_fil_name)
                    # print(f"lrstart: {lr_strt},lrend: {lr_end}, lrpeak: {cur_peak}")

                    # freq_pred = fuzzy_funcs.predict_freq(orig_sig[ind], orig_sig[ind], lr_strt, lr_end,
                    #                                      cur_peak)
                    # print(f"fundamental freq: {cur_fund_freq},freq_pred: {freq_pred}")

                    # fund_freq_lst.append(cur_fund_freq)
                    # freq_prediction_lst.append(freq_pred)
                    # criterion = torch.nn.BCELoss()
                    # loss = criterion(binary_cur_prediction, binary_cur_truth)
                    # cur_pred_accuracy2= torch.nn.BCELoss()(binary_cur_prediction, binary_cur_truth)
                    # cur_pred_accuracy = region_accuracy(binary_cur_truth, binary_cur_prediction, threshold=0.3)

                    prediction_lst.append(cur_overlap_accuracy)
                    #
                    # print(f"freqpred: {freq_pred}")
                    # prediction_lst.append(freq_pred)

            accuracy = sum(prediction_lst) / len(prediction_lst)
            # Convert to numpy arrays
            cur_gtruth = np.array(fund_freq) * 60
            cur_predict = np.array(prediction_lst) * 60

            #### filter out rows where cur_predict == 0
            # filtered_data = [(gt, pred) for gt, pred in zip(cur_gtruth, cur_predict) if pred != 0]
            #
            # filtered_gtruth, filtered_predict = zip(*filtered_data)
            # filtered_gtruth = np.array(filtered_gtruth)
            # filtered_predict = np.array(filtered_predict)

            #### Calculate absolute error and mean percentage error
            # abs_err = np.abs(filtered_gtruth - filtered_predict)
            # mean_err = np.mean((abs_err / filtered_gtruth) * 100)
            abs_err = np.abs(cur_gtruth - cur_predict)
            mean_err = np.mean((abs_err/cur_gtruth)*100)

            df = pd.DataFrame({
                'fundamental_frequency': list(fund_freq),
                'predictions_array': prediction_lst,
                'distance': list(distances),
                'file_name': file_names
                # 'fund_freq': fund_freq_lst,
                # 'predictions': freq_prediction_lst
            })
            # # df = df[df['predictions_array'] != 0]
            # df['abs_err'] = np.abs((df['fundamental_frequency'] * 60) - (df['predictions_array'] * 60))
            # #### Save to CSV
            df.to_csv(f"{dir_path}/{key}_{accuracy}_conv1d_ppsp.csv", index=False)



In [None]:
# import fuzzy_logic_functions as fuzzy_funcs
# import warnings, pickle
#
# warnings.filterwarnings("ignore")
# mode = "gaussian"
# # all_combs_lists = [[0], [3], [0, 1, 2, 3]]
# all_combs_lists = [[0],[3]]
# import torch
# import torch.optim as optim
# from torch.utils.data import DataLoader, TensorDataset
# # from magtach.op_codes.nn_models import FPN_2D, FPN_1D, FPN_2D_regression, FPN_1D_regression, FPN_2_1up, FPN_2, \
# #     FPN_2_1up, FCN1D, FCN1D_resnet, SEResNet1D_FC
# from magtach.op_codes.mtl_models import MTL_1, MTL_2, MTL_1_v1
# # from magtach.op_codes.ppsp_mtl import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_6 import FPN_2_MTL_Residual
# # from magtach.op_codes.ppsp_mtl_7 import MTL_V1_Residual_FundGuidesHarm
# # from magtach.op_codes.ppsp_mtl_8 import FPN_2, FundamentalFromHarmonics
# # from magtach.op_codes.ppsp_mtl_9 import FPN_2, FPN2_withFundamental
# # from magtach.op_codes.ppsp_mtl_1up import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_3 import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_5 import DualBranchDenseMTL
# # from magtach.op_codes.ppsp_mtl_1 import FPN_2_MTL_Dual
# from magtach.op_codes.crepe_1up import FPN_2_mtl
# # from magtach.op_codes.fpn_2 import PPSP
# from magtach.op_codes.ppsp_1up_head import PPSP_withFundamental, PPSP
# # from magtach.op_codes.ppsp_1up import PPSP_1up
# import copy
# # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
# # fpn2_weights_pth="../data/train_test_data/"
# # fpn_weights_file= "best_model_weights_fan5_fan3_bldc_fpn2"
#
# def dice_coeff(predictions, targets, smooth=1):
#     predictions = predictions.astype(float)
#     targets = targets.astype(float)
#     intersection = (predictions * targets).sum()
#     dice = (2. * intersection + smooth) / (predictions.sum() + targets.sum() + smooth)
#     return dice
#
#
# def overlap_dice(predictions, targets, threshold=0.3, smooth=1e-6):
#     ## Binarize predictions
#     preds_bin = (predictions > threshold).astype(float)
#     targets_bin = (targets > 0.1).astype(float)
#     ## True positive region: predictions overlapping target
#     TP = np.sum(preds_bin * targets_bin)
#     ## False positives outside true regions
#     FP = np.sum(preds_bin * (1 - targets_bin))
#     FN = np.sum((1 - preds_bin) * targets_bin)
#     ## Denominator = TP + FP + smoothing
#     # dice = TP / (TP + FP + smooth)
#     ## F1 score as the loss
#     f1_score = (2 * TP) / ((2 * TP) + FP + FN)
#     return f1_score
#
#
# def region_accuracy(y_true, y_pred, threshold=0.5):
#     true_idx = np.where(y_true == 1)[0]
#     pred_idx = np.where(y_pred == 1)[0]
#
#     if len(true_idx) == 0 and len(pred_idx) == 0:
#         return 1.0
#     if len(true_idx) == 0 or len(pred_idx) == 0:
#         return 0.0
#
#     overlap = len(np.intersect1d(true_idx, pred_idx))
#
#     frac_overlap = overlap / len(true_idx)
#
#     return 1.0 if frac_overlap >= threshold else 0.0
#
#
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_test.pkl", "rb") as f:
#     loaded_dict_test = pickle.load(f)
#
# for ind, comb_lst in enumerate(all_combs_lists):
#
#     dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
#     os.makedirs(dir_path, exist_ok=True)
#
#     # trained_model = FPN_2(in_channels=len(comb_lst), out_channels=32)
#     # trained_model = FPN_2_1up(in_channels=len(comb_lst), encoded_channels=32, output_size=1024)
#
#     # trained_model = FCN1D(in_channels=len(comb_lst), num_classes=1024, kernel_size=7)
#     # trained_model =  SEResNet1D_FC(in_channels=len(comb_lst), kernel_size=7)
#     # trained_model = MTL_1(in_ch=len(comb_lst), seq_len=1024, base=8, hidden=512, dilations=(1, 1,  1))
#     # trained_model = MTL_2(in_ch=len(comb_lst), seq_len=1024, base=32, hidden=512,dilations=(1, 1, 1))
#     # trained_model = MTL_2(in_ch=len(comb_lst), seq_len=1024, base=8, hidden=256,dilations=(1, 1, 1))
#     # trained_model = FPN_2_mtl(in_channels=len(comb_lst),fusion_type="local")
#     # trained_model = FPN_2_MTL_Dual(in_channels=len(comb_lst))
#     # trained_model = FCN1D_resnet(in_channels=len(comb_lst), num_classes=1024, kernel_size=3)
#     # trained_model = FCN1D(in_channels=len(comb_lst), num_classes=1024, kernel_size=7)
#
#
#     # trained_model = FPN_2_mtl(in_channels=len(comb_lst),out_channels=32)
#     trained_model = PPSP(in_channels=len(comb_lst),out_channels=32)
#     # trained_model = PPSP_1up(in_channels=len(comb_lst),out_channels=32)
#     model_name = f"ppsp[0].pth"
#     # model_name = f"best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     # model_name = f"best_ppsp_model_gaussian_[0]_100cm.pth"
#     trained_model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
#     trained_model.eval()
#
#     trained_model2 = PPSP_withFundamental(trained_model, freeze=True, hidden=256)
#     # trained_model = PPSP_1up(in_channels=len(comb_lst),out_channels=32)
#     model_name2 = f"best_fpn2_1up_model_gaussian_[0].pth"
#     # model_name = f"best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     # model_name = f"best_ppsp_model_gaussian_[0]_100cm.pth"
#     trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location=torch.device('cpu')))
#     trained_model2.eval()
#
#
#
#     processed_test = process_test_dict(
#         loaded_dict_test,
#         row_indices=comb_lst,
#         col_size=1024,
#         output_format="channels_last"  # or "channels_first"
#     )
#     # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")
#
#     for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
#         print(f"Key: {key}")
#         prediction_lst = []
#         freq_prediction_lst, fund_freq_lst = [], []
#         if key == "bldc_6":
#             for ind in range(len(distances)):
#                 if ind == ind:
#                     cur_fund_freq = fund_freq[ind]
#                     cur_fil_name = file_names[ind]
#
#                     # cur_x = np.array(X_test[ind])[np.newaxis,:,:,:]
#                     cur_x = np.array(X_test[ind])[:, :, :]
#                     torch_x = torch.FloatTensor(cur_x).to('cpu')
#                     cur_y = np.array(y_test[ind])
#
#                     # cur_prediction = trained_model(torch_x).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # cur_prediction = torch.sigmoid(trained_model(torch_x)).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     cur_prediction = trained_model2(torch_x)
#
#
#
#
#                     # fund_pred = torch.softmax(cur_prediction[0], dim=1).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()
#
#                     # fund_pred2 = torch.sigmoid(cur_prediction2[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # all_harmonic_pred = torch.sigmoid(cur_prediction2[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     all_harmonic_pred1 = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # all_harmonic_pred2 = torch.sigmoid(cur_prediction[3]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # prediction_lst.append(cur_prediction)
#                     # cur_prediction_norm = (cur_prediction - np.min(cur_prediction)) / (np.max(cur_prediction) - np.min(cur_prediction) + 1e-12)
#
#                     cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)
#
#                     # cur_fund_prediction_norm2 = (fund_pred2 - np.min(fund_pred2)) / (np.max(fund_pred2) - np.min(fund_pred2) + 1e-12)
#                     #
#                     # cur_harmonic_prediction_norm = (all_harmonic_pred - np.min(all_harmonic_pred)) / (np.max(all_harmonic_pred) - np.min(all_harmonic_pred) + 1e-12)
#                     cur_harmonic_prediction_norm1 = (all_harmonic_pred1 - np.min(all_harmonic_pred1)) / (np.max(all_harmonic_pred1) - np.min(all_harmonic_pred1) + 1e-12)
#                     # cur_harmonic_prediction_norm2 = (all_harmonic_pred2 - np.min(all_harmonic_pred2)) / (np.max(all_harmonic_pred2) - np.min(all_harmonic_pred2) + 1e-12)
#
#
#                     # binary_cur_prediction=np.where(cur_prediction_norm<=0.35,0,1)
#                     binary_cur_prediction=cur_fund_prediction_norm
#                     binary_cur_truth=cur_y
#
#
#                     plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"
#
#                     plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
#                     for ind_x in range(cur_x.squeeze(0).shape[0]):
#                         plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)
#                     # plt.plot(cur_x.squeeze(0)[0])
#                     # plt.plot(cur_x.squeeze(0)[1])
#
#                     plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)
#                     # plt.plot(binary_cur_prediction, '--', linewidth=1, label="Predicted")
#                     # plt.plot(cur_prediction, linewidth=0.8, label="raw prediction")
#
#                     plt.plot(cur_fund_prediction_norm, linewidth=0.8, label=" f0",alpha=0.8)
#                     plt.plot(cur_harmonic_prediction_norm1, '--', linewidth=0.7, label="raw all_harmonics",alpha=0.8)
#                     # plt.plot(cur_fund_prediction_norm, linewidth=0.8,label="ppsp_original",alpha=0.8)
#
#                     # plt.plot(fund_pred2, linewidth=0.8, label="raw f0",alpha=0.8)
#                     # plt.plot(all_harmonic_pred, linewidth=0.7, label="raw all_harmonics",alpha=0.8)
#                     # plt.plot(fund_pred, linewidth=0.8,label="ppsp_original",alpha=0.8)
#                     # plt.plot(cur_harmonic_prediction_norm2, linewidth=0.8, label="f3")
#                     plt.legend(loc='lower right')
#
#                     # plt.savefig(f"./conv2d_data/pred_plots/{plt_fil_name}.png")
#                     # plt.close()
#                     #
#                     # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
#                     # plt.close()
#
#                     plt.show()
#
#
#                     # fig, axes = plt.subplots(2, 2, figsize=(10, 4), sharey=True)
#                     #
#                     # # --- Left: FPN_2 prediction + input channels ---
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     axes[0].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.9)
#                     # axes[0].plot(binary_cur_truth, linewidth=1.1, label="True",color="orange", alpha=0.7)
#                     # axes[0].plot(cur_fund_prediction_norm2,'--', linewidth=0.7, color="green", label="ppsp_1up", alpha=0.8)
#                     # axes[0].plot(all_harmonic_pred*0.8,'--', linewidth=0.7, color="red", label="non normalized_harmonics", alpha=0.9)
#                     # axes[0].set_title("ppsp_1up")
#                     # axes[0].legend(loc="lower right")
#                     #
#                     # # --- Right: FPN_2_mtl prediction + same inputs for reference ---
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     axes[1].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
#                     # axes[1].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
#                     # axes[1].plot(cur_fund_prediction_norm,'--', linewidth=0.8, color="green", label="1up", alpha=0.9)
#                     # axes[1].set_title("1up(crepe)")
#                     # axes[1].legend(loc="lower right")
#                     #
#                     # plt.suptitle(f"{cur_fund_freq} Hz – {cur_fil_name}_{comb_lst}", fontsize=11)
#                     # plt.tight_layout()
#                     #
#                     # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_compare.png"), dpi=150)
#                     # # plt.close()
#                     # plt.show()
#
#
#                     # cur_pred_accuracy = dice_coeff(binary_cur_prediction, binary_cur_truth)
#                     cur_overlap_accuracy = overlap_dice(binary_cur_prediction, binary_cur_truth)
#
#                     ### fuzzy logic
#                     # lr_strt, lr_end, cur_peak = fuzzy_funcs.lr_prediction(binary_cur_prediction,cur_x.squeeze(0).squeeze(0)[0],cur_fund_freq, distances[ind],plt_fil_name)
#                     # print(f"lrstart: {lr_strt},lrend: {lr_end}, lrpeak: {cur_peak}")
#
#                     # freq_pred = fuzzy_funcs.predict_freq(orig_sig[ind], orig_sig[ind], lr_strt, lr_end,
#                     #                                      cur_peak)
#                     # print(f"fundamental freq: {cur_fund_freq},freq_pred: {freq_pred}")
#
#                     # fund_freq_lst.append(cur_fund_freq)
#                     # freq_prediction_lst.append(freq_pred)
#                     # criterion = torch.nn.BCELoss()
#                     # loss = criterion(binary_cur_prediction, binary_cur_truth)
#                     # cur_pred_accuracy2= torch.nn.BCELoss()(binary_cur_prediction, binary_cur_truth)
#                     # cur_pred_accuracy = region_accuracy(binary_cur_truth, binary_cur_prediction, threshold=0.3)
#
#                     prediction_lst.append(cur_overlap_accuracy)
#                     #
#                     # print(f"freqpred: {freq_pred}")
#                     # prediction_lst.append(freq_pred)
#
#             accuracy = sum(prediction_lst) / len(prediction_lst)
#             # Convert to numpy arrays
#             cur_gtruth = np.array(fund_freq) * 60
#             cur_predict = np.array(prediction_lst) * 60
#
#             #### filter out rows where cur_predict == 0
#             # filtered_data = [(gt, pred) for gt, pred in zip(cur_gtruth, cur_predict) if pred != 0]
#             #
#             # filtered_gtruth, filtered_predict = zip(*filtered_data)
#             # filtered_gtruth = np.array(filtered_gtruth)
#             # filtered_predict = np.array(filtered_predict)
#
#             #### Calculate absolute error and mean percentage error
#             # abs_err = np.abs(filtered_gtruth - filtered_predict)
#             # mean_err = np.mean((abs_err / filtered_gtruth) * 100)
#             abs_err = np.abs(cur_gtruth - cur_predict)
#             mean_err = np.mean((abs_err/cur_gtruth)*100)
#
#             df = pd.DataFrame({
#                 'fundamental_frequency': list(fund_freq),
#                 'predictions_array': prediction_lst,
#                 'distance': list(distances),
#                 'file_name': file_names
#                 # 'fund_freq': fund_freq_lst,
#                 # 'predictions': freq_prediction_lst
#             })
#             # # df = df[df['predictions_array'] != 0]
#             # df['abs_err'] = np.abs((df['fundamental_frequency'] * 60) - (df['predictions_array'] * 60))
#             # #### Save to CSV
#             df.to_csv(f"{dir_path}/{key}_{accuracy}_conv1d_ppsp.csv", index=False)
#


In [None]:
# import fuzzy_logic_functions as fuzzy_funcs
# import warnings, pickle
#
# warnings.filterwarnings("ignore")
# mode = "gaussian"
# # all_combs_lists = [[0], [3], [0, 1, 2, 3]]
# all_combs_lists = [[0]]
# import torch
# import torch.optim as optim
# from torch.utils.data import DataLoader, TensorDataset
# # from magtach.op_codes.nn_models import FPN_2D, FPN_1D, FPN_2D_regression, FPN_1D_regression, FPN_2_1up, FPN_2, \
# #     FPN_2_1up, FCN1D, FCN1D_resnet, SEResNet1D_FC
# from magtach.op_codes.mtl_models import MTL_1, MTL_2, MTL_1_v1
# # from magtach.op_codes.ppsp_mtl import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_6 import FPN_2_MTL_Residual
# from magtach.op_codes.ppsp_mtl_7 import MTL_V1_Residual_FundGuidesHarm
# # from magtach.op_codes.ppsp_mtl_8 import FPN_2, FundamentalFromHarmonics
# from magtach.op_codes.ppsp_mtl_9 import FPN_2, FPN2_withFundamental
# from magtach.op_codes.ppsp_mtl_1up import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_3 import FPN_2_mtl
# # from magtach.op_codes.ppsp_mtl_5 import DualBranchDenseMTL
# # from magtach.op_codes.ppsp_mtl_1 import FPN_2_MTL_Dual
# import copy
# # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
# fpn2_weights_pth="../data/train_test_data/"
# fpn_weights_file= "best_model_weights_fan5_fan3_bldc_fpn2"
#
# def dice_coeff(predictions, targets, smooth=1):
#     predictions = predictions.astype(float)
#     targets = targets.astype(float)
#     intersection = (predictions * targets).sum()
#     dice = (2. * intersection + smooth) / (predictions.sum() + targets.sum() + smooth)
#     return dice
#
#
# def overlap_dice(predictions, targets, threshold=0.3, smooth=1e-6):
#     ## Binarize predictions
#     preds_bin = (predictions > threshold).astype(float)
#     targets_bin = (targets > 0.1).astype(float)
#     ## True positive region: predictions overlapping target
#     TP = np.sum(preds_bin * targets_bin)
#     ## False positives outside true regions
#     FP = np.sum(preds_bin * (1 - targets_bin))
#     FN = np.sum((1 - preds_bin) * targets_bin)
#     ## Denominator = TP + FP + smoothing
#     # dice = TP / (TP + FP + smooth)
#     ## F1 score as the loss
#     f1_score = (2 * TP) / ((2 * TP) + FP + FN)
#     return f1_score
#
#
# def region_accuracy(y_true, y_pred, threshold=0.5):
#     true_idx = np.where(y_true == 1)[0]
#     pred_idx = np.where(y_pred == 1)[0]
#
#     if len(true_idx) == 0 and len(pred_idx) == 0:
#         return 1.0
#     if len(true_idx) == 0 or len(pred_idx) == 0:
#         return 0.0
#
#     overlap = len(np.intersect1d(true_idx, pred_idx))
#
#     frac_overlap = overlap / len(true_idx)
#
#     return 1.0 if frac_overlap >= threshold else 0.0
#
#
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_test.pkl", "rb") as f:
#     loaded_dict_test = pickle.load(f)
#
# for ind, comb_lst in enumerate(all_combs_lists):
#
#     dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
#     os.makedirs(dir_path, exist_ok=True)
#
#     # trained_model = FPN_2(in_channels=len(comb_lst), out_channels=32)
#     # trained_model = FPN_2_1up(in_channels=len(comb_lst), encoded_channels=32, output_size=1024)
#
#     # trained_model = FCN1D(in_channels=len(comb_lst), num_classes=1024, kernel_size=7)
#     # trained_model =  SEResNet1D_FC(in_channels=len(comb_lst), kernel_size=7)
#     # trained_model = MTL_1(in_ch=len(comb_lst), seq_len=1024, base=8, hidden=512, dilations=(1, 1,  1))
#     # trained_model = MTL_2(in_ch=len(comb_lst), seq_len=1024, base=32, hidden=512,dilations=(1, 1, 1))
#     # trained_model = MTL_2(in_ch=len(comb_lst), seq_len=1024, base=8, hidden=256,dilations=(1, 1, 1))
#     # trained_model = FPN_2_mtl(in_channels=len(comb_lst),fusion_type="local")
#     # trained_model = FPN_2_MTL_Dual(in_channels=len(comb_lst))
#     # trained_model = FCN1D_resnet(in_channels=len(comb_lst), num_classes=1024, kernel_size=3)
#     # trained_model = FCN1D(in_channels=len(comb_lst), num_classes=1024, kernel_size=7)
#
#     # trained_model = FPN_2()
#     # model_name= "FPN2"
#     # trained_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location=torch.device('cpu')))
#     # trained_model.eval()
#
#     # trained_model2 = MTL_V1_Residual_FundGuidesHarm(in_channels=len(comb_lst), gamma=0.5)
#     # model_name2 = f"best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     # trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location=torch.device('cpu')))
#     # trained_model2.eval()
#     # harmonic_model = FPN_2(in_channels=1, out_channels=32)
#     # harmonic_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}',
#     #                                           map_location="cpu"))
#     # trained_model = FPN_2()
#     # model_name="ppsp"
#     # trained_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))
#
#     # trained_model3 = FPN_2_mtl(in_channels=len(comb_lst),out_channels=32)
#     # model_name3="ppsp_1up"
#     # trained_model3.load_state_dict(torch.load(f'./1up_weights/best_fpn2_1up_model_gaussian_[0].pth', map_location="cpu"))
#     # trained_model3.eval()
#     #
#     # trained_model2 = FPN2_withFundamental(trained_model)
#     # model_name2 = f"best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     # trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location=torch.device('cpu')))
#     #
#     # trained_model2.eval()
#     #
#     # trained_model.eval()
#     #
#     # is_model_training = not (trained_model.training)
#     # print(F"{model_name} in testing mode: {is_model_training}")
#     # print(F"{model_name2} in testing mode: {trained_model2.training}")
#     model_name2 = "best_fpn2_1up_model_gaussian_[0].pth"
#     # # ---- Load base FPN_2 ----
#     harm_model = FPN_2()
#     harm_model.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))
#     harm_model.eval()
#
#     # ---- Create wrapper ----
#     trained_model2 = FPN2_withFundamental(copy.deepcopy(harm_model), freeze=True)
#     trained_model2.eval()
#
#     # ---- Load fundamental head weights ----
#     trained_model2.load_state_dict(torch.load(f"{model_name2}", map_location="cpu"), strict=False)
#
#     # ---- Reapply harmonic weights to ensure identical BN stats ----
#     trained_model2.harm_net.load_state_dict(torch.load(f'../data/train_test_data/{fpn_weights_file}', map_location="cpu"))
#
#
#     trained_model3 = FPN_2_mtl(in_channels=1,out_channels=32)
#     model_name3="crepe"
#     trained_model3.load_state_dict(torch.load(f'./1up_weights/best_fpn2_1up_model_gaussian_[0].pth', map_location="cpu"))
#     trained_model3.eval()
#
#
#
#     print(F"ppsp+1up in testing mode: {not(trained_model2.training)}")
#     print(F"crepe in testing mode: {not(trained_model3.training)}")
#
#     processed_test = process_test_dict(
#         loaded_dict_test,
#         row_indices=comb_lst,
#         col_size=1024,
#         output_format="channels_last"  # or "channels_first"
#     )
#     # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")
#
#
#     # """crepe CAM"""
#     # activations = {}
#     # gradients = {}
#     #
#     # def save_activation(name):
#     #     def hook(module, input, output):
#     #         activations[name] = output.detach()
#     #     return hook
#     #
#     # def save_gradient(name):
#     #     def hook(module, grad_input, grad_output):
#     #         gradients[name] = grad_output[0].detach()
#     #     return hook
#     #
#     # target_layer = trained_model3.conv_block10.c
#     # target_layer.register_forward_hook(save_activation("feat"))
#     # target_layer.register_backward_hook(save_gradient("feat"))
#
#
#     for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
#         print(f"Key: {key}")
#         prediction_lst = []
#         freq_prediction_lst, fund_freq_lst = [], []
#         if key == key:
#             for ind in range(len(distances)):
#                 if ind == ind:
#                     cur_fund_freq = fund_freq[ind]
#                     cur_fil_name = file_names[ind]
#
#                     # cur_x = np.array(X_test[ind])[np.newaxis,:,:,:]
#                     cur_x = np.array(X_test[ind])[:, :, :]
#                     torch_x = torch.FloatTensor(cur_x).to('cpu')
#                     cur_y = np.array(y_test[ind])
#
#                     # cur_prediction = trained_model(torch_x).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # cur_prediction = torch.sigmoid(trained_model(torch_x)).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     cur_prediction = trained_model3(torch_x)
#                     cur_prediction2 = trained_model2(torch_x)
#                     """
#                     PPSP1up sailency map
#                     """
#                     torch_x1 = torch.FloatTensor(cur_x).to('cpu').requires_grad_(True)
#                     fund_pred, harm_pred = trained_model2(torch_x1)
#                     target_index1 = np.argmax(fund_pred.detach().cpu().numpy())
#                     score1 = fund_pred[0, 0, target_index1]
#
#                     trained_model2.zero_grad()
#                     score1.backward(retain_graph=True)
#
#                     # Gradient w.r.t. input
#                     input_grad1 = torch_x1.grad.detach().cpu().numpy().squeeze()
#                     input_grad_abs1 = np.abs(input_grad1)
#                     input_grad_norm1 = input_grad_abs1 / (np.max(input_grad_abs1) + 1e-12)
#
#
#
#                     # fund_pred = torch.softmax(cur_prediction[0], dim=1).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     fund_pred2 = torch.sigmoid(cur_prediction2[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     all_harmonic_pred = torch.sigmoid(cur_prediction2[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # all_harmonic_pred1 = torch.sigmoid(cur_prediction[2]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # all_harmonic_pred2 = torch.sigmoid(cur_prediction[3]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     # prediction_lst.append(cur_prediction)
#                     # cur_prediction_norm = (cur_prediction - np.min(cur_prediction)) / (np.max(cur_prediction) - np.min(cur_prediction) + 1e-12)
#
#                     cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)
#
#                     cur_fund_prediction_norm2 = (fund_pred2 - np.min(fund_pred2)) / (np.max(fund_pred2) - np.min(fund_pred2) + 1e-12)
#
#                     cur_harmonic_prediction_norm = (all_harmonic_pred - np.min(all_harmonic_pred)) / (np.max(all_harmonic_pred) - np.min(all_harmonic_pred) + 1e-12)
#                     # cur_harmonic_prediction_norm1 = (all_harmonic_pred1 - np.min(all_harmonic_pred1)) / (np.max(all_harmonic_pred1) - np.min(all_harmonic_pred1) + 1e-12)
#                     # cur_harmonic_prediction_norm2 = (all_harmonic_pred2 - np.min(all_harmonic_pred2)) / (np.max(all_harmonic_pred2) - np.min(all_harmonic_pred2) + 1e-12)
#                     # Apply threshold (keep >=0.8, zero out rest)
#                     # cur_fund_prediction_norm2[cur_fund_prediction_norm2 < 0.85] = 0
#                     # cur_harmonic_prediction_norm[cur_harmonic_prediction_norm < 0.85] = 0
#                     # cur_fund_prediction_norm[cur_harmonic_prediction_norm < 0.85] = 0
#
#
#                     # binary_cur_prediction=np.where(cur_prediction_norm<=0.35,0,1)
#                     binary_cur_prediction=all_harmonic_pred
#                     binary_cur_truth=cur_y
#                     # binary_cur_truth=np.where(cur_y<=0.35,0,1)
#
#                     # mod_cur_prediction = np.zeros(len(cur_prediction))
#                     # max_f_pred = np.argmax(cur_prediction)
#                     #
#                     # mod_cur_prediction[max_f_pred - 1:max_f_pred + 5] = 1
#                     # binary_cur_prediction = mod_cur_prediction
#                     # binary_cur_truth = cur_y
#
#                     plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"
#
#                     # plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)
#                     # # plt.plot(cur_x.squeeze(0)[0])
#                     # # plt.plot(cur_x.squeeze(0)[1])
#                     #
#                     # plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)
#                     # # plt.plot(binary_cur_prediction, '--', linewidth=1, label="Predicted")
#                     # # plt.plot(cur_prediction, linewidth=0.8, label="raw prediction")
#                     #
#                     # plt.plot(cur_fund_prediction_norm2, linewidth=0.8, label=" f0",alpha=0.8)
#                     # # plt.plot(cur_harmonic_prediction_norm, '--', linewidth=0.7, label="raw all_harmonics",alpha=0.8)
#                     # # plt.plot(cur_fund_prediction_norm, linewidth=0.8,label="ppsp_original",alpha=0.8)
#                     #
#                     # # plt.plot(fund_pred2, linewidth=0.8, label="raw f0",alpha=0.8)
#                     # # plt.plot(all_harmonic_pred, linewidth=0.7, label="raw all_harmonics",alpha=0.8)
#                     # # plt.plot(fund_pred, linewidth=0.8,label="ppsp_original",alpha=0.8)
#                     # # plt.plot(cur_harmonic_prediction_norm2, linewidth=0.8, label="f3")
#                     # plt.legend(loc='lower right')
#                     #
#                     # # plt.savefig(f"./conv2d_data/pred_plots/{plt_fil_name}.png")
#                     # # plt.close()
#                     # #
#                     # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
#                     # # plt.close()
#                     #
#                     # plt.show()
#                     """
#                     Class activation map crepe
#                     """
#                     torch_x.requires_grad = True
#                     output = trained_model3(torch_x)
#                     target_index = np.argmax(output.detach().cpu().numpy())
#                     score = output[0, 0, target_index]
#                     trained_model3.zero_grad()
#                     score.backward(retain_graph=True)
#
#                     # Compute gradient of output w.r.t. input PSD
#                     input_grad = torch_x.grad.detach().cpu().numpy().squeeze()
#                     input_grad_abs = np.abs(input_grad)
#                     input_grad_norm = input_grad_abs / (np.max(input_grad_abs) + 1e-12)
#
#
#                     fig, axes = plt.subplots(2, 2, figsize=(12, 6), sharey=False)
#                     for ind_x in range(cur_x.squeeze(0).shape[0]):
#                         axes[0, 0].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.9)
#                     axes[0, 0].plot(binary_cur_truth, linewidth=1.1, label="True", color="orange", alpha=0.7)
#                     axes[0, 0].plot(cur_fund_prediction_norm2, '--', linewidth=0.7, color="green", label="ppsp_1up", alpha=0.8)
#                     axes[0, 0].plot(all_harmonic_pred * 0.8, '--', linewidth=0.7, color="red", label="non normalized_harmonics", alpha=0.9)
#                     axes[0, 0].set_title("ppsp_1up")
#                     axes[0, 0].legend(loc="lower right")
#
#
#                     for ind_x in range(cur_x.squeeze(0).shape[0]):
#                         axes[0, 1].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
#                     axes[0, 1].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
#                     axes[0, 1].plot(cur_fund_prediction_norm, '--', linewidth=0.8, color="green", label="1up", alpha=0.9)
#                     axes[0, 1].set_title("1up(crepe)")
#                     axes[0, 1].legend(loc="lower right")
#
#                     # axes[1,0].axis('off')
#                     axes[1, 0].plot(cur_x.squeeze(), color="gray",linewidth=1.8, alpha=0.8, label="PSD")
#                     axes[1, 0].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
#                     axes[1,0].plot(input_grad_norm1, '--', color="red", linewidth=0.7, label="sailency",alpha=0.9 )
#                     axes[1, 0].set_title(f"PPSP1up Input Saliency ")
#                     axes[1, 0].legend(loc="lower right",fontsize=8)
#
#
#                     axes[1, 1].plot(cur_x.squeeze(), color="gray",linewidth=1.8, alpha=0.8, label="PSD")
#                     axes[1, 1].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
#                     axes[1, 1].plot(cur_fund_prediction_norm, color="green",linewidth=1, label="crepe")
#                     axes[1, 1].plot(input_grad_norm,'--', color="red", linewidth=0.7, label="CAM",alpha=0.9)
#                     axes[1, 1].set_title(f"Crepe CAM ")
#                     axes[1, 1].legend(loc="lower right")
#
#                     plt.suptitle(f"{cur_fund_freq} Hz – {cur_fil_name}_{comb_lst}",fontsize=8)
#                     plt.tight_layout()
#                     plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_compare.png"), dpi=150)
#                     plt.close()
#                     # plt.show()
#
#                     # fig, axes = plt.subplots(2, 2, figsize=(10, 4), sharey=True)
#                     #
#                     # # --- Left: FPN_2 prediction + input channels ---
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     axes[0].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.9)
#                     # axes[0].plot(binary_cur_truth, linewidth=1.1, label="True",color="orange", alpha=0.7)
#                     # axes[0].plot(cur_fund_prediction_norm2,'--', linewidth=0.7, color="green", label="ppsp_1up", alpha=0.8)
#                     # axes[0].plot(all_harmonic_pred*0.8,'--', linewidth=0.7, color="red", label="non normalized_harmonics", alpha=0.9)
#                     # axes[0].set_title("ppsp_1up")
#                     # axes[0].legend(loc="lower right")
#                     #
#                     # # --- Right: FPN_2_mtl prediction + same inputs for reference ---
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     axes[1].plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
#                     # axes[1].plot(binary_cur_truth, linewidth=1, label="True", alpha=0.8)
#                     # axes[1].plot(cur_fund_prediction_norm,'--', linewidth=0.8, color="green", label="1up", alpha=0.9)
#                     # axes[1].set_title("1up(crepe)")
#                     # axes[1].legend(loc="lower right")
#                     #
#                     # plt.suptitle(f"{cur_fund_freq} Hz – {cur_fil_name}_{comb_lst}", fontsize=11)
#                     # plt.tight_layout()
#                     #
#                     # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_compare.png"), dpi=150)
#                     # # plt.close()
#                     # plt.show()
#
#
#                     # cur_pred_accuracy = dice_coeff(binary_cur_prediction, binary_cur_truth)
#                     cur_overlap_accuracy = overlap_dice(binary_cur_prediction, binary_cur_truth)
#
#                     ### fuzzy logic
#                     # lr_strt, lr_end, cur_peak = fuzzy_funcs.lr_prediction(binary_cur_prediction,cur_x.squeeze(0).squeeze(0)[0],cur_fund_freq, distances[ind],plt_fil_name)
#                     # print(f"lrstart: {lr_strt},lrend: {lr_end}, lrpeak: {cur_peak}")
#
#                     # freq_pred = fuzzy_funcs.predict_freq(orig_sig[ind], orig_sig[ind], lr_strt, lr_end,
#                     #                                      cur_peak)
#                     # print(f"fundamental freq: {cur_fund_freq},freq_pred: {freq_pred}")
#
#                     # fund_freq_lst.append(cur_fund_freq)
#                     # freq_prediction_lst.append(freq_pred)
#                     # criterion = torch.nn.BCELoss()
#                     # loss = criterion(binary_cur_prediction, binary_cur_truth)
#                     # cur_pred_accuracy2= torch.nn.BCELoss()(binary_cur_prediction, binary_cur_truth)
#                     # cur_pred_accuracy = region_accuracy(binary_cur_truth, binary_cur_prediction, threshold=0.3)
#
#                     prediction_lst.append(cur_overlap_accuracy)
#
#                     # print(f"freqpred: {freq_pred}")
#                     # prediction_lst.append(freq_pred)
#
#             accuracy = sum(prediction_lst) / len(prediction_lst)
#             ## Convert to numpy arrays
#             # cur_gtruth = np.array(fund_freq) * 60
#             # cur_predict = np.array(prediction_lst) * 60
#             #
#             # #### filter out rows where cur_predict == 0
#             # # filtered_data = [(gt, pred) for gt, pred in zip(cur_gtruth, cur_predict) if pred != 0]
#             # #
#             # # filtered_gtruth, filtered_predict = zip(*filtered_data)
#             # # filtered_gtruth = np.array(filtered_gtruth)
#             # # filtered_predict = np.array(filtered_predict)
#             #
#             # #### Calculate absolute error and mean percentage error
#             # # abs_err = np.abs(filtered_gtruth - filtered_predict)
#             # # mean_err = np.mean((abs_err / filtered_gtruth) * 100)
#             # abs_err = np.abs(cur_gtruth - cur_predict)
#             # mean_err = np.mean((abs_err/cur_gtruth)*100)
#             #
#             df = pd.DataFrame({
#                 'fundamental_frequency': list(fund_freq),
#                 'predictions_array': prediction_lst,
#                 'distance': list(distances),
#                 'file_name': file_names
#                 # 'fund_freq': fund_freq_lst,
#                 # 'predictions': freq_prediction_lst
#             })
#             # # df = df[df['predictions_array'] != 0]
#             # df['abs_err'] = np.abs((df['fundamental_frequency'] * 60) - (df['predictions_array'] * 60))
#             # #### Save to CSV
#             df.to_csv(f"{dir_path}/{key}_{accuracy}_conv1d_ppsp.csv", index=False)
#


In [None]:
# ### use the full train data for training and whole test dataset for validation.
# ### once this is done, pickle dump it in the same format used for model training
#
# import numpy as np
# import pickle
# all_combs_lists = [[0],[3],[0, 1,2, 3]]
# print(all_combs_lists)
#
# def process_dict_test(loaded_dict, row_indices=[0,4], col_size=512, output_format="channels_first"):
#
#     processed_data = {}
#
#     for key, (test_x, test_y, fund_freq_lst, distances_lst, file_names_lst,orig_signal_lst) in loaded_dict.items():
#         test_x = np.array(test_x)
#         test_y = np.array(test_y)
#         distances_lst = np.array(distances_lst)
#         fund_freq_lst = np.array(fund_freq_lst)
#         orig_signal_lst = np.array(orig_signal_lst)
#
#         # Select rows and columns size (preserve original order - no sorting)
#         X = test_x[:, row_indices, :col_size]
#         y = test_y
#
#         # Reshape based on desired output format
#         if output_format == "channels_first":
#             X = X[:, :, np.newaxis, :]  # [N, num_channels, 1, col_size]
#         elif output_format == "channels_last":
#
#             X = X[:, np.newaxis, :, :]  # [N, 1, num_channels, col_size]
#         else:
#             raise ValueError("output_format must be 'channels_first' or 'channels_last'")
#
#         # Return all data in original order (no train/val split)
#         processed_data[key] = (X, y, distances_lst)
#
#     return processed_data
#
# def process_dict_train(loaded_dict, row_indices=[0,4], col_size=512, output_format="channels_first"):
#
#     processed_data = {}
#     for key, (train_x, train_y, distances, *_) in loaded_dict.items():
#         train_x = np.array(train_x)
#         train_y = np.array(train_y)
#         distances = np.array(distances)
#
#         ### sort by distance
#         sorted_indices = np.argsort(distances)
#         train_x = train_x[sorted_indices]
#         train_y = train_y[sorted_indices]
#         distances = distances[sorted_indices]
#
#         ### select rows and columns size
#         X = train_x[:, row_indices, :col_size]
#         # train_y = train_y[:,:col_size]
#
#         ### reshape based on desired output format
#         if output_format == "channels_first":
#             # [N, num_channels, 1, col_size] - channels first
#             X = X[:, :, np.newaxis, :]
#         elif output_format == "channels_last":
#             # [N, 1, num_channels, col_size] - channels last
#             X = X[:, np.newaxis, :, :]
#         else:
#             raise ValueError("output_format must be 'channels_first' or 'channels_last'")
#
#         ### split into train/val sets
#         X_train, y_train = X, train_y
#         # X_train, X_val, y_train, y_val, dist_train, dist_val = train_test_split(
#         #     X, train_y, distances, test_size=val_ratio, random_state=random_state
#         # )
#
#         processed_data[key] = (X_train, y_train, distances)
#
#     return processed_data
#
# mode = "block"
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}.pkl", "rb") as f:
#     loaded_train_dict = pickle.load(f)
#
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_test.pkl", "rb") as f:
#     loaded_test_dict = pickle.load(f)


In [None]:
# aggregated_combs_data_lst=[]
# for comb in all_combs_lists:
#
#     ### [N, 1, 4, 512] - channels last
#     train_processed_channels_last = process_dict_train(
#         loaded_train_dict,
#         row_indices=comb,
#         col_size=1024,
#         output_format="channels_last"
#     )
#     test_processed_channels_last = process_dict_test(
#         loaded_test_dict,
#         row_indices=comb,
#         col_size=1024,
#         output_format="channels_last"
#     )
#
#     all_X_train, all_X_val = [], []
#     all_y_train, all_y_val = [], []
#     all_dist_train, all_dist_val = [], []
#
#     for key, (X_train, y_train, dist_train) in train_processed_channels_last.items():
#         all_X_train.append(X_train)
#         all_y_train.append(y_train)
#         all_dist_train.append(dist_train)
#
#     for key, (X_train, y_train, dist_train) in test_processed_channels_last.items():
#         all_X_val.append(X_train)
#         all_y_val.append(y_train)
#         all_dist_val.append(dist_train)
#
#     X_train = np.concatenate(all_X_train, axis=0).squeeze(1)
#     X_val   = np.concatenate(all_X_val, axis=0).squeeze(1)
#     y_train = np.concatenate(all_y_train, axis=0)
#     y_val   = np.concatenate(all_y_val, axis=0)
#     dist_train = np.concatenate(all_dist_train, axis=0)
#     dist_val   = np.concatenate(all_dist_val, axis=0)
#
#     aggregated_combs_data_lst.append((X_train, X_val, y_train, y_val, dist_train, dist_val))
#
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_1.pkl", "wb") as f:
#         pickle.dump(aggregated_combs_data_lst, f)

In [None]:
# mode = "gaussian"
# with open(f"./conv2d_data/conv2d_psd_scaled_down_1up_{mode}_1.pkl", "rb") as f:
#     loaded_dict = pickle.load(f)
# print()

In [None]:

# all_X_train, all_X_val = [], []
# all_y_train, all_y_val = [], []
# all_dist_train, all_dist_val = [], []
#
# for key, (X_train, X_val, y_train, y_val, dist_train, dist_val) in processed_channels_last.items():
#     all_X_train.append(X_train)
#     all_X_val.append(X_val)
#     all_y_train.append(y_train)
#     all_y_val.append(y_val)
#     all_dist_train.append(dist_train)
#     all_dist_val.append(dist_val)
#
# ### Concatenate along first axis
# # X_train = np.concatenate(all_X_train, axis=0)
# # X_val   = np.concatenate(all_X_val, axis=0)
# # y_train = np.concatenate(all_y_train, axis=0)
# # y_val   = np.concatenate(all_y_val, axis=0)
# # dist_train = np.concatenate(all_dist_train, axis=0)
# # dist_val   = np.concatenate(all_dist_val, axis=0)
#
# X_train = np.concatenate(all_X_train, axis=0).squeeze(1)
# X_val   = np.concatenate(all_X_val, axis=0).squeeze(1)
# y_train = np.concatenate(all_y_train, axis=0)
# y_val   = np.concatenate(all_y_val, axis=0)
# dist_train = np.concatenate(all_dist_train, axis=0)
# dist_val   = np.concatenate(all_dist_val, axis=0)
#
# aggregated_combs_data_lst.append((X_train, X_val, y_train, y_val, dist_train, dist_val))

In [None]:
# model = FPN_2D(in_channels=1, base_channels=16)
# model.load_state_dict(torch.load('best_fpn2d_model.pth', map_location=device))
# model.eval()
#
# # X_train = np.concatenate(all_X_train, axis=0)
# # X_val   = np.concatenate(all_X_val, axis=0)
# # y_train = np.concatenate(all_y_train, axis=0)
# # y_val   = np.concatenate(all_y_val, axis=0)
# # dist_train = np.concatenate(all_dist_train, axis=0)
# # dist_val   = np.concatenate(all_dist_val, axis=0)
# for ind in list(np.where(dist_train>80)[0]):
#     X_test = torch.FloatTensor(np.array([[X_train[ind]]]))  # [num_samples, 1, 2, 512]
#     y_test = torch.FloatTensor(y_train[ind])
#
#     # test_x1= X_train[1605]
#     # test_x1= X_train[1614]
#     prediction=model(X_test)
#
#     plt.title(f'{dist_train[ind]}')
#     plt.plot(X_train[ind].T)
#     plt.plot(y_test, label='True Mask', linestyle='--',linewidth=2, color='green')
#     plt.plot(prediction.detach().numpy()[0].T, label='Predicted Mask', linestyle='--', color='red')
#     plt.legend()
#     plt.show()
#
# # print()


In [None]:

# # ===== Case 1: [N, 1, 4, 512] =====
# model_case1 = FPN_2D(in_channels=1, base_channels=16)  # base_channels smaller for quick test
# x1 = torch.randn(2, 1, 4, 1024)   # batch=2, 1 channel, height=4, width=512
# y1 = model_case1(x1)
# print("Case 1 input:", x1.shape, " -> output:", y1.shape)
#
#
# # ===== Case 2: [N, 4, 1, 512] =====
# model_case2 = FPN_2D(in_channels=4, base_channels=16)
# x2 = torch.randn(2, 4, 1, 1024)   # batch=2, 4 channels, height=1, width=512
# y2 = model_case2(x2)
# print("Case 2 input:", x2.shape, " -> output:", y2.shape)

In [None]:
# root_folder_pth = f"../data/"
# train_folders_lst = [("fan3_3spd_augment",96.7),("fan5_3spd_augment",41.97), ("bldc_1_augment",94.98),  ("bldc_2_augment",80.25)]
#
# psd_length = 1024
# fs = 44100
# ss_num_chunks, welch_num_chunks = 3, 2
# psd_scale_down_factors = [1,2,3,4,'sum']
# plot_fig, save_fig = False, False
# ### calculating resample factors
# strt_pnt = 0.3
# resampl_factor_lst = generate_resample_factors(start=strt_pnt, end=3.1, step=0.1)
#
# for _,(train_folder,cur_fund_freq) in enumerate(train_folders_lst):
#     print(train_folder, cur_fund_freq)
#
#     cur_folder_pth = os.path.join(root_folder_pth, train_folder)
#     if not os.path.exists(cur_folder_pth):
#         print(f"{cur_folder_pth} does not exist")
#         continue
#     else:
#         data_folder_pth = os.path.join(cur_folder_pth, "orig_files")
#         files_lst = os.listdir(data_folder_pth)
#
#         for file in files_lst:
#             print(file)
#             if file not in [".DS_Store"]:
#                 cur_file_pth = os.path.join(data_folder_pth, file)
#                 cur_signal = np.sum(read_files(cur_file_pth), axis=0)
#
#                 # num_pnts = len(cur_signal) // welch_num_chunks
#                 #
#                 # ### original signal psd (1st scaled or no scaled version)
#                 # orig_freq_ss, orig_Pxx_ss = signal.welch(cur_signal, fs, nperseg=num_pnts, nfft=fs)
#                 # log_Pxx_ss = np.log(orig_Pxx_ss)
#                 # log_Pxx_ss = min_max_norm(log_Pxx_ss)
#                 # # orig_norm_Pxx = min_max_norm(log_Pxx_ss[:psd_length])
#                 # orig_norm_Pxx = log_Pxx_ss[:psd_length]
#                 #
#                 # cur_psds = [orig_norm_Pxx]
#                 # for scale_factor in psd_scale_down_factors:
#                 #     if scale_factor not in [1 ,'sum']:
#                 #         cur_scaled_psd = scale_psd(log_Pxx_ss, psd_length, scale_factor, method="average")
#                 #         # cur_psds.append(cur_scaled_psd)
#                 #         cur_psds.append(min_max_norm(cur_scaled_psd))
#                 #
#                 # # summed_psd = cur_psds[0].copy()
#                 # # for p in cur_psds[1:]:
#                 # #     summed_psd += p
#                 # #
#                 # # cur_psds.append(min_max_norm(summed_psd))
#                 #
#                 # stacked = np.stack(cur_psds, axis=0)
#                 # max_psd = np.median(stacked, axis=0)
#                 # cur_psds.append(min_max_norm(max_psd))
#                 #
#                 # peaks = compute_harmonic_peaks(fundamental_freq=cur_fund_freq)
#
#                 for resamp_ind, resamp_factor in enumerate(resampl_factor_lst):
#                     if resamp_factor not in [0.0, 3.0]:
#                         print(f"Resample factor {resamp_factor}, cur signal id {file}")
#                         resample_sig = librosa.resample(cur_signal, orig_sr=fs, target_sr=int(fs * resamp_factor))
#
#                         resamp_num_pnts = len(resample_sig) // welch_num_chunks
#
#                         if resamp_factor >= 2.0:
#                             resamp_freq_ss, resamp_Pxx_ss = signal.welch(resample_sig, fs, nperseg=resamp_num_pnts,
#                                                                          nfft=int(fs * resamp_factor))
#                         else:
#                             resamp_freq_ss, resamp_Pxx_ss = signal.welch(resample_sig, fs, nperseg=resamp_num_pnts,
#                                                                          nfft=fs)
#
#                         resamp_log_Pxx_ss = np.log(resamp_Pxx_ss)
#                         #### perform spectrum downsampling
#
#                         resamp_norm_Pxx_ss = resamp_log_Pxx_ss[:psd_length]
#                         resamp_norm_Pxx_ss[:5]=0
#                         resamp_norm_Pxx_ss = min_max_norm(resamp_norm_Pxx_ss)
#
#                         cur_fund_freq = (cur_fund_freq / resamp_factor)[0]
#
#                         cur_fund_freq_lst = []
#                         j = 0
#                         for i in range(50):
#                             if j + cur_fund_freq >= psd_length-2:
#                                 break
#                             j += cur_fund_freq
#                             cur_fund_freq_lst.append(j)
#
#                         cur_fund_freq_lst = [int(np.round(i,0)) for i in cur_fund_freq_lst]
#
#                 if plot_fig:
#                     fig, axes = plt.subplots(len(cur_psds), 1, figsize=(8, 1.5*len(cur_psds)), sharex=False)
#                     fig.suptitle(f"{file}")
#                     for idx, psd in enumerate(cur_psds):
#                         axes[idx].plot(psd)
#                         axes[idx].set_title(f"PSD {idx+1} (factor {psd_scale_down_factors[idx]})")
#                         for h in peaks:
#                             if h <= orig_freq_ss[psd_length-1]:  # only mark if inside plotted range
#                                 cur_idx = np.argmin(np.abs(orig_freq_ss[:psd_length] - h))  # closest index
#                                 axes[idx].plot(orig_freq_ss[cur_idx], psd[cur_idx], "r*", markersize=10)
#
#                     plt.tight_layout()
#
#                     if save_fig:
#                         plt.savefig(os.path.join("./", f"{file}.png"))
#                     else:
#                         plt.show()
#
#
#
#
#


In [None]:
# with open("./conv2d_data/conv2d_psd_scaled_down.pkl", "rb") as f:
#     loaded_dict = pickle.load(f)
#
# processed = process_loaded_dict(
#     loaded_dict,
#     row_indices=[0, 4],
#     col_size=512,
#     val_ratio=0.2
# )
#
# all_X_train, all_X_val = [], []
# all_y_train, all_y_val = [], []
#
# for key, (X_train, X_val, y_train, y_val, dist_train, dist_val) in processed.items():
#     all_X_train.append(X_train)
#     all_X_val.append(X_val)
#     all_y_train.append(y_train)
#     all_y_val.append(y_val)
#
# # Concatenate data
# X_train = np.concatenate(all_X_train, axis=0)  # Shape: (num_samples, 2, 512)
# X_val = np.concatenate(all_X_val, axis=0)      # Shape: (num_samples, 2, 512)
# y_train = np.concatenate(all_y_train, axis=0)  # Shape: (num_samples, 1024)
# y_val = np.concatenate(all_y_val, axis=0)      # Shape: (num_samples, 1024)
#
# # Since model output is 512 width, resize y to match
# y_train = y_train[:, :512]  # Shape: (num_samples, 512)
# y_val = y_val[:, :512]      # Shape: (num_samples, 512)
#
# print("Data shapes:")
# print(f"X_train: {X_train.shape}")  # (num_samples, 2, 512)
# print(f"y_train: {y_train.shape}")  # (num_samples, 512)
# print(f"X_val: {X_val.shape}")      # (num_samples, 2, 512)
# print(f"y_val: {y_val.shape}")      # (num_samples, 512)


# def plot_predictions(model, X_data, y_true, num_samples=5, device='mps'):
#     """
#     Plot predictions vs true outputs along with input data
#     """
#     model.eval()
#
#     # Select random samples
#     indices = np.random.choice(len(X_data), num_samples, replace=False)
#
#     fig, axes = plt.subplots(num_samples, 3, figsize=(15, 3*num_samples))
#
#     if num_samples == 1:
#         axes = axes.reshape(1, -1)
#
#     for i, idx in enumerate(indices):
#         ### Prepare input - X_data shape: (num_samples, 2, 512)
#         x_input = torch.FloatTensor(X_data[idx:idx+1, np.newaxis, :, :]).to(device)  # [1, 1, 2, 512]
#
#         ### Get prediction
#         with torch.no_grad():
#             y_pred = model(x_input).cpu().numpy()  # [1, 1, 1, 512]
#             y_pred = y_pred.squeeze()  # [512] - remove batch, channel, and height dimensions
#
#         ### Get true output and input data
#         y_true_sample = y_true[idx]  # [512]
#         x_input_sample = X_data[idx]  # [2, 512] - two input channels
#
#         ### Plot input (both channels)
#         axes[i, 0].plot(x_input_sample[0], label='Input Channel 1', alpha=0.7, color='blue')
#         axes[i, 0].plot(x_input_sample[1], label='Input Channel 2', alpha=0.7, color='green')
#         axes[i, 0].set_title(f'Sample {idx}: Input PSDs')
#         axes[i, 0].legend()
#         axes[i, 0].grid(True)
#
#         ### Plot true output
#         axes[i, 1].plot(y_true_sample, label='True Mask', color='red', linewidth=2)
#         axes[i, 1].set_title('True Harmonic Mask')
#         axes[i, 1].set_ylim(0, 1)
#         axes[i, 1].grid(True)
#
#         ### Plot prediction vs true
#         axes[i, 2].plot(y_true_sample, label='True', color='red', alpha=0.7, linewidth=2)
#         axes[i, 2].plot(y_pred, label='Predicted', color='blue', alpha=0.7)
#         axes[i, 2].set_title('Prediction vs True')
#         axes[i, 2].set_ylim(0, 1)
#         axes[i, 2].legend()
#         axes[i, 2].grid(True)
#
#     plt.tight_layout()
#     # plt.savefig('predictions_comparison.png', dpi=300, bbox_inches='tight')
#     plt.show()


# def make_true_mask(cur_fund_freq, psd_length=1024, sigma=4.0, n_harmonics=100):
#
#     cur_fund_freq_lst = []
#     j = 0
#     for i in range(n_harmonics):
#         if j + cur_fund_freq >= psd_length - 2:
#             break
#         j += cur_fund_freq
#         cur_fund_freq_lst.append(j)
#
#     cur_fund_freq_lst = np.array([int(np.round(i)) for i in cur_fund_freq_lst])
#
#     ### Build mask with Gaussian peaks
#     true_mask = np.zeros(psd_length)
#     x = np.arange(psd_length)
#
#     for f in cur_fund_freq_lst:
#         gaussian = (1 / np.sqrt(2 * np.pi * sigma**2)) * np.exp(-0.5 * ((x - f)/sigma)**2)
#         true_mask = np.maximum(true_mask, gaussian)  #### take max to keep strongest peak at each bin
#
#     #### Normalize so max = 1
#     ### adding this to make sure the sum is always 1 but we don't have any overlaps
#     true_mask /= np.max(true_mask)
#
#     return true_mask
