This notebook is used to generate results from planted audio detection part of the paper.

# Planted audio detection

In [None]:
import os
import argparse

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from sklearn import svm

import cv2
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import wavfile

import torch
from torch.utils.data import DataLoader, TensorDataset

from stingray.lightcurve import Lightcurve
from stingray.bispectrum import Bispectrum

import numpy as np
import random
random.seed(42)
import librosa
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from tqdm import tqdm

In [None]:
# parameters used in the research
SEGMENT_SIZE = 400
SEGMENT_OVERLAP = 200

min_slice_size = 6400
max_slice_size = 9600
SLICE_SIZE = 3200
SLICE_OVERLAP = 1600

random_controller = 0
device = None

In [None]:
# model to use
class ResNetMulti(nn.Module):
    def __init__(self):
        super(ResNetMulti, self).__init__()
        self.model = models.resnet50(pretrained=False, num_classes=1)
        self.model.conv1 = nn.Conv2d(5, 64, kernel_size=(5, 5), stride=(2, 2), padding=(3, 3), bias=True)
    
    def forward(self, x):
        return nn.Sigmoid()(self.model(x))

# merging 2 audios' random slices
def merge_2_audio(real, fake):
    global random_controller
    random.seed(random_controller)
    
    parts = []
    labels = []
    index = 0
    
    while len(real) + len(fake) > 0:
        length = random.randint(min_slice_size, max_slice_size)
        if random.random() < 0.5 and len(real) > 0:
            min_len = min(length, len(real))
            parts.append(real[:min_len])
            labels.append(0)
            real = real[min_len:]
            index += min_len
        else:
            min_len = min(length, len(fake))
            parts.append(fake[:min_len])
            fake = fake[min_len:]
            index += min_len
            labels.append(1)
    return parts, labels

# check the create_image_dataset.py for comments
def get_features(data, samplerate, max_K=-1):
    num_segments = (len(data) - SEGMENT_SIZE) // SEGMENT_OVERLAP + 1
    if max_K > 0:
        num_segments = min(num_segments, max_K)

    RC_layers = np.zeros((num_segments, SEGMENT_SIZE+1, SEGMENT_SIZE+1), dtype=complex)
    cum3_sum = np.zeros((SEGMENT_SIZE+1, SEGMENT_SIZE+1))

    time_values = np.linspace(0, SEGMENT_SIZE / samplerate, SEGMENT_SIZE)
    for idx, segment_start in enumerate(range(0, len(data), SEGMENT_OVERLAP)):
        if idx == num_segments:
            break
        segment = data[segment_start:segment_start + SEGMENT_SIZE]
        
        lc = Lightcurve(time_values, segment)
        bs = Bispectrum(lc, window="hamming")

        mag, phase, cum3 = bs.bispec_mag, bs.bispec_phase, bs.cum3
        cum3_sum = cum3_sum + cum3

        R = mag * np.cos(phase)
        C = mag * np.sin(phase)
        RC_layers[idx] = R + C * 1j

    return RC_layers, cum3_sum/num_segments

# check the create_image_dataset.py for comments
def create_signature_image(RC_layers):
    RC_layers = RC_layers[..., np.newaxis]
    signature_image = np.zeros(RC_layers.shape[1:], dtype=complex)
    tops = np.sum(RC_layers, axis=0)

    signature_image = np.reshape(np.array([tops[r][c]/(np.sqrt(np.dot(RC_layers[:,r,c,:].T,np.conjugate(RC_layers[:,r,c,:])).real) + 0.0001) 
                                        for r in range(signature_image.shape[0]) 
                                        for c in range(signature_image.shape[1])]), signature_image.shape)

    # list comprehension is for this for loop
    # for r in range(signature_image.shape[0]):
    #     for c in range(signature_image.shape[1]):
    #         L = RC_layers[:,r,c,:]
    #         top = tops[r][c]
    #         bottom = np.sqrt(np.dot(L.T, np.conjugate(L)).real)
    #         signature_image[r,c] = top/(bottom + 0.0001)

    return signature_image

# check the create_image_dataset.py for comments
def audio_to_images(data, samplerate):
    RC_layers, cum3_avg = get_features(data, samplerate, max_K=-1)
    signature_image = create_signature_image(RC_layers)

    absolute = np.absolute(signature_image)
    absolute_norm = (absolute - absolute.min()) / (absolute.max() - absolute.min())

    angle = np.angle(signature_image)
    angle_norm = (angle - angle.min()) / (angle.max() - angle.min())

    real = signature_image.real
    real_norm = (real - real.min()) / (real.max() - real.min())

    imag = signature_image.imag
    imag_norm = (imag - imag.min()) / (imag.max() - imag.min())

    cum3_norm = (cum3_avg - cum3_avg.min()) / (cum3_avg.max() - cum3_avg.min())

    return absolute_norm[:,:,0], angle_norm[:,:,0], real_norm[:,:,0], imag_norm[:,:,0], cum3_norm

# check the create_image_dataset.py for comments
def get_slice_features(audio):
    slice_features = []
    for slice_start in tqdm(range(0, len(audio), SLICE_OVERLAP)):
        slice = audio[slice_start:slice_start + SLICE_SIZE]
        absolute_norm, angle_norm, real_norm, imag_norm, cum3_norm = audio_to_images(slice, 16000)
        stacked = np.stack([absolute_norm, angle_norm, cum3_norm, imag_norm, real_norm])[np.newaxis, ...]
        slice_features.append(stacked)

    return slice_features

# passes features to the model
def pass_features_to_model(model, slice_features, batch_size=4):
    features_batch = np.vstack(slice_features)
    features_batch = features_batch.astype(np.float32)

    input_tensor = torch.from_numpy(features_batch.astype(np.float32)).to(device)

    dataset = TensorDataset(input_tensor)
    batch_size = 4  # Adjust batch size as needed
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    probs_list = []
    with torch.no_grad():
        for batch in dataloader:
            input_tensor = batch[0].to(device)
            probs = model(input_tensor)
            probs_list.append(probs.cpu().numpy())

    # Concatenate all batches into a single numpy array
    probs_array = np.concatenate(probs_list, axis=0)

    return probs_array

In [None]:
# calculates each slices' probability predicted by the model
def get_slice_probs(real_audio_path, fake_audio_path):
    global device
    real, _ = librosa.load(real_audio_path, sr=16000)
    fake, _ = librosa.load(fake_audio_path, sr=16000)

    parts, labels = merge_2_audio(real, fake)

    audio = np.hstack(parts)
    slice_features = get_slice_features(audio)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ResNetMulti().to(device)
    model.load_state_dict(torch.load(""))
    model = model.eval()

    probs = pass_features_to_model(model, slice_features, 4)

    return audio, parts, labels, probs

# draws plot of predictions
def draw_plot(audio, parts, labels, probs_list, SLICE_SIZE_AND_OVERLAPS, title):
    num_plots = len(probs_list)
    fig, axs = plt.subplots(num_plots, 1, figsize=(14, 5 * num_plots), sharex=True)

    for i, (probs, (SLICE_SIZE, SLICE_OVERLAP)) in enumerate(zip(probs_list, SLICE_SIZE_AND_OVERLAPS)):
        ax = axs[i] if num_plots > 1 else axs  # Handle single subplot case
        start_index = 0
        for m, l in zip(parts, labels):
            end_index = start_index + len(m)
            c = "red" if l == 1 else "blue"
            ax.plot(list(range(start_index, end_index)), m, color=c)
            start_index = end_index

        # confidence limits
        ax.axhline(y=0.5, color='black', linestyle=':')
        ax.axhline(y=-0.5, color='black', linestyle=':')

        # confidence
        detection_points = range(SLICE_SIZE//2, len(audio) + SLICE_OVERLAP, SLICE_OVERLAP)
        zero_line = [0 for _ in probs]
        prob_points = probs[:, 0] - 0.5
        ax.plot(detection_points, prob_points, color="black", zorder=100, linestyle='--', alpha=0.7, marker="o", markersize=3)
        ax.fill_between(detection_points, zero_line, prob_points, where=(zero_line > prob_points), color='blue', alpha=0.3, interpolate=True)
        ax.fill_between(detection_points, zero_line, prob_points, where=(zero_line <= prob_points), color='red', alpha=0.3, interpolate=True)
        
        ax.tick_params(axis='both', which='major', labelsize=24)
        ax.set_title(f"Detection on segments with {SLICE_SIZE} samples overlapping at {SLICE_OVERLAP} samples", fontsize=24)

    plt.tight_layout()
    plt.savefig(title, bbox_inches='tight')
    plt.show()

In [None]:
real_path = "real.wav"
fake_path = "fake.wav"

In [None]:
random_controller = 0
min_slice_size = 3200
max_slice_size = 6400
SLICE_SIZE_AND_OVERLAPS = [(2400, 1200), (4000, 2000), (6400, 3200), (8000, 4000)]
probs_list = []

for (SLICE_SIZE, SLICE_OVERLAP) in SLICE_SIZE_AND_OVERLAPS:
    audio, parts, labels, probs = get_slice_probs(real_path, fake_path)
    probs_list.append(probs)
draw_plot(audio, parts, labels, probs_list, SLICE_SIZE_AND_OVERLAPS, "3200-6400.png")

In [None]:
random_controller = 0
min_slice_size = 6400
max_slice_size = 9600
SLICE_SIZE_AND_OVERLAPS = [(2400, 1200), (4000, 2000), (6400, 3200), (8000, 4000)]
probs_list = []

for (SLICE_SIZE, SLICE_OVERLAP) in SLICE_SIZE_AND_OVERLAPS:
    audio, parts, labels, probs = get_slice_probs(real_path, fake_path)
    probs_list.append(probs)
draw_plot(audio, parts, labels, probs_list, SLICE_SIZE_AND_OVERLAPS, "6400-9600.png")

In [None]:
random_controller = 0
min_slice_size = 9600
max_slice_size = 12800
SLICE_SIZE_AND_OVERLAPS = [(2400, 1200), (4000, 2000), (6400, 3200), (8000, 4000)]
probs_list = []

for (SLICE_SIZE, SLICE_OVERLAP) in SLICE_SIZE_AND_OVERLAPS:
    audio, parts, labels, probs = get_slice_probs(real_path, fake_path)
    probs_list.append(probs)
draw_plot(audio, parts, labels, probs_list, SLICE_SIZE_AND_OVERLAPS, "9600-12800.png")

In [None]:
random_controller = 0
min_slice_size = 12800
max_slice_size = 16000
SLICE_SIZE_AND_OVERLAPS = [(2400, 1200), (4000, 2000), (6400, 3200), (8000, 4000)]
probs_list = []

for (SLICE_SIZE, SLICE_OVERLAP) in SLICE_SIZE_AND_OVERLAPS:
    audio, parts, labels, probs = get_slice_probs(real_path, fake_path)
    probs_list.append(probs)
draw_plot(audio, parts, labels, probs_list, SLICE_SIZE_AND_OVERLAPS, "12800-16000.png")