# Universal Adversarial Audio Attack on Whisper

**Proof of Concept**: Creating a universal noise pattern that, when added to any audio, causes Whisper to mistranscribe.

In [None]:
import numpy as np
import librosa
import whisper
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from jiwer import wer
import torch
import random

In [None]:
# Configuration
LIBRISPEECH_PATH = Path("LibriSpeech/test-clean")
SAMPLE_RATE = 16000  # Whisper's expected sample rate
MAX_SAMPLES = 25     # Number of audio samples to test (reduced for faster testing)
MODEL = "base"  # Whisper model size, can be "tiny", "base", "small", "medium", "large"

# General Attack Parameters
ITERATIONS = 1
LEARNING_RATE = 0.01
EPSILON = 0.02

# Fast Feature Fool (FFF) specific parameters
ATTACK_LAYERS = [
    'encoder.conv1',
    'encoder.conv2',
    'encoder.blocks.2.attn',
    'encoder.blocks.4.attn'
]

# GD-UAP specific parameters
# (Uses general parameters, no specific ones needed for this implementation)

# PSP-UAP specific parameters
PSP_LAMBDA_PERCEPTUAL = 1.0  # Weight for the perceptual loss component
PSP_LAMBDA_FOOLING = 1.0     # Weight for the fooling loss component

# Load Whisper model
model = whisper.load_model(MODEL)
print(f"Loaded Whisper {MODEL} model")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

In [None]:
def load_librispeech_samples(data_path, max_samples=20):
    """Load audio files and transcriptions from LibriSpeech dataset"""
    samples = []

    if not data_path.exists():
        return samples

    # Iterate through speaker directories
    for speaker_dir in data_path.iterdir():
        if not speaker_dir.is_dir():
            continue

        # Iterate through chapter directories
        for chapter_dir in speaker_dir.iterdir():
            if not chapter_dir.is_dir():
                continue

            # Load transcription file
            speaker_id = speaker_dir.name
            chapter_id = chapter_dir.name
            trans_file = chapter_dir / f"{speaker_id}-{chapter_id}.trans.txt"
            if not trans_file.exists():
                continue

            # Parse transcriptions
            transcriptions = {}
            with open(trans_file, 'r') as f:
                for line in f:
                    parts = line.strip().split(' ', 1)
                    if len(parts) == 2:
                        transcriptions[parts[0]] = parts[1]

            # Load audio files
            for audio_file in chapter_dir.glob("*.flac"):
                if len(samples) >= max_samples:
                    return samples

                file_id = audio_file.stem
                if file_id in transcriptions:
                    try:
                        audio, _ = librosa.load(audio_file, sr=SAMPLE_RATE)
                        samples.append({
                            'id': file_id,
                            'audio': audio,
                            'transcript': transcriptions[file_id],
                            'path': str(audio_file)
                        })
                    except Exception:
                        continue

    return samples

# Load samples
samples = load_librispeech_samples(LIBRISPEECH_PATH, MAX_SAMPLES)
print(f"Loaded {len(samples)} audio samples")

In [None]:
class FastFeatureFoolAttack:
    """
    Implementation of the Fast Feature Fool attack for Whisper in PyTorch.
    Generates a universal perturbation to maximize activations in specified layers.
    """
    def __init__(self, model, attack_layers, epsilon=0.02, learning_rate=0.01, iterations=100):
        self.model = model
        self.attack_layers = attack_layers
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.iterations = iterations
        self.hooks = []
        self.features = {}
        self.device = next(model.parameters()).device

    def _get_feature_hook(self, name):
        def hook(model, input, output):
            self.features[name] = output[0] if isinstance(output, tuple) else output
        return hook

    def _register_hooks(self):
        for name, module in self.model.named_modules():
            if name in self.attack_layers:
                self.hooks.append(module.register_forward_hook(self._get_feature_hook(name)))

    def _remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

    def generate(self, audio_samples, sample_rate):
        print("Generating universal perturbation with Fast Feature Fool...")
        self._register_hooks()
        max_length = max(len(audio) for audio in audio_samples)
        perturbation = torch.zeros(max_length, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([perturbation], lr=self.learning_rate)
        for _ in tqdm(range(self.iterations), desc="FFF Iterations"):
            optimizer.zero_grad()
            total_loss = 0
            batch_samples = random.sample(audio_samples, min(len(audio_samples), 4))
            for audio in batch_samples:
                audio_tensor = torch.tensor(audio, dtype=torch.float32, device=self.device)
                if len(perturbation) > len(audio_tensor):
                    pert = perturbation[:len(audio_tensor)]
                else:
                    n_repeats = (len(audio_tensor) + len(perturbation) - 1) // len(perturbation)
                    pert = perturbation.repeat(n_repeats)[:len(audio_tensor)]
                attacked_audio = torch.clamp(audio_tensor + pert, -1.0, 1.0)
                n_samples = whisper.audio.N_SAMPLES
                if attacked_audio.shape[0] > n_samples:
                    padded_audio = attacked_audio[:n_samples]
                else:
                    padded_audio = torch.nn.functional.pad(attacked_audio, (0, n_samples - attacked_audio.shape[0]))
                mel = whisper.log_mel_spectrogram(padded_audio).to(self.device)
                self.model.encoder(mel.unsqueeze(0))
                loss = 0
                for name in self.features:
                    loss -= torch.norm(self.features[name])
                total_loss += loss
            total_loss.backward()
            optimizer.step()
            perturbation.data = torch.clamp(perturbation.data, -self.epsilon, self.epsilon)
        self._remove_hooks()
        print("Finished generating perturbation.")
        return perturbation.detach().cpu().numpy()

class GDUAPAttack:
    def __init__(self, model, epsilon=0.02, learning_rate=0.01, iterations=100):
        self.model = model
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.iterations = iterations
        self.device = next(model.parameters()).device

    def _build_target_tokens(self, tokenizer, transcript):
        prefix = []
        sot_seq = getattr(tokenizer, 'sot_sequence', None)
        if sot_seq is not None:
            if callable(sot_seq):
                try:
                    sot_seq_val = sot_seq()
                except TypeError:
                    sot_seq_val = sot_seq
            else:
                sot_seq_val = sot_seq
            prefix = list(sot_seq_val)
        elif hasattr(tokenizer, 'sot'):
            prefix = [tokenizer.sot]
        encoded = tokenizer.encode(transcript)
        if hasattr(tokenizer, 'eot'):
            encoded = encoded + [tokenizer.eot]
        tokens = prefix + encoded
        tokens = tokens[: self.model.dims.n_text_ctx]
        return torch.tensor(tokens, device=self.device, dtype=torch.long)

    def generate(self, audio_samples, sample_rate):
        print("Generating universal perturbation with GD-UAP...")
        max_length = max(len(s['audio']) for s in audio_samples)
        perturbation = torch.zeros(max_length, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([perturbation], lr=self.learning_rate)
        tokenizer = whisper.tokenizer.get_tokenizer(self.model.is_multilingual, language="en", task="transcribe")
        for _ in tqdm(range(self.iterations), desc="GD-UAP Iterations"):
            optimizer.zero_grad()
            total_obj = 0.0
            batch_samples = random.sample(audio_samples, min(len(audio_samples), 4))
            for sample in batch_samples:
                audio = sample['audio']
                transcript = sample['transcript']
                audio_tensor = torch.tensor(audio, dtype=torch.float32, device=self.device)
                if len(perturbation) > len(audio_tensor):
                    pert = perturbation[:len(audio_tensor)]
                else:
                    n_repeats = (len(audio_tensor) + len(perturbation) - 1) // len(perturbation)
                    pert = perturbation.repeat(n_repeats)[:len(audio_tensor)]
                attacked_audio = torch.clamp(audio_tensor + pert, -1.0, 1.0)
                n_samples = whisper.audio.N_SAMPLES
                if attacked_audio.shape[0] > n_samples:
                    padded_audio = attacked_audio[:n_samples]
                else:
                    padded_audio = torch.nn.functional.pad(attacked_audio, (0, n_samples - attacked_audio.shape[0]))
                mel = whisper.log_mel_spectrogram(padded_audio).to(self.device)
                full_tokens = self._build_target_tokens(tokenizer, transcript)
                if full_tokens.shape[0] < 2:
                    continue
                tokens_in = full_tokens[:-1].unsqueeze(0)
                targets = full_tokens[1:]
                logits = self.model(mel.unsqueeze(0), tokens_in)
                if isinstance(logits, dict) and 'logits' in logits:
                    logits = logits['logits']
                if isinstance(logits, (list, tuple)) and not torch.is_tensor(logits):
                    logits = logits[0]
                logits = logits[0, :targets.shape[0], :]
                loss = torch.nn.functional.cross_entropy(logits, targets)
                total_obj += -loss
            if isinstance(total_obj, float):
                continue
            total_obj.backward()
            optimizer.step()
            perturbation.data = torch.clamp(perturbation.data, -self.epsilon, self.epsilon)
        print("Finished generating GD-UAP perturbation.")
        return perturbation.detach().cpu().numpy()

class PSPUAPAttack:
    def __init__(self, model, epsilon=0.02, learning_rate=0.01, iterations=100, lambda_perceptual=1.0, lambda_fooling=1.0):
        self.model = model
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.iterations = iterations
        self.lambda_perceptual = lambda_perceptual
        self.lambda_fooling = lambda_fooling
        self.device = next(model.parameters()).device

    def _perceptual_loss(self, perturbation):
        n_fft = 2048
        window = torch.hann_window(n_fft, device=self.device)
        spec = torch.stft(perturbation, n_fft=n_fft, window=window, return_complex=True)
        power_spec = torch.abs(spec) ** 2  # (freq_bins, frames)
        freq_bins = power_spec.shape[0]
        # Dynamic weighting from low->high frequency
        weights = (torch.linspace(0, 1, freq_bins, device=self.device) ** 2)[:, None]  # (freq_bins,1)
        perceptual_loss = (power_spec * weights).mean()
        return perceptual_loss

    def _build_target_tokens(self, tokenizer, transcript):
        prefix = []
        sot_seq = getattr(tokenizer, 'sot_sequence', None)
        if sot_seq is not None:
            if callable(sot_seq):
                try:
                    sot_seq_val = sot_seq()
                except TypeError:
                    sot_seq_val = sot_seq
            else:
                sot_seq_val = sot_seq
            prefix = list(sot_seq_val)
        elif hasattr(tokenizer, 'sot'):
            prefix = [tokenizer.sot]
        encoded = tokenizer.encode(transcript)
        if hasattr(tokenizer, 'eot'):
            encoded = encoded + [tokenizer.eot]
        tokens = prefix + encoded
        tokens = tokens[: self.model.dims.n_text_ctx]
        return torch.tensor(tokens, device=self.device, dtype=torch.long)

    def generate(self, audio_samples, sample_rate):
        print("Generating universal perturbation with PSP-UAP...")
        max_length = max(len(s['audio']) for s in audio_samples)
        perturbation = torch.zeros(max_length, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([perturbation], lr=self.learning_rate)
        tokenizer = whisper.tokenizer.get_tokenizer(self.model.is_multilingual, language="en", task="transcribe")
        for _ in tqdm(range(self.iterations), desc="PSP-UAP Iterations"):
            optimizer.zero_grad()
            fooling_obj = 0.0
            batch_samples = random.sample(audio_samples, min(len(audio_samples), 4))
            for sample in batch_samples:
                audio = sample['audio']
                transcript = sample['transcript']
                audio_tensor = torch.tensor(audio, dtype=torch.float32, device=self.device)
                if len(perturbation) > len(audio_tensor):
                    pert = perturbation[:len(audio_tensor)]
                else:
                    n_repeats = (len(audio_tensor) + len(perturbation) - 1) // len(perturbation)
                    pert = perturbation.repeat(n_repeats)[:len(audio_tensor)]
                attacked_audio = torch.clamp(audio_tensor + pert, -1.0, 1.0)
                n_samples = whisper.audio.N_SAMPLES
                if attacked_audio.shape[0] > n_samples:
                    padded_audio = attacked_audio[:n_samples]
                else:
                    padded_audio = torch.nn.functional.pad(attacked_audio, (0, n_samples - attacked_audio.shape[0]))
                mel = whisper.log_mel_spectrogram(padded_audio).to(self.device)
                full_tokens = self._build_target_tokens(tokenizer, transcript)
                if full_tokens.shape[0] < 2:
                    continue
                tokens_in = full_tokens[:-1].unsqueeze(0)
                targets = full_tokens[1:]
                logits = self.model(mel.unsqueeze(0), tokens_in)
                if isinstance(logits, dict) and 'logits' in logits:
                    logits = logits['logits']
                if isinstance(logits, (list, tuple)) and not torch.is_tensor(logits):
                    logits = logits[0]
                logits = logits[0, :targets.shape[0], :]
                ce_loss = torch.nn.functional.cross_entropy(logits, targets)
                fooling_obj += -ce_loss
            if not isinstance(fooling_obj, torch.Tensor):
                continue
            perceptual = self._perceptual_loss(perturbation)
            total_obj = self.lambda_fooling * fooling_obj + self.lambda_perceptual * (-perceptual)
            total_obj.backward()
            optimizer.step()
            perturbation.data = torch.clamp(perturbation.data, -self.epsilon, self.epsilon)
        print("Finished generating PSP-UAP perturbation.")
        return perturbation.detach().cpu().numpy()

print("Setting up attacks...")
fff_attack = FastFeatureFoolAttack(model, ATTACK_LAYERS, epsilon=EPSILON, learning_rate=LEARNING_RATE, iterations=ITERATIONS)
gduap_attack = GDUAPAttack(model, epsilon=EPSILON, learning_rate=LEARNING_RATE, iterations=ITERATIONS)
pspuap_attack = PSPUAPAttack(model, epsilon=EPSILON, learning_rate=LEARNING_RATE, iterations=ITERATIONS, 
                             lambda_perceptual=PSP_LAMBDA_PERCEPTUAL, lambda_fooling=PSP_LAMBDA_FOOLING)
clean_audio_list = [sample['audio'] for sample in samples]
noise_patterns = {
    'fff': fff_attack.generate(clean_audio_list, SAMPLE_RATE),
    'gduap': gduap_attack.generate(samples, SAMPLE_RATE),
    'pspuap': pspuap_attack.generate(samples, SAMPLE_RATE)
}
print("\nGenerated noise patterns:")
for name, noise in noise_patterns.items():
    print(f"  {name}: amplitude range [{noise.min():.4f}, {noise.max():.4f}]")

In [None]:
def apply_adversarial_noise(audio, noise):
    """Apply universal noise to audio sample"""
    # Truncate or pad noise to match audio length
    if len(noise) > len(audio):
        noise_trimmed = noise[:len(audio)]
    else:
        noise_trimmed = np.pad(noise, (0, len(audio) - len(noise)), mode='wrap')

    # Add noise to audio
    attacked_audio = audio + noise_trimmed

    # Ensure audio stays in valid range [-1, 1]
    attacked_audio = np.clip(attacked_audio, -1.0, 1.0)

    return attacked_audio

# Apply all attack methods to all samples
print("Applying different adversarial noise patterns to all samples...")

for sample in tqdm(samples, desc="Processing samples"):
    # Apply each attack method
    for attack_name, noise_pattern in noise_patterns.items():
        attacked_audio = apply_adversarial_noise(sample['audio'], noise_pattern)
        sample[f'attacked_audio_{attack_name}'] = attacked_audio

print(f"Applied {len(noise_patterns)} different attack methods to all samples")

In [None]:
def transcribe_batch(model, audio_list, batch_name=""):
    """Transcribe a batch of audio samples using Whisper"""
    transcriptions = []

    print(f"Transcribing {len(audio_list)} {batch_name} samples...")
    for audio in tqdm(audio_list, desc=f"Transcribing {batch_name}"):
        try:
            # Ensure audio is float32 and normalized
            audio = np.array(audio, dtype=np.float32)
            
            # Normalize audio to [-1, 1] range if needed
            if np.max(np.abs(audio)) > 1.0:
                audio = audio / np.max(np.abs(audio))
            
            result = model.transcribe(audio, fp16=False)
            transcriptions.append(result['text'].strip().upper())
        except Exception as e:
            print(f"Transcription error: {e}")
            transcriptions.append("")

    return transcriptions

# Transcribe clean audio
clean_audio_list = [sample['audio'] for sample in samples]
clean_transcriptions = transcribe_batch(model, clean_audio_list, "clean")

# Store clean transcriptions
for i, sample in enumerate(samples):
    sample['clean_prediction'] = clean_transcriptions[i]

# Transcribe attacked audio for each method
attack_transcriptions = {}
for attack_name in noise_patterns.keys():
    attacked_audio_list = [sample[f'attacked_audio_{attack_name}'] for sample in samples]
    attack_transcriptions[attack_name] = transcribe_batch(model, attacked_audio_list, f"attacked ({attack_name})")
    
    # Store transcriptions in samples
    for i, sample in enumerate(samples):
        sample[f'attacked_prediction_{attack_name}'] = attack_transcriptions[attack_name][i]

print("All transcriptions completed")


In [None]:
# Calculate WER for all samples and attack methods
results = []

for sample in samples:
    ground_truth = sample['transcript'].upper()
    clean_pred = sample['clean_prediction']
    clean_wer = wer(ground_truth, clean_pred)
    
    result_row = {
        'id': sample['id'],
        'ground_truth': ground_truth,
        'clean_prediction': clean_pred,
        'clean_wer': clean_wer,
    }
    
    # Add results for each attack method
    for attack_name in noise_patterns.keys():
        attacked_pred = sample[f'attacked_prediction_{attack_name}']
        attacked_wer = wer(ground_truth, attacked_pred)
        wer_increase = attacked_wer - clean_wer
        
        result_row[f'attacked_prediction_{attack_name}'] = attacked_pred
        result_row[f'attacked_wer_{attack_name}'] = attacked_wer
        result_row[f'wer_increase_{attack_name}'] = wer_increase
    
    results.append(result_row)

# Create results DataFrame
df_results = pd.DataFrame(results)

# Calculate summary statistics
print("EVALUATION RESULTS:")
print("=" * 60)

avg_clean_wer = df_results['clean_wer'].mean()
print(f"Average Clean WER: {avg_clean_wer:.3f} ({avg_clean_wer*100:.1f}%)")
print()

attack_summary = {}
for attack_name in noise_patterns.keys():
    avg_attacked_wer = df_results[f'attacked_wer_{attack_name}'].mean()
    avg_wer_increase = df_results[f'wer_increase_{attack_name}'].mean()
    success_rate = (df_results[f'wer_increase_{attack_name}'] > 0).mean() * 100
    
    attack_summary[attack_name] = {
        'avg_wer': avg_attacked_wer,
        'avg_increase': avg_wer_increase,
        'success_rate': success_rate
    }
    
    print(f"{attack_name.upper()} ATTACK:")
    print(f"  Average WER: {avg_attacked_wer:.3f} ({avg_attacked_wer*100:.1f}%)")
    print(f"  Average WER Increase: {avg_wer_increase:.3f} ({avg_wer_increase*100:.1f}%)")
    print(f"  Success Rate: {success_rate:.1f}% (samples with WER increase > 0)")
    print()

# Find the most effective attack
best_attack = max(attack_summary.items(), key=lambda x: x[1]['avg_increase'])
print(f"MOST EFFECTIVE ATTACK: {best_attack[0].upper()}")
print(f"  WER Increase: {best_attack[1]['avg_increase']:.3f} ({best_attack[1]['avg_increase']*100:.1f}%)")


In [None]:
# Create WER comparison visualization
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
fig.suptitle('Universal Adversarial Audio Attack - Average WER Comparison', fontsize=16, fontweight='bold')

# WER Comparison (Bar Plot)
categories = ['Clean'] + [name.upper() for name in noise_patterns.keys()]
wer_values = [avg_clean_wer] + [attack_summary[name]['avg_wer'] for name in noise_patterns.keys()]
colors = ['lightblue'] + ['lightcoral', 'lightgreen', 'lightsalmon']

bars = ax.bar(categories, wer_values, color=colors[:len(categories)], alpha=0.7)
ax.set_ylabel('Word Error Rate', fontsize=12)
ax.set_title('Average WER Comparison', fontweight='bold', fontsize=14)
ax.set_ylim(0, max(wer_values) * 1.2 if wer_values else 1)
ax.tick_params(axis='x', rotation=45)

for bar, val in zip(bars, wer_values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{val:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=10)

plt.tight_layout()
plt.show()


In [None]:
# Audio Playback - Listen to Clean vs Attacked Audio
from IPython.display import Audio, display

def play_audio_comparison(sample_idx=0):
    """Play clean, attacked, and noise audio for comparison"""
    sample = samples[sample_idx]

    # Play clean audio
    print("\nCLEAN AUDIO:")
    print(f"  Ground Truth: {sample['transcript'].upper()}")
    print(f"  Whisper Prediction: {sample['clean_prediction']}")
    display(Audio(sample['audio'], rate=SAMPLE_RATE))

    # Play each attacked version and its corresponding noise
    for attack_name, noise_pattern in noise_patterns.items():
        attacked_audio = sample[f'attacked_audio_{attack_name}']
        print(f"\n{attack_name.upper()} ATTACKED AUDIO:")
        print(f"  Whisper Prediction: {sample[f'attacked_prediction_{attack_name}']}")
        display(Audio(attacked_audio, rate=SAMPLE_RATE))
        
        print(f"\n{attack_name.upper()} ADVERSARIAL NOISE:")
        display(Audio(noise_pattern, rate=SAMPLE_RATE))

play_audio_comparison(0)