In [None]:
%load_ext autoreload
%autoreload 1
# %matplotlib widget
import matplotlib
import librosa
import numpy as np
import torch
import torch.nn.functional as F

matplotlib.use('QT5Agg')

In [None]:
def process_test_dict(loaded_dict, row_indices=[0, 4], col_size=512, output_format="channels_first"):
    processed_data = {}

    for key, (test_x, test_y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst) in loaded_dict.items():
        test_x = np.array(test_x)
        test_y = np.array(test_y)
        distances_lst = np.array(distances_lst)
        fund_freq_lst = np.array(fund_freq_lst)
        orig_signal_lst = np.array(orig_signal_lst)

        # Select rows and columns size (preserve original order - no sorting)
        X = test_x[:, row_indices, :col_size]
        y = test_y

        # Reshape based on desired output format
        if output_format == "channels_first":
            X = X[:, :, np.newaxis, :]  # [N, num_channels, 1, col_size]
        elif output_format == "channels_last":

            X = X[:, np.newaxis, :, :]  # [N, 1, num_channels, col_size]
        else:
            raise ValueError("output_format must be 'channels_first' or 'channels_last'")

        # Return all data in original order (no train/val split)
        processed_data[key] = (X, y, fund_freq_lst, distances_lst, file_names_lst, orig_signal_lst)

    return processed_data

In [None]:
import os

print(os.getcwd())

In [None]:
### GradCAM class
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        ### Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, x, target_category=None):
        ### Forward pass
        output = self.model(x)

        if target_category is None:
            ### Use the fundamental prediction [batch_size, 1, 1024] for Grad-CAM
            target = output[0]
        else:
            target = output[target_category]

        ### Zero gradients
        self.model.zero_grad()

        ### For Grad-CAM, we need a scalar value (mean)
        target_scalar = target.mean()

        ### Backward pass for target
        target_scalar.backward(retain_graph=True)

        #### Get gradients and activations
        # [batch_size, channels, length]
        gradients = self.gradients
        activations = self.activations

        ### Global average pooling of gradients across spatial dimension (length)
        weights = torch.mean(gradients, dim=2)

        # Weight the activations
        batch_size, channels, length = activations.shape
        cam = torch.zeros(batch_size, length, device=activations.device)

        for i in range(batch_size):
            for j in range(channels):
                cam[i] += weights[i, j] * activations[i, j, :]

        # Apply ReLU
        cam = F.relu(cam)

        # Normalize
        cam = cam - cam.min(dim=1, keepdim=True)[0]
        cam = cam / (cam.max(dim=1, keepdim=True)[0] + 1e-8)

        return cam.detach().cpu().numpy(), output

In [None]:
# CORRECTED Grad-CAM for FPN_2_mtl model (from crepe.py)
class GradCAM_FPN2MTL:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, x, target_category=None):
        # Forward pass
        output = self.model(x)

        # For FPN_2_mtl, output is already the fundamental prediction
        target = output

        # Zero gradients
        self.model.zero_grad()

        # Convert to scalar for backprop - use the actual prediction values
        # We need to backpropagate through the actual predictions, not just the mean
        target_scalar = target.mean()

        # Backward pass for target
        target_scalar.backward(retain_graph=True)

        # Get gradients and activations
        if self.gradients is None or self.activations is None:
            raise RuntimeError("Gradients or activations not captured. Check target layer.")

        gradients = self.gradients  # [batch_size, channels, length]
        activations = self.activations  # [batch_size, channels, length]

        # Global average pooling of gradients across spatial dimension
        weights = torch.mean(gradients, dim=2)  # [batch_size, channels]

        # Weight the activations
        batch_size, channels, length = activations.shape
        cam = torch.zeros(batch_size, length, device=activations.device)

        for i in range(batch_size):
            for j in range(channels):
                cam[i] += weights[i, j] * activations[i, j, :]

        # Apply ReLU
        cam = F.relu(cam)

        # Normalize
        if cam.max() - cam.min() > 1e-8:
            cam = (cam - cam.min()) / (cam.max() - cam.min())
        else:
            cam = torch.zeros_like(cam)

        return cam.detach().cpu().numpy(), output

In [None]:
# CORRECTED Grad-CAM for PPSP model (from fpn_2.py)
class GradCAM_PPSP:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, x, target_category=None):
        # Forward pass
        output = self.model(x)

        # For PPSP, output is the final output3 [batch_size, 1, 1024]
        target = output

        # Zero gradients
        self.model.zero_grad()

        # Convert to scalar for backprop
        target_scalar = target.mean()

        # Backward pass for target
        target_scalar.backward(retain_graph=True)

        # Get gradients and activations
        if self.gradients is None or self.activations is None:
            raise RuntimeError("Gradients or activations not captured. Check target layer.")

        gradients = self.gradients  # [batch_size, channels, length]
        activations = self.activations  # [batch_size, channels, length]

        # Global average pooling of gradients across spatial dimension
        weights = torch.mean(gradients, dim=2)  # [batch_size, channels]

        # Weight the activations
        batch_size, channels, length = activations.shape
        cam = torch.zeros(batch_size, length, device=activations.device)

        for i in range(batch_size):
            for j in range(channels):
                cam[i] += weights[i, j] * activations[i, j, :]

        # Apply ReLU
        cam = F.relu(cam)

        # Normalize
        if cam.max() - cam.min() > 1e-8:
            cam = (cam - cam.min()) / (cam.max() - cam.min())
        else:
            cam = torch.zeros_like(cam)

        return cam.detach().cpu().numpy(), output

In [None]:
"""
testing crepe and ppsp - CORRECTED VERSION
"""
import warnings, pickle
import os, csv

warnings.filterwarnings("ignore")
mode = "block"
all_combs_lists = [[3]]
import torch
from models import fpn_2
from models import crepe
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import fuzzy_logic as fl

device = 'cpu'

with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_block_test.pkl", "rb") as f:
    loaded_dict_test = pickle.load(f)

for ind, comb_lst in enumerate(all_combs_lists):

    dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
    os.makedirs(dir_path, exist_ok=True)

    # Initialize FPN_2_mtl model
    crepe_model = crepe.FPN_2_mtl(in_channels=len(comb_lst))

    model_name = f"../1up_weights/best_fpn2_1up_model_gaussian_[3].pth"
    crepe_model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
    crepe_model.eval()

    ### CORRECTED: Use the LAST convolutional layer before flattening
    # Target the final conv_block10 in the encoder pathway
    target_layer_crepe = crepe_model.conv_block10.c
    grad_cam_crepe = GradCAM_FPN2MTL(crepe_model, target_layer_crepe)

    processed_test = process_test_dict(
        loaded_dict_test,
        row_indices=comb_lst,
        col_size=1024,
        output_format="channels_last"
    )
    print(f"Using device: {device}")

    for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
        print(f"Key: {key}")
        cur_res_dict = {}
        if key == key:
            # cur_results = []
            # csv_path = os.path.join(dir_path, f"{key}_results.csv")
            # csv_file = open(csv_path, mode="w", newline="")
            # writer = csv.writer(csv_file)
            # writer.writerow(["filename", "gtruth_fund_freq", "predicted_fund_freq", "predicted_fund_freq_lst"])

            for sample_ind in range(len(distances)):  # Test with 3 samples
                cur_original_sig = orig_sig[sample_ind]
                cur_fund_freq = fund_freq[sample_ind]
                cur_fil_name = file_names[sample_ind]

                cur_x = np.array(X_test[sample_ind])[:, :, :]
                torch_x = torch.FloatTensor(cur_x).to('cpu')
                cur_y = np.array(y_test[sample_ind])

                ### Enable gradients for input
                torch_x.requires_grad_()

                with torch.enable_grad():
                    cur_prediction = crepe_model(torch_x)

                fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()

                cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (
                        np.max(fund_pred) - np.min(fund_pred) + 1e-12)
                binary_cur_truth = cur_y

                ### apply thresholding
                bin_fund_pred = np.where(cur_fund_prediction_norm > 0.01, 1, 0)

                ### Find windows
                # fund_regions_lst, fund_central_freq_lst = fl.find_windows(bin_fund_pred)
                #
                # fine_pred_freq_lst = [fund_central_freq_lst[0]] if len(fund_central_freq_lst) > 0 else [0]
                # print(f"filename={cur_fil_name} gtruth={cur_fund_freq} predicted={fine_pred_freq_lst}")
                #
                # fine_pred_freq = fund_central_freq_lst[0] if len(fund_central_freq_lst) > 0 else 0
                #
                # writer.writerow([cur_fil_name, cur_fund_freq, fine_pred_freq, fine_pred_freq_lst])
                # cur_results.append(
                #     [[bin_fund_pred], binary_cur_truth, cur_x, cur_fund_freq, fine_pred_freq,
                #      cur_fil_name, fine_pred_freq_lst,])

                ### Generate Grad-CAM for FPN_2_mtl

                with torch.enable_grad():
                    cam_crepe, _ = grad_cam_crepe(torch_x)

                ### Convert CAM to same length as input
                cam_crepe = cam_crepe.squeeze()

                if cam_crepe.ndim > 1:
                    cam_crepe = cam_crepe[0]

                # The CAM might be shorter than 1024 due to downsampling, so we need to interpolate
                cam_resized = np.interp(np.linspace(0, len(cam_crepe)-1, 1024),
                                       np.arange(len(cam_crepe)), cam_crepe)

        #         if distances[sample_ind] not in cur_res_dict:
        #             cur_res_dict[distances[sample_ind]] = [[cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam_crepe,cam_resized]]
        #         else:
        #             cur_res_dict[distances[sample_ind]].append([cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam_crepe,cam_resized])
        #
        # pkl_path = os.path.join(dir_path, f"{key}_crepe_cam")
        # pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))

                # Plot with Grad-CAM
                fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

                # Plot 1: Input signals with predictions
                for ind_x in range(cur_x.squeeze(0).shape[0]):
                    ax1.plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7, label=f'Input {ind_x}' if ind_x == 0 else "")

                ax1.plot(binary_cur_truth, linewidth=1.5, label="Ground Truth", alpha=0.9, color='black')
                ax1.plot(cur_fund_prediction_norm, linewidth=1.5, label="f0 Prediction", alpha=0.8, color='blue')
                ax1.plot(bin_fund_pred, linewidth=1.2, label="Binary f0", alpha=0.6, color='green')
                ax1.legend(loc='upper right')
                ax1.set_title(f"FPN_2_mtl - {cur_fund_freq}Hz - {cur_fil_name}")
                ax1.set_ylabel('Amplitude')
                ax1.grid(True, alpha=0.3)

                # Plot 2: Grad-CAM saliency
                ax2.plot(cam_resized, color='red', linewidth=2, label='Saliency')
                ax2.fill_between(range(len(cam_resized)), cam_resized, alpha=0.3, color='red')
                ax2.set_xlabel('Time steps')
                ax2.set_ylabel('Saliency')
                ax2.set_ylim(0, 1)
                ax2.set_title('Grad-CAM Saliency Map (Which features the model used for prediction)')
                ax2.legend()
                ax2.grid(True, alpha=0.3)

                plt.tight_layout()

                plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}_crepe"
                plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_gradcam_crepe.png"), dpi=150)
                plt.close()

                # plt.show()



            # csv_file.close()

        # cur_res_dict[key] = cur_results
        # pkl_path = os.path.join(dir_path, f"{key}_results")
        # pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))

In [None]:
"""
testing only ppsp - CORRECTED VERSION
"""
import warnings, pickle
import os
warnings.filterwarnings("ignore")
mode = "block"
all_combs_lists = [[3]]
import torch
from models import fpn_2
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

device = 'cpu'

with open(f"../conv2d_data/conv2d_psd_scaled_sfnds_1up_{mode}_test.pkl", "rb") as f:
    loaded_dict_test = pickle.load(f)

for ind, comb_lst in enumerate(all_combs_lists):

    dir_path = f"./conv2d_data/pred_plots/{comb_lst}/"
    os.makedirs(dir_path, exist_ok=True)

    # Initialize PPSP model (not ppsp_1up)
    ppsp_model = fpn_2.PPSP(in_channels=len(comb_lst), out_channels=32)

    model_name = f"../1ppsp_weights/best_model_weights_fan5_fan3_bldc_fpn2"  # Update with your actual PPSP weights path
    ppsp_model.load_state_dict(torch.load(f"{model_name}", map_location=torch.device('cpu')))
    ppsp_model.eval()

    ### CORRECTED: Choose appropriate target layer for PPSP
    # Option 1: Late decoder layer (good for final predictions)
    target_layer_ppsp = ppsp_model.conv_output1.c

    # Option 2: Bottleneck layer (good for understanding encoder features)
    # target_layer_ppsp = ppsp_model.conv_block10.c

    # Option 3: Early decoder layer (good for understanding feature fusion)
    # target_layer_ppsp = ppsp_model.conv_concat.c

    grad_cam_ppsp = GradCAM_PPSP(ppsp_model, target_layer_ppsp)

    processed_test = process_test_dict(
        loaded_dict_test,
        row_indices=comb_lst,
        col_size=1024,
        output_format="channels_last"
    )
    print(f"Using device: {device}")

    for key, (X_test, y_test, fund_freq, distances, file_names, orig_sig) in processed_test.items():
        cur_res_dict = {}
        print(f"Key: {key}")

        for sample_ind in range( len(distances)):  # Test with 3 samples
            cur_fund_freq = fund_freq[sample_ind]
            cur_fil_name = file_names[sample_ind]

            cur_x = np.array(X_test[sample_ind])[:, :, :]
            torch_x = torch.FloatTensor(cur_x).to('cpu')
            cur_y = np.array(y_test[sample_ind])

            torch_x.requires_grad_()

            with torch.enable_grad():
                cur_prediction = ppsp_model(torch_x)

            fund_pred = torch.sigmoid(cur_prediction).squeeze(1).squeeze(0).detach().cpu().numpy()

            cur_fund_prediction_norm = (fund_pred - np.min(fund_pred)) / (np.max(fund_pred) - np.min(fund_pred) + 1e-12)
            binary_cur_truth = cur_y
            bin_fund_pred = np.where(cur_fund_prediction_norm > 0.01, 1, 0)

            plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}"

            ### Generate Grad-CAM for PPSP
            # try:
            with torch.enable_grad():
                cam_ppsp, _ = grad_cam_ppsp(torch_x)

            ### Convert CAM to same length as input
            cam_ppsp = cam_ppsp.squeeze()

            if cam_ppsp.ndim > 1:
                cam_ppsp = cam_ppsp[0]

            # Interpolate CAM to match input length (1024)
            cam_resized = np.interp(np.linspace(0, len(cam_ppsp)-1, 1024),
                                   np.arange(len(cam_ppsp)), cam_ppsp)

        #     if distances[sample_ind] not in cur_res_dict:
        #         cur_res_dict[distances[sample_ind]] = [[cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam_ppsp,cam_resized]]
        #     else:
        #         cur_res_dict[distances[sample_ind]].append([cur_fil_name, cur_fund_freq,X_test[sample_ind],y_test[sample_ind],fund_pred,cur_fund_prediction_norm,binary_cur_truth,cam_ppsp,cam_resized])
        #
        # pkl_path = os.path.join(dir_path, f"{key}_ppsp_cam")
        # pickle.dump([cur_res_dict], open(f"{pkl_path}", "wb"))

            # Create comprehensive visualization
            fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 10))

            # Plot 1: Input signals
            for ind_x in range(cur_x.squeeze(0).shape[0]):
                ax1.plot(cur_x.squeeze(0)[ind_x], linewidth=1.2, alpha=0.7,
                        label=f'Channel {ind_x}' if ind_x == 0 else "")

            ax1.plot(binary_cur_truth, linewidth=2, label="Ground Truth", alpha=0.9, color='black')
            ax1.legend(loc='upper right')
            ax1.set_title(f"PPSP - Input Signals - {cur_fund_freq}Hz - {cur_fil_name}")
            ax1.set_ylabel('Amplitude')
            ax1.grid(True, alpha=0.3)

            # Plot 2: Predictions with saliency overlay
            ax2.plot(cur_fund_prediction_norm, linewidth=2, label="f0 Prediction", alpha=0.8, color='blue')
            ax2.plot(binary_cur_truth, linewidth=1.5, label="Ground Truth", alpha=0.6, color='black', linestyle='--')

            # Add saliency as background color
            x_axis = np.arange(len(cam_resized))
            ax2.fill_between(x_axis, 0, 1, where=cam_resized > 0.7,
                           alpha=0.3, color='red', label='High Saliency')
            ax2.fill_between(x_axis, 0, 1, where=(cam_resized > 0.3) & (cam_resized <= 0.7),
                           alpha=0.2, color='orange', label='Medium Saliency')

            ax2.legend(loc='upper right')
            ax2.set_ylabel('Normalized Output')
            ax2.set_ylim(0, 1)
            ax2.set_title('Model Predictions with Saliency Overlay')
            ax2.grid(True, alpha=0.3)

            # Plot 3: Saliency map
            ax3.plot(cam_resized, color='red', linewidth=2, label='Saliency')
            ax3.fill_between(range(len(cam_resized)), cam_resized, alpha=0.3, color='red')
            ax3.set_xlabel('Time steps')
            ax3.set_ylabel('Saliency')
            ax3.set_ylim(0, 1)
            ax3.set_title('Grad-CAM Saliency Map (Decoder Features Used for Prediction)')
            ax3.legend()
            ax3.grid(True, alpha=0.3)

            plt.tight_layout()

            plt_fil_name = f"{cur_fund_freq}-{cur_fil_name}_{comb_lst}_ppsp"
            plt.savefig(os.path.join(dir_path, f"{plt_fil_name}_gradcam_ppsp.png"), dpi=150)
            plt.close()

            # plt.show()
            #
            # # Print some diagnostic information
            # print(f"PPSP - CAM range: [{cam_resized.min():.3f}, {cam_resized.max():.3f}]")
            # print(f"PPSP - High saliency regions: {np.sum(cam_resized > 0.7)} timepoints")

