In [None]:
import torch
import torch.nn as nn
import numpy as np
import librosa
import matplotlib.pyplot as plt
from torchvision import models
from captum.attr import IntegratedGradients
from scipy.ndimage import gaussian_filter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义 ResNet 模型类
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ResNetClassifier, self).__init__()
        self.resnet = models.resnet18(weights=None)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# 加载模型和预训练权重
model = ResNetClassifier(num_classes=2)
model.load_state_dict(torch.load("./resnet18_cla.pth", map_location=
                                 device, weights_only=True))
model.eval()


# 处理音频文件 和训练时一样
def process_audio(audio_path):
    audio, sr = librosa.load(audio_path, sr=None) 
    mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=2048, hop_length=512, n_mels=128)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) 
    mel_spec_db_resized = cv2.resize(mel_spec_db, (224, 224), interpolation=cv2.INTER_CUBIC)
    spectrogram_tensor = torch.tensor(mel_spec_db_resized[np.newaxis, np.newaxis, :, :], dtype=torch.float32) 
    return spectrogram_tensor


#  Visulaisation

In [None]:
import cv2
import numpy as np
import torch
import librosa
import matplotlib.pyplot as plt
from captum.attr import IntegratedGradients, LayerGradCam, Saliency, Occlusion,Lime
from matplotlib.colors import LinearSegmentedColormap
from scipy.ndimage import zoom

# 自定义颜色映射
def create_red_blue_transparent_cmap():
    cdict = {
        'red':   [(0.0, 0.0, 0.0), (0.5, 1.0, 1.0), (1.0, 1.0, 1.0)],
        'green': [(0.0, 0.0, 0.0), (0.5, 1.0, 1.0), (1.0, 0.0, 0.0)],
        'blue':  [(0.0, 1.0, 1.0), (0.5, 1.0, 1.0), (1.0, 0.0, 0.0)],
        'alpha': [(0.0, 1.0, 1.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)]
    }
    return LinearSegmentedColormap('RedBlueTransparent', segmentdata=cdict)





# IG
def generate_ig_heatmap(model, input_tensor, baseline=None, target_class=None):
    if baseline is None:
        baseline = torch.zeros_like(input_tensor)
    input_tensor.requires_grad_()

    ig = IntegratedGradients(model)
    attributions, _ = ig.attribute(input_tensor, baseline, target=target_class, return_convergence_delta=True)

    return attributions.squeeze().detach().numpy()



# Lime
def generate_lime_heatmap(
    model, spectrogram, target_class, grid_size=(28, 28), n_samples=500, top_percent=10
):

    lime = Lime(model)


    feature_mask = create_feature_mask(spectrogram, grid_size=grid_size)

    attributions = lime.attribute(
        inputs=spectrogram,
        target=target_class,
        n_samples=n_samples,
        feature_mask=torch.tensor(feature_mask, dtype=torch.long).unsqueeze(0),
    )


    attributions_np = attributions.squeeze().detach().numpy()

    pos_mask = attributions_np >= np.percentile(attributions_np, 100 - top_percent)
    neg_mask = attributions_np <= np.percentile(attributions_np, top_percent)
    highlighted = np.zeros_like(attributions_np)
    highlighted[pos_mask] = attributions_np[pos_mask]
    highlighted[neg_mask] = attributions_np[neg_mask]

    # 归一化
    highlighted_normalized = (highlighted - highlighted.min()) / (
        highlighted.max() - highlighted.min()
    )
    return highlighted_normalized


# Grad-CAM
def generate_gradcam_heatmap(model, input_tensor, target_class=None, layer_name=None, percentile=90):
    layer_gc = LayerGradCam(model, layer_name)
    attributions = layer_gc.attribute(input_tensor, target=target_class)
    attributions_np = attributions.squeeze().cpu().detach().numpy()


    positive_attributions = np.maximum(attributions_np, 0)  
    negative_attributions = np.minimum(attributions_np, 0)


    if positive_attributions.max() > 0:
        positive_attributions /= positive_attributions.max()


    if np.any(negative_attributions < 0): 
        threshold_neg = np.percentile(np.abs(negative_attributions[negative_attributions < 0]), percentile)
        negative_attributions[negative_attributions > -threshold_neg] = 0
        negative_attributions /= np.abs(negative_attributions).max()
    else:
        negative_attributions[:] = 0  


    attributions_filtered = positive_attributions + negative_attributions

    return cv2.resize(attributions_filtered, (input_tensor.shape[-1], input_tensor.shape[-2]))

# CAM
def generate_cam_heatmap(model, input_tensor, target_class=None, percentile=90):
    features = []


    def forward_hook(module, input, output):
        features.append(output)

  
    hook = model.resnet.layer4.register_forward_hook(forward_hook)
    with torch.no_grad():
        model(input_tensor)
    hook.remove()

    feature_map = features[0].squeeze().cpu().numpy()
    weights = model.resnet.fc.weight[target_class].detach().cpu().numpy()
    cam = np.zeros(feature_map.shape[1:], dtype=np.float32)


    for i, w in enumerate(weights):
        cam += w * feature_map[i]


    positive_cam = np.maximum(cam, 0)  
    negative_cam = np.minimum(cam, 0) 


    if positive_cam.max() > 0:
        threshold_pos = np.percentile(positive_cam[positive_cam > 0], percentile)
        positive_cam[positive_cam < threshold_pos] = 0

    if np.any(negative_cam < 0): 
        threshold_neg = np.percentile(np.abs(negative_cam[negative_cam < 0]), percentile)
        negative_cam[negative_cam > -threshold_neg] = 0
        negative_cam /= np.abs(negative_cam).max()
    else:
        negative_cam[:] = 0 


    cam_filtered = positive_cam + negative_cam

    return cv2.resize(cam_filtered, (input_tensor.shape[-1], input_tensor.shape[-2]))



# Occlusion

def generate_occlusion_heatmap(model, input_tensor, target_class=None, strides=(1, 8, 8), window_shapes=(1, 8, 8), top_percent=10):

on = Occlusion(model)
    attributions = occlusion.attribute(
        input_tensor,
        strides=strides,
        sliding_window_shapes=window_shapes,
        target=target_class
    )

    attributions_np = attributions.squeeze().detach().numpy()


    pos_mask = attributions_np >= np.percentile(attributions_np, 100 - top_percent)
    neg_mask = attributions_np <= np.percentile(attributions_np, top_percent)
    highlighted = np.zeros_like(attributions_np)
    highlighted[pos_mask] = attributions_np[pos_mask]
    highlighted[neg_mask] = attributions_np[neg_mask]


    highlighted_normalized = (highlighted - highlighted.min()) / (highlighted.max() - highlighted.min())
    return highlighted_normalized


# for occulsion
def create_feature_mask(spectrogram, grid_size=(28, 28)):

    
    spectrogram_np = spectrogram.detach().cpu().squeeze().numpy()
    mask = np.zeros_like(spectrogram_np, dtype=int)


    grid_x = np.linspace(0, spectrogram_np.shape[0], grid_size[0] + 1, dtype=int)
    grid_y = np.linspace(0, spectrogram_np.shape[1], grid_size[1] + 1, dtype=int)

    group_idx = 0
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            mask[grid_x[i]:grid_x[i + 1], grid_y[j]:grid_y[j + 1]] = group_idx
            group_idx += 1

    return mask



def visualize_all_with_lime(original_spec, ig_attributions, gc_attributions, cam_attributions, occlusion_map, lime_map, sr, hop_length):
    cmap = create_red_blue_transparent_cmap()

    zoom_factor = (original_spec.shape[0] / ig_attributions.shape[0],
                   original_spec.shape[1] / ig_attributions.shape[1])
    ig_attributions_resized = zoom(ig_attributions, zoom_factor, order=1)
    gc_attributions_resized = zoom(gc_attributions, zoom_factor, order=1)
    cam_attributions_resized = zoom(cam_attributions, zoom_factor, order=1)
    occlusion_resized = zoom(occlusion_map, zoom_factor, order=1)
    lime_resized = zoom(lime_map, zoom_factor, order=1)

    time_axis = np.linspace(0, original_spec.shape[1] * hop_length / sr, original_spec.shape[1])
    freq_axis = np.linspace(0, sr / 2, original_spec.shape[0])

    fig, axes = plt.subplots(1, 6, figsize=(30, 6))

    # 原始 Mel-spectrogram
    axes[0].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[0].set_title("Original Mel-Spectrogram")
    axes[0].set_xlabel("Time (s)")
    axes[0].set_ylabel("Frequency (Hz)")

    # IG
    axes[1].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[1].imshow(ig_attributions_resized, cmap=cmap, aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()], alpha=0.6)
    axes[1].set_title("Integrated Gradients")

    # Grad-CAM
    axes[2].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[2].imshow(gc_attributions_resized, cmap=cmap, aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()], alpha=0.6)
    axes[2].set_title("Grad-CAM")

    # CAM
    axes[3].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[3].imshow(cam_attributions_resized, cmap=cmap, aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()], alpha=0.6)
    axes[3].set_title("Class Activation Map (CAM)")

    # Occlusion
    axes[4].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[4].imshow(occlusion_resized, cmap=cmap, aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()], alpha=0.6)
    axes[4].set_title("Occlusion")

    # LIME
    axes[5].imshow(original_spec, cmap='gray', aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()])
    axes[5].imshow(lime_resized, cmap=cmap, aspect='auto', origin='lower',
                   extent=[time_axis.min(), time_axis.max(), freq_axis.min(), freq_axis.max()], alpha=0.6)
    axes[5].set_title("LIME")

    plt.tight_layout()
    plt.show()




audio_path = "./FakeMusicCaps/MusicCaps/-0Gj8-vB1q4.wav" 
input_tensor, original_spec, sr, mel_resized = process_audio(audio_path)


output = model(input_tensor)
predicted_class = torch.argmax(output).item()


ig_attributions = generate_ig_heatmap(model, input_tensor, target_class=predicted_class)
gc_attributions = generate_gradcam_heatmap(model, input_tensor, target_class=predicted_class, layer_name=model.resnet.layer4)
cam_attributions = generate_cam_heatmap(model, input_tensor, target_class=predicted_class)
occlusion_map = generate_occlusion_heatmap(
    model, input_tensor, target_class=predicted_class,
    strides=(1, 8, 8), window_shapes=(1, 8, 8), top_percent=10
)
 
lime_map = generate_lime_heatmap(
    model, input_tensor, target_class=predicted_class, grid_size=(32, 32), n_samples=500, top_percent=10
)


visualize_all_with_lime(original_spec, ig_attributions, gc_attributions, cam_attributions, occlusion_map, lime_map, sr=sr, hop_length=512)


#  ablation

## single

In [None]:
#  without mask
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, recall_score
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt




val_csv = 'val_dataset_paths.csv'  
val_df = pd.read_csv(val_csv)
val_paths = val_df['File Path'].values
val_labels = val_df['Label'].values


class AudioDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mel_spec_db = np.load(self.file_paths[idx])
        mel_spec_db = torch.tensor(mel_spec_db).unsqueeze(0)  # 添加一个维度
        label = self.labels[idx]
        return mel_spec_db, label

# DataLoader
val_dataset = AudioDataset(val_paths, val_labels)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


def evaluate(model, val_loader, criterion):
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(val_loader, desc="Evaluating"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            
    
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())


    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds)
    sensitivity = recall_score(all_labels, all_preds)
    avg_loss = running_loss / len(val_loader)

    cm = confusion_matrix(all_labels, all_preds)


    print(f'Validation Accuracy: {accuracy:.4f}')
    print(f'Validation F1 Score: {f1:.4f}')
    print(f'Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Validation Loss: {avg_loss:.4f}')

    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    return accuracy, f1, sensitivity, cm


criterion = nn.CrossEntropyLoss()


accuracy, f1, sensitivity, cm = evaluate(model, val_loader, criterion)


In [None]:
# ig
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def mask_positive_contributions(input_tensor, ig_attributions, mask_threshold=90):
    positive_contributions = np.maximum(ig_attributions, 0)  # 只保留正向贡献
    threshold = np.percentile(positive_contributions, mask_threshold)  # 计算正向贡献的阈值
    mask = (positive_contributions >= threshold).astype(float)  # 生成遮盖掩码


    masked_percentage = mask.sum() / mask.size * 100
#     print(f"Masked percentage: {masked_percentage:.2f}%")

    masked_tensor = input_tensor.cpu().numpy() * (1 - mask)  # 遮盖正向贡献区域
    return torch.tensor(masked_tensor, dtype=torch.float32), masked_percentage


def compute_and_mask_ig_features(model, val_loader, mask_threshold=90, max_samples=10):
    masked_features = []
    masked_labels = []
    masked_percentages = []


    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, max_samples)
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Computing IG and masking features"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec[0].unsqueeze(0)  # [1, channels, height, width]
            output = model(input_tensor)  
            target_class = torch.argmax(output).item()

            ig_attributions = generate_ig_heatmap(model, input_tensor, target_class=target_class)


            masked_tensor, masked_percentage = mask_positive_contributions(input_tensor, ig_attributions, mask_threshold)
            masked_features.append(masked_tensor.squeeze(0))
            masked_labels.append(labels[0].item())
            masked_percentages.append(masked_percentage)


    avg_masked_percentage = np.mean(masked_percentages)
#     print(f"Average masked percentage: {avg_masked_percentage:.2f}%")

    return torch.stack(masked_features), torch.tensor(masked_labels)


def evaluate_with_masked_features(model, masked_features, masked_labels, criterion):
    model.eval()
    masked_dataset = torch.utils.data.TensorDataset(masked_features, masked_labels)
    masked_loader = DataLoader(masked_dataset, batch_size=1, shuffle=False)

    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(masked_loader, desc="Evaluating with masked features"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            loss = criterion(outputs, labels)  # 计算损失
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds)
    sensitivity = recall_score(all_labels, all_preds)
    avg_loss = running_loss / len(masked_loader)

    cm = confusion_matrix(all_labels, all_preds)  


    print(f'Masked Validation Accuracy: {accuracy:.4f}')
    print(f'Masked Validation F1 Score: {f1:.4f}')
    print(f'Masked Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Masked Validation Loss: {avg_loss:.4f}')


    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Masked Features)')
    plt.show()

    return accuracy, f1, sensitivity, cm


val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


masked_features, masked_labels = compute_and_mask_ig_features(
    model, val_loader, mask_threshold=90, max_samples=len(val_dataset)
)


masked_accuracy, masked_f1, masked_sensitivity, masked_cm = evaluate_with_masked_features(
    model, masked_features, masked_labels, criterion
)


In [None]:
# grad-cam
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def mask_positive_contributions_gradcam(input_tensor, gradcam_attributions, mask_threshold=90):
    positive_contributions = np.maximum(gradcam_attributions, 0) 
    threshold = np.percentile(positive_contributions, mask_threshold) 
    mask = (positive_contributions >= threshold).astype(float)

    total_area = np.prod(mask.shape)
    masked_area = np.sum(mask)
    masked_percentage = (masked_area / total_area) * 100
#     print(f"Masked percentage: {masked_percentage:.2f}%")
    
    masked_tensor = input_tensor.cpu().numpy() * (1 - mask)  # 遮盖正向贡献区域
    return torch.tensor(masked_tensor, dtype=torch.float32)


def compute_and_mask_gradcam_features(model, val_loader, layer_name, mask_threshold=90, max_samples=100):

    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, max_samples)
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    masked_features = []
    masked_labels = []

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Computing Grad-CAM and masking features"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec  # [1, channels, height, width]
            output = model(input_tensor)  
            target_class = torch.argmax(output).item()


            gradcam_attributions = generate_gradcam_heatmap(
                model, input_tensor, target_class=target_class, layer_name=layer_name
            )

            # 检查是否存在正向贡献，避免空数组报错
            if np.any(gradcam_attributions > 0):

                masked_tensor = mask_positive_contributions_gradcam(input_tensor, gradcam_attributions, mask_threshold)
                masked_features.append(masked_tensor.squeeze(0))  
                masked_labels.append(labels.item())
            else:
                print(f"Sample skipped: No positive contributions for this sample.")

    if len(masked_features) == 0:
        raise ValueError("No valid samples with positive contributions were found.")

    return torch.stack(masked_features), torch.tensor(masked_labels)


def evaluate_with_masked_features(model, masked_features, masked_labels, criterion):

    model.eval()
    masked_dataset = torch.utils.data.TensorDataset(masked_features, masked_labels)
    masked_loader = DataLoader(masked_dataset, batch_size=32, shuffle=False)

    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(masked_loader, desc="Evaluating with masked features"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            loss = criterion(outputs, labels) 
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds, average='weighted')
    sensitivity = recall_score(all_labels, all_preds, average='weighted')
    avg_loss = running_loss / len(masked_loader)

    cm = confusion_matrix(all_labels, all_preds)

    print(f'Masked Validation Accuracy: {accuracy:.4f}')
    print(f'Masked Validation F1 Score: {f1:.4f}')
    print(f'Masked Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Masked Validation Loss: {avg_loss:.4f}')


    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Masked Features)')
    plt.show()

    return accuracy, f1, sensitivity, cm

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

#last layer
layer_name = model.resnet.layer4  


masked_features, masked_labels = compute_and_mask_gradcam_features(model, val_loader, layer_name, mask_threshold=90, max_samples=len(val_dataset))

masked_accuracy, masked_f1, masked_sensitivity, masked_cm = evaluate_with_masked_features(
    model, masked_features, masked_labels, criterion
)


In [None]:
# cam
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def mask_positive_contributions_cam(input_tensor, cam_attributions, mask_threshold=90):
    positive_contributions = np.maximum(cam_attributions, 0)  
    threshold = np.percentile(positive_contributions, mask_threshold)  
    mask = (positive_contributions >= threshold).astype(float)  
    masked_tensor = input_tensor.cpu().numpy() * (1 - mask)  
    return torch.tensor(masked_tensor, dtype=torch.float32)


def compute_and_mask_cam_features(model, val_loader, mask_threshold=90, max_samples=1000):

    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, min(max_samples, len(all_indices))) 
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    masked_features = []
    masked_labels = []

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Computing CAM and masking features"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec  
            output = model(input_tensor) 
            target_class = torch.argmax(output).item()

            cam_attributions = generate_cam_heatmap(model, input_tensor, target_class=target_class)


            if np.any(cam_attributions > 0):

                masked_tensor = mask_positive_contributions_cam(input_tensor, cam_attributions, mask_threshold)
                masked_features.append(masked_tensor.squeeze(0))
                masked_labels.append(labels.item())
            else:
                print(f"Skipping sample with no positive contributions in CAM.")

    if len(masked_features) == 0:
        raise ValueError("No valid samples with positive contributions were found in CAM.")

    return torch.stack(masked_features), torch.tensor(masked_labels)

def evaluate_with_masked_features(model, masked_features, masked_labels, criterion):

    model.eval()
    masked_dataset = torch.utils.data.TensorDataset(masked_features, masked_labels)
    masked_loader = DataLoader(masked_dataset, batch_size=32, shuffle=False)

    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(masked_loader, desc="Evaluating with masked features"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            loss = criterion(outputs, labels)  
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds, average='weighted')
    sensitivity = recall_score(all_labels, all_preds, average='weighted')
    avg_loss = running_loss / len(masked_loader)

    cm = confusion_matrix(all_labels, all_preds) 


    print(f'Masked Validation Accuracy: {accuracy:.4f}')
    print(f'Masked Validation F1 Score: {f1:.4f}')
    print(f'Masked Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Masked Validation Loss: {avg_loss:.4f}')


    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Masked Features)')
    plt.show()

    return accuracy, f1, sensitivity, cm


val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

masked_features, masked_labels = compute_and_mask_cam_features(model, val_loader, mask_threshold=90, max_samples=len(val_dataset))

 
masked_accuracy, masked_f1, masked_sensitivity, masked_cm = evaluate_with_masked_features(
    model, masked_features, masked_labels, criterion
)


In [None]:
# occ

import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def mask_positive_contributions_occlusion(input_tensor, occlusion_attributions, mask_threshold=90):

    positive_contributions = np.maximum(occlusion_attributions, 0)  
    threshold = np.percentile(positive_contributions, mask_threshold)  
    mask = (positive_contributions >= threshold).astype(float)
    coverage = np.mean(mask) * 100  
    masked_tensor = input_tensor.cpu().numpy() * (1 - mask)
    return torch.tensor(masked_tensor, dtype=torch.float32), coverage


def compute_and_mask_occlusion_features(model, val_loader, mask_threshold=90, max_samples=10):

    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, max_samples)
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    masked_features = []
    masked_labels = []

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Computing Occlusion and masking features"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec  
            output = model(input_tensor)  
            target_class = torch.argmax(output).item()


            occlusion_attributions = generate_occlusion_heatmap(model, input_tensor, target_class=target_class)


            masked_tensor, coverage = mask_positive_contributions_occlusion(input_tensor, occlusion_attributions, mask_threshold)
#             print(f"Sample coverage: {coverage:.2f}%") 
            masked_features.append(masked_tensor.squeeze(0))
            masked_labels.append(labels.item())

    return torch.stack(masked_features), torch.tensor(masked_labels)


def evaluate_with_masked_features(model, masked_features, masked_labels, criterion):

    model.eval()
    masked_dataset = torch.utils.data.TensorDataset(masked_features, masked_labels)
    masked_loader = DataLoader(masked_dataset, batch_size=32, shuffle=False)

    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(masked_loader, desc="Evaluating with masked features"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            loss = criterion(outputs, labels)  # 计算损失
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds )
    sensitivity = recall_score(all_labels, all_preds,)
    avg_loss = running_loss / len(masked_loader)

    cm = confusion_matrix(all_labels, all_preds)  


    print(f'Masked Validation Accuracy: {accuracy:.4f}')
    print(f'Masked Validation F1 Score: {f1:.4f}')
    print(f'Masked Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Masked Validation Loss: {avg_loss:.4f}')


    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Masked Features)')
    plt.show()

    return accuracy, f1, sensitivity, cm


val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


masked_features, masked_labels = compute_and_mask_occlusion_features(
    model, val_loader, mask_threshold=90, max_samples= len(val_dataset)
)

masked_accuracy, masked_f1, masked_sensitivity, masked_cm = evaluate_with_masked_features(
    model, masked_features, masked_labels, criterion
)


In [None]:
# lime
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def mask_positive_contributions_lime(input_tensor, lime_attributions, mask_threshold=90):
    positive_contributions = np.maximum(lime_attributions, 0) 
    threshold = np.percentile(positive_contributions, mask_threshold) 
    mask = (positive_contributions >= threshold).astype(float) 
    masked_tensor = input_tensor.cpu().numpy() * (1 - mask)  


    mask_coverage = 100 * mask.sum() / mask.size
    return torch.tensor(masked_tensor, dtype=torch.float32), mask_coverage


def compute_and_mask_lime_features(model, val_loader, mask_threshold=90, grid_size=(28, 28), max_samples=10):
    masked_features = []
    masked_labels = []
    mask_percentages = []  


    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, max_samples)
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Computing LIME and masking features"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec[0].unsqueeze(0)  # [1, channels, height, width]
            output = model(input_tensor)
            target_class = torch.argmax(output).item()


            lime_attributions = generate_lime_heatmap(
                model,
                spectrogram=input_tensor,
                target_class=target_class,
                grid_size=grid_size,
                n_samples=500  
            )

            masked_tensor, mask_coverage = mask_positive_contributions_lime(input_tensor, lime_attributions, mask_threshold)
            masked_features.append(masked_tensor.squeeze(0)) 
            masked_labels.append(labels[0].item())
            mask_percentages.append(mask_coverage)  


#             print(f"Sample {len(masked_features)}: Masked percentage = {mask_coverage:.2f}%")

    return torch.stack(masked_features), torch.tensor(masked_labels), mask_percentages


def evaluate_with_masked_features(model, masked_features, masked_labels, criterion):
    model.eval()
    masked_dataset = torch.utils.data.TensorDataset(masked_features, masked_labels)
    masked_loader = DataLoader(masked_dataset, batch_size=1, shuffle=False)

    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    running_loss = 0.0

    with torch.no_grad():
        for mel_spec, labels in tqdm(masked_loader, desc="Evaluating with masked features"):
            mel_spec = mel_spec.to(device)
            labels = labels.to(device)
            outputs = model(mel_spec)
            loss = criterion(outputs, labels)  
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    f1 = f1_score(all_labels, all_preds)
    sensitivity = recall_score(all_labels, all_preds)
    avg_loss = running_loss / len(masked_loader)

    cm = confusion_matrix(all_labels, all_preds)  


    print(f'Masked Validation Accuracy: {accuracy:.4f}')
    print(f'Masked Validation F1 Score: {f1:.4f}')
    print(f'Masked Validation Sensitivity (Recall): {sensitivity:.4f}')
    print(f'Masked Validation Loss: {avg_loss:.4f}')


    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Masked Features)')
    plt.show()

    return accuracy, f1, sensitivity, cm


val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


masked_features, masked_labels, mask_percentages = compute_and_mask_lime_features(
    model, val_loader, mask_threshold=90, grid_size=(28, 28), max_samples=100
)


print(f"Average Masked Percentage: {np.mean(mask_percentages):.2f}%")
print(f"Min Masked Percentage: {np.min(mask_percentages):.2f}%")
print(f"Max Masked Percentage: {np.max(mask_percentages):.2f}%")


masked_accuracy, masked_f1, masked_sensitivity, masked_cm = evaluate_with_masked_features(
    model, masked_features, masked_labels, criterion
)


## overlapping

In [None]:
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import seaborn as sns
import matplotlib.pyplot as plt
import random


def binarize_contributions(contributions, threshold=90):

    positive_contributions = np.maximum(contributions, 0)
    binarized = (positive_contributions >= np.percentile(positive_contributions, threshold)).astype(int)
    return binarized


def combine_and_mask_contributions(input_tensor, contributions_list, min_overlap=2):

    combined_mask = sum(binarize_contributions(contributions) for contributions in contributions_list)
    mask = (combined_mask >= min_overlap).astype(float)
    masked_tensor = input_tensor.cpu().numpy() * (1 - mask) 


    total_elements = mask.size
    masked_elements = np.sum(mask)
    masked_percentage = (masked_elements / total_elements) * 100

    return torch.tensor(masked_tensor, dtype=torch.float32), masked_percentage


def process_multiple_xai_techniques(model, input_tensor, target_class):

    # Grad-CAM
    gradcam_attributions = generate_gradcam_heatmap(
        model, 
        input_tensor, 
        target_class=target_class, 
        layer_name=model.resnet.layer4
    )
    
    # Ig
    ig_attributions = generate_ig_heatmap(
        model, 
        input_tensor, 
        target_class=target_class
    )
    
    # CAM
    cam_attributions = generate_cam_heatmap(
        model, 
        input_tensor, 
        target_class=target_class
    )
    
    # Occlusion
    occlusion_map = generate_occlusion_heatmap(
        model, 
        input_tensor, 
        target_class=target_class, 
        strides=(1, 8, 8), 
        window_shapes=(1, 8, 8), 
        top_percent=10
    )
    
    # LIME
    lime_map = generate_lime_heatmap(
        model, 
        spectrogram=input_tensor, 
        target_class=target_class, 
        grid_size=(32, 32), 
        n_samples=500,  
        top_percent=10
    )

    return [gradcam_attributions, ig_attributions, cam_attributions, occlusion_map, lime_map]


def evaluate_with_combined_contributions(model, val_loader, criterion, mask_threshold=90, min_overlap=2, max_samples=10):

    masked_features = []
    masked_labels = []
    masked_percentages = []


    all_indices = list(range(len(val_loader.dataset)))
    random_indices = random.sample(all_indices, max_samples)
    subset_loader = DataLoader(Subset(val_loader.dataset, random_indices), batch_size=1, shuffle=False)

    with torch.no_grad():
        for mel_spec, labels in tqdm(subset_loader, desc="Processing multiple XAI techniques"):
            mel_spec = mel_spec.to(device)
            input_tensor = mel_spec[0].unsqueeze(0) 
            output = model(input_tensor) 
            target_class = torch.argmax(output).item()

            # 生成多个 XAI 技术的贡献图
            contributions_list = process_multiple_xai_techniques(model, input_tensor, target_class)


            masked_tensor, masked_percentage = combine_and_mask_contributions(input_tensor, contributions_list, min_overlap=min_overlap)
            masked_features.append(masked_tensor.squeeze(0)) 
            masked_labels.append(labels[0].item())
            masked_percentages.append(masked_percentage)

    masked_features = torch.stack(masked_features)
    masked_labels = torch.tensor(masked_labels)


    accuracy, f1, sensitivity, cm = evaluate_with_masked_features(model, masked_features, masked_labels, criterion)


    avg_masked_percentage = np.mean(masked_percentages)
    print(f"Average Masked Percentage: {avg_masked_percentage:.2f}%")

    return accuracy, f1, sensitivity, cm, avg_masked_percentage

# min_overlap 2-5, the key paramter
accuracy, f1, sensitivity, cm, avg_masked_percentage = evaluate_with_combined_contributions(
    model, val_loader, criterion, mask_threshold=90, min_overlap=, max_samples=len(val_dataset)
)


print(f"Masked Validation Accuracy: {accuracy:.4f}")
print(f"Masked Validation F1 Score: {f1:.4f}")
print(f"Masked Validation Sensitivity (Recall): {sensitivity:.4f}")
print(f"Average Masked Percentage: {avg_masked_percentage:.2f}%")
