In [1]:
import sounddevice as sd
import numpy as np
import queue
import threading
import time
import torch
import torchinfo
import torch.nn.functional as F
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from models import *
from datasets import GunshotDataset

In [2]:
# Producer parameters
fs = 44100  # Sampling frequency
chunk_duration = 2  # Chunk duration in seconds
chunk_samples = int(fs * chunk_duration)  # Samples per chunk
spec_dims = (256, 256)  # Spectrogram dimensions
power = 2.0  # Power for spectrogram calculation

# Queues for spectrograms and for plotting
spec_queue = queue.Queue()
# plot_queue = queue.Queue()
inference_queue = queue.Queue()

def make_spectrogram(waveform, n_samples, spec_dims, power):
    n_ffts = (spec_dims[0] * 2) - 1
    hop_length = max(1, int((n_samples - n_ffts) / (spec_dims[1] - 1)) + 2)
    to_specgram = T.Spectrogram(n_fft=n_ffts, hop_length=hop_length, power=power)
    power_to_dB = T.AmplitudeToDB(stype='power')
    spec = to_specgram(waveform)
    spec_dB = power_to_dB(spec)
    return spec_dB

def audio_callback(indata, frames, time, status):
    if status:
        print(status)
    chunk = torch.tensor(indata[:, 0], dtype=torch.float32)
    spec = make_spectrogram(chunk, chunk_samples, spec_dims, power)
    spec_queue.put(spec.unsqueeze(0).unsqueeze(0))

def record_audio():
    with sd.InputStream(callback=audio_callback, channels=1, samplerate=fs, blocksize=chunk_samples):
        print("Recording...")
        while True:
            pass

def monitor_queue():
    """Monitors the queue and prepares spectrograms for plotting."""
    try:
        while True:
            spectrogram = spec_queue.get()  # Blocks until an item is available
            # plot_queue.put(spectrogram)
            inference_queue.put(spectrogram)
    except KeyboardInterrupt:
        print("Stopped monitoring the queue.")
        
def stream_and_infer():
    # Load model
    model = build_resnet18()
    checkpoint = torch.load("checkpoints/resnet18.pth")
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    try:
        while True:
            spec = inference_queue.get()
            output = model(spec)
            logit = F.sigmoid(output)
            decision = "positive" if logit > 0.99 else "negative"
            print(logit.item(), ":", decision)
    except KeyboardInterrupt:
        print("Interrupted.")

# def plot_spectrograms():
#     """Plot spectrograms from the plot queue on the main thread."""
#     try:
#         while True:
#             spectrogram = plot_queue.get()
#             plt.figure(figsize=(10, 4))
#             plt.imshow(spectrogram.numpy(), origin='lower')
#             plt.colorbar(format='%+2.0f dB')
#             plt.title('Spectrogram (dB)')
#             plt.ylabel('Frequency Bin')
#             plt.xlabel('Time Frame')
#             plt.tight_layout()
#             plt.show()
#     except KeyboardInterrupt:
#         clear_output()
#         print("Stopped plotting.")


# Start producer and consumer threads
recording_thread = threading.Thread(target=record_audio)
monitoring_thread = threading.Thread(target=monitor_queue)

recording_thread.start()
monitoring_thread.start()

# Run the plotting function on the main thread
stream_and_infer()

Recording...
0.017445573583245277 : negative
0.02835758402943611 : negative
0.3268868029117584 : negative
0.09615269303321838 : negative
0.3005329668521881 : negative
0.043674346059560776 : negative
0.5081961154937744 : negative
0.7469927668571472 : negative
0.03297499194741249 : negative
0.6285988688468933 : negative
0.0794837698340416 : negative
0.06883645057678223 : negative
0.03200659155845642 : negative
0.5939410924911499 : negative
0.041939057409763336 : negative
0.988934338092804 : negative
0.04536587744951248 : negative
0.021998178213834763 : negative
0.0055379318073391914 : negative
0.0031186279375106096 : negative
0.0015643545193597674 : negative
0.0012173077557235956 : negative
0.0011251705000177026 : negative
0.02721616066992283 : negative
0.8690611720085144 : negative
0.1265464723110199 : negative
0.27860596776008606 : negative
0.19945329427719116 : negative
0.7288821339607239 : negative
0.30214905738830566 : negative
0.016690634191036224 : negative
0.12767919898033142 : n