In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from fpn_2 import PPSP
from ppsp_1up_head import PPSP_withFundamental

In [None]:
import os, pickle

mode='gaussian'

with open(f"./conv2d_psd_scaled_down_1up_{mode}_1.pkl", "rb") as f:
    aggregated_combs_data_lst = pickle.load(f)

all_combs_lists = [[0],[3],[0, 1, 2, 3]]

print(f"Datasets length={len(aggregated_combs_data_lst)}, combination_length={len(all_combs_lists)}")

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import torch.nn as nn
psd_length=1024

def plot_region_masks(model_1_output, model_2_output, index=0):
    """
    model_1_output: list like [[region_mask_f0, gt_y], ...]
    model_2_output: list like [[mask_f0, mask_2f0, mask_3f0, mask_4f0], ...]
    index: which sample to visualize
    """
    # unpack
    f0_mask, gt_y = model_1_output[index]
    masks = model_2_output[index]  # list of harmonic masks
    x = np.arange(len(f0_mask))

    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    # ---- Row 1: fundamental vs ground truth ----
    axes[0].plot(x, gt_y, label='Ground Truth (all harmonics)', color='gray', linewidth=2, alpha=0.7)
    axes[0].plot(x, f0_mask, label='Fundamental region (Head1)', color='blue', linewidth=2)
    axes[0].set_title('Model 1 Output — Fundamental vs Ground Truth')
    axes[0].legend()
    axes[0].grid(True, linestyle='--', alpha=0.4)

    # ---- Row 2: individual harmonic regions ----
    colors = ['r', 'g', 'b', 'orange']
    labels = ['f0', '2f0', '3f0', '4f0']

    for i, mask in enumerate(masks):
        axes[1].plot(x, mask, color=colors[i % len(colors)], label=labels[i], linewidth=2)

    axes[1].set_title('Model 2 Output — Individual Harmonic Regions')
    axes[1].legend()
    axes[1].grid(True, linestyle='--', alpha=0.4)

    plt.xlabel('Frequency Bin')
    plt.tight_layout()
    plt.show()
def make_gaussian(zero_mask, center_idx, psd_length=1024, sigma=2):
    x = np.arange(psd_length)
    gaussian = (1 / np.sqrt(2 * np.pi * sigma ** 2)) * np.exp(-0.5 * ((x - center_idx) / sigma) ** 2)
    gaussian /= gaussian.max()
    final_mask = np.maximum(zero_mask, gaussian)

    return final_mask.astype(np.float32)

def remake_targets(batch_y, total_num_outputs=6):
    batch_y_ndarray = batch_y.numpy()
    batch_y_shape= batch_y_ndarray.shape
    model1_list, model2_list = [], []

    for ind in range(batch_y_shape[0]):
        cur_harmonic_indices=np.where(batch_y_ndarray[ind]== 1)[0]

        top_harmonic_indices=cur_harmonic_indices[:total_num_outputs-2]

        cur_zero_mask = np.zeros(psd_length)
        region_mask_f0 = make_gaussian(cur_zero_mask, top_harmonic_indices[0])
        region_mask_2f0 = make_gaussian(cur_zero_mask, top_harmonic_indices[1])
        region_mask_3f0 = make_gaussian(cur_zero_mask, top_harmonic_indices[2])
        region_mask_4f0 = make_gaussian(cur_zero_mask, top_harmonic_indices[3])

        ### plot_region_masks([[region_mask_f0, batch_y_ndarray[ind]]], [[region_mask_f0, region_mask_2f0, region_mask_3f0, region_mask_4f0]], index=0)

        model1_list.append(np.stack([region_mask_f0, batch_y_ndarray[ind]], axis=0).astype(np.float32))
        model2_list.append(np.stack([region_mask_f0, region_mask_2f0, region_mask_3f0, region_mask_4f0], axis=0).astype(np.float32))

    ### Force float32 before creating torch tensors
    model1_array = np.stack(model1_list).astype(np.float32)
    model2_array = np.stack(model2_list).astype(np.float32)

    model1_targets = torch.from_numpy(model1_array).to(device)
    model2_targets = torch.from_numpy(model2_array).to(device)


    return model1_targets, model2_targets

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

### Training function
def train_fpn2d_model(X_train, y_train, X_val, y_val, num_epochs=50, batch_size=100, learning_rate=0.001, model=None,
                      model_name="", train_model_flag=False):
    """
    Train FPN_2D model with format: [batch_size, 1, channels, 512]
    """
    ### Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train)
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.FloatTensor(y_val)

    print(f"Training data shapes:")
    print(f"X_train_tensor: {X_train_tensor.shape}")  # [num_samples, 1, selected_rows, 512]
    print(f"y_train_tensor: {y_train_tensor.shape}")  # [num_samples, 512]
    print(f"X_val_tensor: {X_val_tensor.shape}")  # [num_samples, 1, selected_rows, 512]
    print(f"y_val_tensor: {y_val_tensor.shape}")  # [num_samples, 512]

    ### Create datasets and dataloaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    ### Initialize model, loss, optimizer

    model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

    # # criterion = DiceLoss()
    # criterion = OverlapDiceLoss()
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)

    if train_model_flag:

        # Training variables
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        best_model_weights = None

        ### Training loop
        for epoch in range(num_epochs):
            ### Training phase
            model.train()
            train_loss = 0.0

            for batch_x, batch_y in train_loader:
                model1_targets, model2_targets = remake_targets(batch_y)

                ### plot_region_masks(model1_targets, model2_targets)

                batch_x = batch_x.to(device)  # [batch_size, 1, 2, 512]
                batch_y = batch_y.to(device)  # [batch_size, 512]

                optimizer.zero_grad()
                outputs = model(batch_x)

                ### loss function for model mtl1
                loss_fundamental = criterion(outputs[0].squeeze(1), model1_targets[:, 0, :])
                loss = loss_fundamental

                loss.backward()
                optimizer.step()

                train_loss += loss.item() * batch_x.size(0)

            train_loss /= len(train_loader.dataset)
            train_losses.append(train_loss)

            ### Validation phase
            model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for batch_x, batch_y in val_loader:
                    model1_targets, model2_targets = remake_targets(batch_y)
                    batch_x = batch_x.to(device)  # [batch_size, 1, 2, 512]
                    batch_y = batch_y.to(device)  # [batch_size, 512]

                    outputs = model(batch_x)  # [batch_size, 1, 1, 512]

                    ### loss function for model mtl1
                    loss_fundamental = criterion(outputs[0].squeeze(1), model1_targets[:, 0, :])
                    loss = loss_fundamental

                    val_loss += loss.item() * batch_x.size(0)

            val_loss /= len(val_loader.dataset)
            val_losses.append(val_loss)

            ### Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_weights = model.state_dict().copy()
                torch.save(best_model_weights, f'/content/drive/My Drive/train_test_data/{model_name}')

            if (epoch + 1) % 50 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], '
                      f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
                      f'Best Val: {best_val_loss:.4f}')

        cur_combination=model_name.split('_')[5].split('.')[0]
        # Plot training history
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'/content/drive/My Drive/train_test_data/training_history_{cur_combination}.png', dpi=300, bbox_inches='tight')
        plt.close()
        # plt.show()

    return model


for ind, (X_train, X_val, y_train, y_val, dist_train, dist_val) in enumerate(aggregated_combs_data_lst):
    if ind==ind:
        print(f"Training model: best_fpn1d_model_{mode}_{all_combs_lists[ind]}.pth")

        ppsp_weights_file = "best_model_weights_fan5_fan3_bldc_fpn2"
        ppsp_model = PPSP(in_channels=1)
        ppsp_model.load_state_dict(torch.load(f'./ppsp_orig_weights/{ppsp_weights_file}', map_location="cpu"))

        ppsp_1up = PPSP_withFundamental(ppsp_model, freeze=True, hidden=256)

        ### Train the model
        trained_model = train_fpn2d_model(
          X_train, y_train, X_val, y_val,
          num_epochs=150,
          batch_size=10,
          learning_rate=0.001,
          model=ppsp_1up,
          model_name=f"best_fpn2_1up_model_{mode}_{all_combs_lists[ind]}.pth",
          train_model_flag=False
        )

In [None]:
%load_ext autoreload
%autoreload 1
# %matplotlib widget
import matplotlib
import librosa

matplotlib.use('QT5Agg')
import numpy as np
import pandas as pd
import os, pickle, re
from scipy import signal
import matplotlib.pyplot as plt
from magtach.op_codes.preprocess_functions import read_files, min_max_norm

In [None]:
def process_test_dict(loaded_dict, row_indices=[0, 4], col_size=512, output_format="channels_first"):
    processed_data = {}

    for key, (test_x, test_y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst) in loaded_dict.items():
        test_x = np.array(test_x)
        test_y = np.array(test_y)
        distances_lst = np.array(distances_lst)
        fund_freq_lst = np.array(fund_freq_lst)
        orig_signal_lst = np.array(orig_signal_lst)

        # Select rows and columns size (preserve original order - no sorting)
        X = test_x[:, row_indices, :col_size]
        y = test_y

        # Reshape based on desired output format
        if output_format == "channels_first":
            X = X[:, :, np.newaxis, :]  # [N, num_channels, 1, col_size]
        elif output_format == "channels_last":

            X = X[:, np.newaxis, :, :]  # [N, 1, num_channels, col_size]
        else:
            raise ValueError("output_format must be 'channels_first' or 'channels_last'")

        # Return all data in original order (no train/val split)
        processed_data[key] = (X, y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst)

    return processed_data

In [None]:

import warnings, pickle
import torch
from fpn_2 import PPSP
from ppsp_1up_head import PPSP_withFundamental
warnings.filterwarnings("ignore")
mode = "gaussian"
# all_combs_lists = [[0], [3], [0, 1, 2, 3]]
all_combs_lists = [[0]]

device = 'cpu'

with open(f"./conv2d_psd_scaled_down_1up_{mode}_test.pkl", "rb") as f:
    loaded_dict_test = pickle.load(f)

for ind, comb_lst in enumerate(all_combs_lists):

    dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
    os.makedirs(dir_path, exist_ok=True)


    processed_test = process_test_dict(
        loaded_dict_test,
        row_indices=comb_lst,
        col_size=1024,
        output_format="channels_last"
    )
    print(f"Using device: {device}")

    ppsp_weights_file = "best_model_weights_fan5_fan3_bldc_fpn2"

    ppsp_model = PPSP(in_channels=1)
    ppsp_model.load_state_dict(torch.load(f'../../data/train_test_data/{ppsp_weights_file}', map_location="cpu"))


    ppsp_1up = PPSP_withFundamental(ppsp_model, freeze=True, hidden=256)
    ppsp_1up.eval()
    print("ppsp_model loaded")

    for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
        print(f"Key: {key}")
        prediction_lst = []
        freq_prediction_lst, fund_freq_lst = [], []
        if key == "bldc_6":
            for ind in range(len(distances)):
                if ind == ind:
                    cur_fund_freq = fund_freq[ind]
                    cur_fil_name = file_names[ind]

                    cur_x = np.array(X_test[ind])[:, :, :]
                    torch_x = torch.FloatTensor(cur_x).to('cpu')
                    cur_y = np.array(y_test[ind])

                    cur_prediction = ppsp_1up(torch_x)


                    fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
                    all_harmonic_pred1 = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()


                    cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)
                    cur_harmonic_prediction_norm1 = (all_harmonic_pred1 - np.min(all_harmonic_pred1)) / (np.max(all_harmonic_pred1) - np.min(all_harmonic_pred1) + 1e-12)


                    binary_cur_truth=cur_y


                    plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"

                    plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
                    for ind_x in range(cur_x.squeeze(0).shape[0]):
                        plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)

                    plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)


                    plt.plot(cur_fund_prediction_norm, linewidth=0.8, label=" f0",alpha=0.8)
                    plt.plot(cur_harmonic_prediction_norm1, '--', linewidth=0.7, label="raw all_harmonics",alpha=0.8)

                    plt.legend(loc='lower right')

                    # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
                    # plt.close()

                    plt.show()

In [None]:
from fpn_2 import PPSP
import torch
import numpy as np

from ppsp_1up_head import PPSP_withFundamental
from ppsp_mtl_9 import FPN2_withFundamental

ppsp_weights_file = "best_model_weights_fan5_fan3_bldc_fpn2"

ppsp_model = PPSP(in_channels=1)
ppsp_model.load_state_dict(torch.load(f'../../data/train_test_data/{ppsp_weights_file}', map_location="cpu"))
# ppsp_model.eval()

ppsp_1up_noEval = FPN2_withFundamental(ppsp_model, freeze=True, hidden=256)
ppsp_1up_noEval.eval()

# ppsp_1up_noEval = PPSP_withFundamental(ppsp_model, freeze=True, hidden=256)
# ppsp_1up_noEval.eval()
print("ppsp_model loaded")



In [None]:
import pickle, matplotlib
matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt

files=pickle.load(open(f"../../data/train_test_data/test_x_y_bldc_correct","rb"))
test_dict_lst =files[0]

for ind, values in enumerate(test_dict_lst):

    if values[1] in ["bldc_1"]:
        print(f"{values[1]}")
        test_dict, fundamental_freq = values[0], values[2]
        res_file = f"{values[1]}_results"

        for key, val in test_dict.items():
            if key == "5cm":
                print(key, len(val))
                cur_results = []

                for sig_ind in range(len(val)):
                    # print(sig_ind)
                    cur_fil_name = val[sig_ind][4]
                    print(cur_fil_name)

                    orig_sig = np.array(val[sig_ind][3])
                    cur_x, cur_y = val[sig_ind][0], val[sig_ind][1][:1, :][0]

                    cur_x = torch.tensor(cur_x, dtype=torch.float32).reshape(1, 1, -1)
                    sig = (cur_x[0][0].detach().numpy())

                    cur_pred = ppsp_model(cur_x)
                    ppsp_1up_noEval_pred = ppsp_1up_noEval(cur_x)
                    # ppsp_1up_Eval_pred = ppsp_1up_Eval(cur_pred)

                    cur_pred = torch.sigmoid(cur_pred)
                    pred = cur_pred[0][0].detach().numpy()

                    ppsp_1up_noEval_pred =torch.sigmoid(ppsp_1up_noEval_pred[1])
                    ppsp_1up_noEval_pred = ppsp_1up_noEval_pred[0][0].detach().numpy()

                    # cur_pred =torch.sigmoid(cur_pred)
                    # sig = (cur_x[0][0].detach().numpy())
                    # pred = cur_pred[0][0].detach().numpy()
                    # pred = np.where(pred <= 0.5, 0, pred)

                    plt.title(f"{cur_fil_name}_{fundamental_freq}")
                    plt.plot(sig)
                    plt.plot(pred)
                    plt.plot(ppsp_1up_noEval_pred, '--',color='green',alpha=0.7)
                    plt.plot(cur_y, "--", linewidth='0.8',color='red',alpha=0.5)

                    plt.tight_layout()

                    plt.show()