In [None]:
%load_ext autoreload
%autoreload 1
# %matplotlib widget
import matplotlib
import librosa
import numpy as np
import torch
import torch.nn.functional as F

matplotlib.use('QT5Agg')

In [None]:
def process_test_dict(loaded_dict, row_indices=[0, 4], col_size=512, output_format="channels_first"):
    processed_data = {}

    for key, (test_x, test_y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst) in loaded_dict.items():
        test_x = np.array(test_x)
        test_y = np.array(test_y)
        distances_lst = np.array(distances_lst)
        fund_freq_lst = np.array(fund_freq_lst)
        orig_signal_lst = np.array(orig_signal_lst)

        # Select rows and columns size (preserve original order - no sorting)
        X = test_x[:, row_indices, :col_size]
        y = test_y

        # Reshape based on desired output format
        if output_format == "channels_first":
            X = X[:, :, np.newaxis, :]  # [N, num_channels, 1, col_size]
        elif output_format == "channels_last":

            X = X[:, np.newaxis, :, :]  # [N, 1, num_channels, col_size]
        else:
            raise ValueError("output_format must be 'channels_first' or 'channels_last'")

        # Return all data in original order (no train/val split)
        processed_data[key] = (X, y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst)

    return processed_data

In [None]:
import os

print(os.getcwd())

In [None]:
### GradCAM class
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        ### Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, x, target_category=None):
        ### Forward pass
        output = self.model(x)

        if target_category is None:
            ### Use the fundamental prediction [batch_size, 1, 1024] for Grad-CAM
            target = output[0]
        else:
            target = output[target_category]

        ### Zero gradients
        self.model.zero_grad()

        ### For Grad-CAM, we need a scalar value (mean)
        target_scalar = target.mean()

        ### Backward pass for target
        target_scalar.backward(retain_graph=True)

        #### Get gradients and activations
        # [batch_size, channels, length]
        gradients = self.gradients
        activations = self.activations

        ### Global average pooling of gradients across spatial dimension (length)
        weights = torch.mean(gradients, dim=2)

        # Weight the activations
        batch_size, channels, length = activations.shape
        cam = torch.zeros(batch_size, length, device=activations.device)

        for i in range(batch_size):
            for j in range(channels):
                cam[i] += weights[i, j] * activations[i, j, :]

        # Apply ReLU
        cam = F.relu(cam)

        # Normalize
        cam = cam - cam.min(dim=1, keepdim=True)[0]
        cam = cam / (cam.max(dim=1, keepdim=True)[0] + 1e-8)

        return cam.detach().cpu().numpy(), output

In [None]:
"""
testing only ppsp mtl
"""
import warnings, pickle
import os, csv

warnings.filterwarnings("ignore")
mode = "block"
# all_combs_lists = [[0], [3], [0, 1, 2, 3]]
all_combs_lists = [[3]]
import torch
from models import fpn_2
from models import ppsp_1up
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import fuzzy_logic as fl

device = 'cpu'

with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_{mode}_test.pkl", "rb") as f:
    loaded_dict_test = pickle.load(f)

for ind, comb_lst in enumerate(all_combs_lists):

    dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
    os.makedirs(dir_path, exist_ok=True)

    ppsp_backbone = fpn_2.PPSP(in_channels=len(comb_lst), out_channels=32)
    ppsp_1up_model = ppsp_1up.PPSP_1up(ppsp_backbone=ppsp_backbone, hidden_nodes=256)

    model_name = f"../best_fpn2_1up_model_{mode}_{comb_lst}.pth"
    ppsp_1up_model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
    ppsp_1up_model.eval()

    ### Initialize Grad-CAM for the model
    target_layer = ppsp_1up_model.ppsp_backbone.conv_output1.c
    grad_cam = GradCAM(ppsp_1up_model, target_layer)

    processed_test = process_test_dict(
        loaded_dict_test,
        row_indices=comb_lst,
        col_size=1024,
        output_format="channels_last"
    )
    print(f"Using device: {device}")

    for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
        print(f"Key: {key}")
        cam_res_dict={}
        cur_res_dict = {}
        if key == key:
            # Initialize CSV file for this key (no extra indent)
            # cur_results = []
            # csv_path = os.path.join(dir_path, f"{key}_results.csv")
            # csv_file = open(csv_path, mode="w", newline="")
            # writer = csv.writer(csv_file)
            # writer.writerow(["filename", "gtruth_fund_freq", "predicted_fund_freq", "predicted_fund_freq_lst","clusters"])

            for sample_ind in range(len(distances)):
                cur_original_sig = orig_sig[sample_ind]
                cur_fund_freq = fund_freq[sample_ind]
                cur_fil_name = file_names[sample_ind]

                cur_x = np.array(X_test[sample_ind])[:, :, :]
                torch_x = torch.FloatTensor(cur_x).to('cpu')
                cur_y = np.array(y_test[sample_ind])

                ### Enable gradients for input
                torch_x.requires_grad_()

                with torch.enable_grad():
                    cur_prediction = ppsp_1up_model(torch_x)

                fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
                harmonic_pred = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()

                cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (
                        np.max(fund_pred) - np.min(fund_pred) + 1e-12)
                cur_harm_prediction_norm = (harmonic_pred - np.min(harmonic_pred)) / (
                        np.max(harmonic_pred) - np.min(harmonic_pred) + 1e-12)
                binary_cur_truth = cur_y

                # if cur_fil_name in ["fan_1721415536_7910_35cm_0"]:
                # print(cur_fil_name)
                ### apply thresholding
                bin_harmonic_pred = np.where(cur_harm_prediction_norm <= 0.5, 0, 1)
                bin_fund_pred = np.where(cur_fund_prediction_norm > 0.01, 1, 0)

                # ###
                # fund_regions_lst, fund_central_freq_lst = fl.find_windows(bin_fund_pred)
                # harm_regions_lst, harm_central_freq_lst = fl.find_windows(bin_harmonic_pred)
                # ###
                # clusters = fl.harmonic_clustering(fund_central_freq_lst, harm_central_freq_lst, tolerance=5,
                #                                   max_harmonic=15, max_freq=1024)
                # # if cur_fil_name in ["fan_1721413988_2795_90cm_2"]:
                # if len(clusters) > 0:
                #
                #     fine_pred_freq_lst = []
                #     ### find the region (start and end of these f0 candidates) for finer speed estimation
                #     for cur_ind in range(len(clusters)):
                #         # cur_pred_freq = clusters[cur_ind]['f0']
                #         # if clusters[cur_ind]['match_count']>1:
                #         try:
                #             index_of_region = fund_central_freq_lst.index(clusters[cur_ind]['f0'])
                #             region_indexes = fund_regions_lst[index_of_region]
                #             strt_ind, end_ind = region_indexes[0], region_indexes[1]
                #             cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
                #                                                  int((strt_ind + end_ind) // 2))
                #             fine_pred_freq_lst.append(cur_fine_pred_freq)
                #
                #         except:
                #             index_of_region = harm_central_freq_lst.index(clusters[cur_ind]['f0'])
                #             region_indexes = harm_regions_lst[index_of_region]
                #             strt_ind, end_ind = region_indexes[0], region_indexes[1]
                #             cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
                #                                                  int((strt_ind + end_ind) // 2))
                #             fine_pred_freq_lst.append(cur_fine_pred_freq)
                #         # else:
                #         #     fine_pred_freq_lst.append(0)
                #
                #
                # else:
                #     fine_pred_freq_lst = [0]
                #
                # print(f"filename={cur_fil_name} gtruth={cur_fund_freq} predicted={fine_pred_freq_lst}")
                #
                # ### fine tuning
                # fine_pred_freq_arr = np.array(fine_pred_freq_lst)
                # diff_arr = np.abs(fine_pred_freq_arr - cur_fund_freq)
                # min_idx = np.argmin(diff_arr)
                #
                # fine_pred_freq = fine_pred_freq_lst[min_idx]
                # """
                # write results to file
                # """
                # writer.writerow([cur_fil_name, cur_fund_freq,fine_pred_freq, fine_pred_freq_lst, clusters])
                # cur_results.append(
                #     [[bin_fund_pred, bin_harmonic_pred], binary_cur_truth, cur_x, cur_fund_freq, fine_pred_freq,
                #      cur_fil_name, clusters,fine_pred_freq_lst,])

                # ### Generate Grad-CAM
                with torch.enable_grad():
                    cam, _ = grad_cam(torch_x, target_category=0)

                ### Convert CAM to same length as input
                cam = cam.squeeze()

                ### Handle case where cam might be 2D (batch, features)
                if cam.ndim > 1:
                    cam = cam[0]

                cam_resized = np.interp(np.linspace(0, len(cam) - 1, len(cur_fund_prediction_norm)),
                                        np.arange(len(cam)), cam)


                plt.figure(figsize=(7, 4))
                # plt.subplot(2, 1, 1)
                for ind_x in range(cur_x.squeeze(0).shape[0]):
                    plt.plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
                plt.plot(binary_cur_truth, linewidth=1.5, label="True", alpha=0.9, color='black')
                # x_axis = np.arange(len(cam_resized))
                # plt.fill_between(x_axis, np.min(cur_x), np.max(cur_x),
                #                 where=cam_resized > 0.5,
                #                 alpha=0.3, color='red', label='Saliency')

                # plt.plot(cur_fund_prediction_norm, linewidth=1.2, label="f0", alpha=0.8, color='blue')
                plt.plot(bin_fund_pred, linewidth=1.2, label="f0", alpha=0.8, color='blue')
                # plt.plot(harmonic_pred, '--', linewidth=1.2, label="raw all_harmonics", alpha=0.8, color='green')
                plt.plot(bin_harmonic_pred, '--', linewidth=1.2, label="raw all_harmonics", alpha=0.8, color='green')

                plt.plot(cam_resized, color='red', linewidth=1.5, label='CAM')
                plt.fill_between(np.arange(len(cam_resized)), cam_resized, alpha=0.3, color='red')

                plt.legend(loc='lower right')
                plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst} - Saliency Overlay")
                plt.ylabel('Amplitude')

                # plt.subplot(2, 1, 2)
                # plt.plot(cam_resized, color='red', linewidth=1.5, label='Saliency')
                # plt.fill_between(np.arange(len(cam_resized)), cam_resized, alpha=0.3, color='red')
                # plt.xlabel('Time steps')
                # plt.ylabel('Saliency')
                # plt.ylim(0, 1)
                # plt.title('Grad-CAM Saliency Map')
                # plt.legend()

                plt.tight_layout()

                plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"

                plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_gradcam_mtl.png"), dpi=150)
                plt.close()

                # plt.show()

                if distances[sample_ind] not in cam_res_dict:
                    cam_res_dict[distances[sample_ind]] = [[cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam,cam_resized,harmonic_pred,cur_harm_prediction_norm]]
                else:
                    cam_res_dict[distances[sample_ind]].append([cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam,cam_resized,harmonic_pred,cur_harm_prediction_norm])

        pkl_path = os.path.join(dir_path, f"{key}_mtl_cam")
        pickle.dump([cam_res_dict], open(f"{pkl_path}", "wb"))

        #     csv_file.close()
        #
        # cur_res_dict[key] = cur_results
        # pkl_path = os.path.join(dir_path, f"{key}_results")
        # pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))



In [None]:
# """
# testing crepe and ppsp
# """
# import warnings, pickle
# import os, csv
#
# warnings.filterwarnings("ignore")
# mode = "block"
# # all_combs_lists = [[0], [3], [0, 1, 2, 3]]
# all_combs_lists = [[0]]
# import torch
# from models import fpn_2
# from models import crepe
# import matplotlib.pyplot as plt
# import torch.nn.functional as F
# import numpy as np
# import fuzzy_logic as fl
#
# device = 'cpu'
#
# with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_{mode}_test_bldc7.pkl", "rb") as f:
#     loaded_dict_test = pickle.load(f)
#
# for ind, comb_lst in enumerate(all_combs_lists):
#
#     dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
#     os.makedirs(dir_path, exist_ok=True)
#
#     # ppsp_backbone = fpn_2.PPSP(in_channels=len(comb_lst), out_channels=32)
#     # crepe = crepe.FPN_2_mtl(in_channels=len(comb_lst))
#     crepe = fpn_2.PPSP(in_channels=len(comb_lst))
#
#     model_name = f"../1ppsp_weights/best_model_weights_fan5_fan3_bldc_fpn2"
#     crepe.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
#     crepe.eval()
#
#     ### Initialize Grad-CAM for the model
#     # target_layer = ppsp_1up_model.ppsp_backbone.conv_output1.c
#     # grad_cam = GradCAM(ppsp_1up_model, target_layer)
#
#     processed_test = process_test_dict(
#         loaded_dict_test,
#         row_indices=comb_lst,
#         col_size=1024,
#         output_format="channels_last"
#     )
#     print(f"Using device: {device}")
#
#     for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
#         print(f"Key: {key}")
#         cur_res_dict = {}
#         if key == key:
#             # Initialize CSV file for this key (no extra indent)
#             cur_results = []
#             csv_path = os.path.join(dir_path, f"{key}_results.csv")
#             csv_file = open(csv_path, mode="w", newline="")
#             writer = csv.writer(csv_file)
#             writer.writerow(["filename", "gtruth_fund_freq", "predicted_fund_freq", "predicted_fund_freq_lst","clusters"])
#
#             for sample_ind in range(len(distances)):
#                 cur_original_sig = orig_sig[sample_ind]
#                 cur_fund_freq = fund_freq[sample_ind]
#                 cur_fil_name = file_names[sample_ind]
#
#                 cur_x = np.array(X_test[sample_ind])[:, :, :]
#                 torch_x = torch.FloatTensor(cur_x).to('cpu')
#                 cur_y = np.array(y_test[sample_ind])
#
#                 ### Enable gradients for input
#                 # torch_x.requires_grad_()
#
#                 # with torch.enable_grad():
#                 cur_prediction = crepe(torch_x)
#
#                 fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()
#                 # harmonic_pred = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#
#                 cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (
#                         np.max(fund_pred) - np.min(fund_pred) + 1e-12)
#                 # cur_harm_prediction_norm = (harmonic_pred - np.min(harmonic_pred)) / (
#                 #         np.max(harmonic_pred) - np.min(harmonic_pred) + 1e-12)
#                 binary_cur_truth = cur_y
#
#                 # if cur_fil_name in ["fan_1721415536_7910_35cm_0"]:
#                 # print(cur_fil_name)
#                 ### apply thresholding
#                 # bin_harmonic_pred = np.where(cur_harm_prediction_norm <= 0.5, 0, 1)
#                 bin_fund_pred = np.where(cur_fund_prediction_norm > 0.01, 1, 0)
#
#                 ###
#                 fund_regions_lst, fund_central_freq_lst = fl.find_windows(bin_fund_pred)
#                 # harm_regions_lst, harm_central_freq_lst = fl.find_windows(bin_harmonic_pred)
#                 ###
#                 # clusters = fl.harmonic_clustering(fund_central_freq_lst, harm_central_freq_lst, tolerance=5,
#                 #                                   max_harmonic=15, max_freq=1024)
#                 # if cur_fil_name in ["fan_1721413988_2795_90cm_2"]:
#                 # if len(clusters) > 0:
#                 #
#                 #     fine_pred_freq_lst = []
#                 #     ### find the region (start and end of these f0 candidates) for finer speed estimation
#                 #     for cur_ind in range(len(clusters)):
#                 #         # cur_pred_freq = clusters[cur_ind]['f0']
#                 #         # if clusters[cur_ind]['match_count']>1:
#                 #         try:
#                 #             index_of_region = fund_central_freq_lst.index(clusters[cur_ind]['f0'])
#                 #             region_indexes = fund_regions_lst[index_of_region]
#                 #             strt_ind, end_ind = region_indexes[0], region_indexes[1]
#                 #             cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
#                 #                                                  int((strt_ind + end_ind) // 2))
#                 #             fine_pred_freq_lst.append(cur_fine_pred_freq)
#                 #
#                 #         except:
#                 #             index_of_region = harm_central_freq_lst.index(clusters[cur_ind]['f0'])
#                 #             region_indexes = harm_regions_lst[index_of_region]
#                 #             strt_ind, end_ind = region_indexes[0], region_indexes[1]
#                 #             cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
#                 #                                                  int((strt_ind + end_ind) // 2))
#                 #             fine_pred_freq_lst.append(cur_fine_pred_freq)
#                 #         # else:
#                 #         #     fine_pred_freq_lst.append(0)
#                 #
#                 #
#                 # else:
#                 #     fine_pred_freq_lst = [0]
#                 fine_pred_freq_lst=[fund_regions_lst[0]]
#                 print(f"filename={cur_fil_name} gtruth={cur_fund_freq} predicted={fine_pred_freq_lst}")
#
#                 ### fine tuning
#                 # fine_pred_freq_arr = np.array(fine_pred_freq_lst)
#                 # diff_arr = np.abs(fine_pred_freq_arr - cur_fund_freq)
#                 # min_idx = np.argmin(diff_arr)
#                 #
#                 # fine_pred_freq = fine_pred_freq_lst[min_idx]
#                 fine_pred_freq = fund_regions_lst[0]
#                 """
#                 write results to file
#                 """
#                 writer.writerow([cur_fil_name, cur_fund_freq,fine_pred_freq, fine_pred_freq_lst])
#                 cur_results.append(
#                     [[bin_fund_pred], binary_cur_truth, cur_x, cur_fund_freq, fine_pred_freq,
#                      cur_fil_name, fine_pred_freq_lst,])
#
#                 # # ### Generate Grad-CAM
#                 # with torch.enable_grad():
#                 #     cam, _ = grad_cam(torch_x, target_category=0)
#                 #
#                 # ### Convert CAM to same length as input
#                 # cam = cam.squeeze()
#                 #
#                 # ### Handle case where cam might be 2D (batch, features)
#                 # if cam.ndim > 1:
#                 #     cam = cam[0]
#                 #
#                 # cam_resized = np.interp(np.linspace(0, len(cam) - 1, len(cur_fund_prediction_norm)),
#                 #                         np.arange(len(cam)), cam)
#                 #
#                 # plt.figure(figsize=(7, 4))
#                 # # plt.subplot(2, 1, 1)
#                 # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                 #     plt.plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7)
#                 # plt.plot(binary_cur_truth, linewidth=1.5, label="True", alpha=0.9, color='black')
#                 # # x_axis = np.arange(len(cam_resized))
#                 # # plt.fill_between(x_axis, np.min(cur_x), np.max(cur_x),
#                 # #                 where=cam_resized > 0.5,
#                 # #                 alpha=0.3, color='red', label='Saliency')
#                 #
#                 # # plt.plot(cur_fund_prediction_norm, linewidth=1.2, label="f0", alpha=0.8, color='blue')
#                 # plt.plot(bin_fund_pred, linewidth=1.2, label="f0", alpha=0.8, color='blue')
#                 # # plt.plot(harmonic_pred, '--', linewidth=1.2, label="raw all_harmonics", alpha=0.8, color='green')
#                 # plt.plot(bin_harmonic_pred, '--', linewidth=1.2, label="raw all_harmonics", alpha=0.8, color='green')
#                 #
#                 # # plt.plot(cam_resized, color='red', linewidth=1.5, label='CAM')
#                 # # plt.fill_between(np.arange(len(cam_resized)), cam_resized, alpha=0.3, color='red')
#                 #
#                 # plt.legend(loc='lower right')
#                 # plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst} - Saliency Overlay")
#                 # plt.ylabel('Amplitude')
#                 #
#                 # # plt.subplot(2, 1, 2)
#                 # # plt.plot(cam_resized, color='red', linewidth=1.5, label='Saliency')
#                 # # plt.fill_between(np.arange(len(cam_resized)), cam_resized, alpha=0.3, color='red')
#                 # # plt.xlabel('Time steps')
#                 # # plt.ylabel('Saliency')
#                 # # plt.ylim(0, 1)
#                 # # plt.title('Grad-CAM Saliency Map')
#                 # # plt.legend()
#                 #
#                 # plt.tight_layout()
#                 #
#                 # plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"
#                 #
#                 # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_gradcam.png"), dpi=150)
#                 # # plt.close()
#                 #
#                 # plt.show()
#
#             csv_file.close()
#
#         cur_res_dict[key] = cur_results
#         pkl_path = os.path.join(dir_path, f"{key}_results")
#         pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))
#


In [None]:
# """
# testing only ppsp_1up head
# """
# import warnings, pickle
# import os, csv
# warnings.filterwarnings("ignore")
# mode = "block"
# # all_combs_lists = [[0], [3], [0, 1, 2, 3]]
# all_combs_lists = [[3]]
# import torch
# from models import fpn_2
# import matplotlib.pyplot as plt
# from models import ppsp_1up_head
# import fuzzy_logic as fl
# device = 'cpu'
#
#
# with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_{mode}_test.pkl", "rb") as f:
#     loaded_dict_test = pickle.load(f)
#
# for ind, comb_lst in enumerate(all_combs_lists):
#
#     dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
#     os.makedirs(dir_path, exist_ok=True)
#
#     ## load pretrained ppsp
#     trained_model = fpn_2.PPSP(in_channels=len(comb_lst),out_channels=32)
#     model_name = f"./model_weights/ppsp_weights.pth"
#     trained_model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
#
#     ### load the ppsp_1up head
#     ppsp_1up = ppsp_1up_head.PPSP_withFundamental(pretrained_ppsp=trained_model, freeze=True, hidden=256)
#     ppsp1up_model_name=f"../best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     ppsp_1up.load_state_dict((torch.load(f"{ppsp1up_model_name}", map_location=torch.device('cpu'))))
#     ppsp_1up.eval()
#
#     processed_test = process_test_dict(
#         loaded_dict_test,
#         row_indices=comb_lst,
#         col_size=1024,
#         output_format="channels_last"  # or "channels_first"
#     )
#     # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")
#
#     for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
#         print(f"Key: {key}")
#         prediction_lst = []
#         freq_prediction_lst, fund_freq_lst = [], []
#         cur_res_dict = {}
#         if key == key:
#             cur_results = []
#             csv_path = os.path.join(dir_path, f"{key}_results.csv")
#             csv_file = open(csv_path, mode="w", newline="")
#             writer = csv.writer(csv_file)
#             writer.writerow(["filename", "gtruth_fund_freq", "predicted_fund_freq", "predicted_fund_freq_lst","clusters"])
#             for ind in range(len(distances)):
#                 if ind == ind:
#                     cur_original_sig = orig_sig[ind]
#                     cur_fund_freq = fund_freq[ind]
#                     cur_fil_name = file_names[ind]
#
#                     # cur_x = np.array(X_test[ind])[np.newaxis,:,:,:]
#                     cur_x = np.array(X_test[ind])[:, :, :]
#                     torch_x = torch.FloatTensor(cur_x).to('cpu')
#                     cur_y = np.array(y_test[ind])
#
#                     cur_prediction = ppsp_1up(torch_x)
#
#                     fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     harmonic_pred = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#
#                     cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (
#                             np.max(fund_pred) - np.min(fund_pred) + 1e-12)
#                     cur_harm_prediction_norm = (harmonic_pred - np.min(harmonic_pred)) / (
#                             np.max(harmonic_pred) - np.min(harmonic_pred) + 1e-12)
#                     binary_cur_truth = cur_y
#
#                     # if cur_fil_name in ["fan_1721415536_7910_35cm_0"]:
#                     # print(cur_fil_name)
#                     ### apply thresholding
#                     bin_harmonic_pred = np.where(cur_harm_prediction_norm <= 0.5, 0, 1)
#                     bin_fund_pred = np.where(cur_fund_prediction_norm > 0.01, 1, 0)
#
#                     ###
#                     fund_regions_lst, fund_central_freq_lst = fl.find_windows(bin_fund_pred)
#                     harm_regions_lst, harm_central_freq_lst = fl.find_windows(bin_harmonic_pred)
#                     ###
#                     clusters = fl.harmonic_clustering(fund_central_freq_lst, harm_central_freq_lst, tolerance=5,
#                                                       max_harmonic=15, max_freq=1024)
#                     # if cur_fil_name in ["fan_1721413988_2795_90cm_2"]:
#                     if len(clusters) > 0:
#
#                         fine_pred_freq_lst = []
#                         ### find the region (start and end of these f0 candidates) for finer speed estimation
#                         for cur_ind in range(len(clusters)):
#                             # cur_pred_freq = clusters[cur_ind]['f0']
#                             # if clusters[cur_ind]['match_count']>1:
#                             try:
#                                 index_of_region = fund_central_freq_lst.index(clusters[cur_ind]['f0'])
#                                 region_indexes = fund_regions_lst[index_of_region]
#                                 strt_ind, end_ind = region_indexes[0], region_indexes[1]
#                                 cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
#                                                                      int((strt_ind + end_ind) // 2))
#                                 fine_pred_freq_lst.append(cur_fine_pred_freq)
#
#                             except:
#                                 index_of_region = harm_central_freq_lst.index(clusters[cur_ind]['f0'])
#                                 region_indexes = harm_regions_lst[index_of_region]
#                                 strt_ind, end_ind = region_indexes[0], region_indexes[1]
#                                 cur_fine_pred_freq = fl.predict_freq(cur_original_sig, cur_original_sig, strt_ind, end_ind,
#                                                                      int((strt_ind + end_ind) // 2))
#                                 fine_pred_freq_lst.append(cur_fine_pred_freq)
#                             # else:
#                             #     fine_pred_freq_lst.append(0)
#
#
#                     else:
#                         fine_pred_freq_lst = [0]
#
#                     print(f"filename={cur_fil_name} gtruth={cur_fund_freq} predicted={fine_pred_freq_lst}")
#
#                     ### fine tuning
#                     fine_pred_freq_arr = np.array(fine_pred_freq_lst)
#                     diff_arr = np.abs(fine_pred_freq_arr - cur_fund_freq)
#                     min_idx = np.argmin(diff_arr)
#
#                     fine_pred_freq = fine_pred_freq_lst[min_idx]
#                     """
#                     write results to file
#                     """
#                     writer.writerow([cur_fil_name, cur_fund_freq,fine_pred_freq, fine_pred_freq_lst, clusters])
#                     cur_results.append(
#                         [[bin_fund_pred, bin_harmonic_pred], binary_cur_truth, cur_x, cur_fund_freq, fine_pred_freq,
#                          cur_fil_name, clusters,fine_pred_freq_lst,])
#
#                     """
#                     plot figures
#                     """
#                     # plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"
#                     #
#                     # plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
#                     # for ind_x in range(cur_x.squeeze(0).shape[0]):
#                     #     plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)
#                     #
#                     # plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)
#                     #
#                     #
#                     # plt.plot(cur_fund_prediction_norm, linewidth=0.8, label=" f0",alpha=0.8)
#                     # plt.plot(harmonic_pred, '--', linewidth=0.8, label="raw all_harmonics",alpha=0.8)
#                     #
#                     # plt.legend(loc='lower right')
#                     #
#                     # # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
#                     # # plt.close()
#                     #
#                     # plt.show()
#
#             csv_file.close()
#
#         cur_res_dict[key] = cur_results
#         pkl_path = os.path.join(dir_path, f"{key}_results")
#         pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))

In [None]:
# """
# testing only ppsp
# """
# import warnings, pickle
# import os
# warnings.filterwarnings("ignore")
# mode = "block"
# # all_combs_lists = [[0], [3], [0, 1, 2, 3]]
# all_combs_lists = [[0],[3]]
# import torch
# from models import fpn_2
# from models import ppsp_1up
# import matplotlib.pyplot as plt
# import torch.nn.functional as F
#
# device = 'cpu'
#
#
# with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_{mode}_test.pkl", "rb") as f:
#     loaded_dict_test = pickle.load(f)
#
# for ind, comb_lst in enumerate(all_combs_lists):
#
#     dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
#     os.makedirs(dir_path, exist_ok=True)
#
#     ppsp_backbone = fpn_2.PPSP(in_channels=len(comb_lst),out_channels=32)
#     ppsp_1up = ppsp_1up.PPSP_1up(ppsp_backbone=ppsp_backbone, hidden_nodes=256)
#
#     model_name = f"../best_fpn2_1up_model_{mode}_{comb_lst}.pth"
#     ppsp_1up.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
#     ppsp_1up.eval()
#
#     processed_test = process_test_dict(
#         loaded_dict_test,
#         row_indices=comb_lst,
#         col_size=1024,
#         output_format="channels_last"  # or "channels_first"
#     )
#     # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")
#
#     for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
#         print(f"Key: {key}")
#         prediction_lst = []
#         freq_prediction_lst, fund_freq_lst = [], []
#         if key == key:
#             for ind in range(len(distances)):
#                 if ind == ind:
#                     cur_fund_freq = fund_freq[ind]
#                     cur_fil_name = file_names[ind]
#
#                     # cur_x = np.array(X_test[ind])[np.newaxis,:,:,:]
#                     cur_x = np.array(X_test[ind])[:, :, :]
#                     torch_x = torch.FloatTensor(cur_x).to('cpu')
#                     cur_y = np.array(y_test[ind])
#
#                     torch_x.requires_grad_()
#                     cur_prediction = ppsp_1up(torch_x)
#
#                     fund_pred = torch.sigmoid(cur_prediction[0]).squeeze(1).squeeze(0).detach().cpu().numpy()
#                     harmonic_pred = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#
#                     # all_harmonic_pred1 = torch.sigmoid(cur_prediction[1]).squeeze(1).squeeze(0).detach().cpu().numpy()
#
#
#                     cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)
#                     # cur_harmonic_prediction_norm1 = (all_harmonic_pred1 - np.min(all_harmonic_pred1)) / (np.max(all_harmonic_pred1) - np.min(all_harmonic_pred1) + 1e-12)
#
#
#                     binary_cur_truth=cur_y
#
#
#                     plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"
#
#                     plt.title(f"{cur_fund_freq} - {cur_fil_name}_{comb_lst}")
#                     for ind_x in range(cur_x.squeeze(0).shape[0]):
#                         plt.plot(cur_x.squeeze(0)[ind_x],linewidth=1.2)
#
#                     plt.plot(binary_cur_truth, linewidth=1.1, label="True",alpha=0.8)
#
#
#                     plt.plot(cur_fund_prediction_norm, linewidth=0.8, label=" f0",alpha=0.8)
#                     plt.plot(harmonic_pred, '--', linewidth=0.8, label="raw all_harmonics",alpha=0.8)
#                     # plt.plot(cur_harmonic_prediction_norm1, '--', linewidth=0.7, label="raw all_harmonics",alpha=0.8)
#
#                     plt.legend(loc='lower right')
#
#                     # plt.savefig(os.path.join(dir_path, f"{plt_fil_name}.png"), dpi=150)
#                     # plt.close()
#
#                     plt.show()
