In [1]:
import yt_dlp as ydl
import subprocess
import os
import dotenv

In [2]:
dotenv.load_dotenv()

BASE_PATH = os.getenv('BASE_PATH')
TEST_PATH = os.path.join(BASE_PATH, 'test')
os.makedirs(TEST_PATH, exist_ok=True)

full_audio = os.path.join(TEST_PATH, "audio_full.wav")
trimmed_audio = os.path.join(TEST_PATH, "audio_trimmed.wav")

In [None]:
# url = "https://www.youtube.com/watch?v=GAWKHBycWvQ"# 0 - 14 sec
# url = "https://www.youtube.com/watch?v=qa5wIivTuxY" # 3-8 sec
url = "https://www.youtube.com/watch?v=Ymyq2tmv8d4" #0 - 20 sec


start = "00:00:00"
end   = "00:00:20"

ydl_opts = {
    'format': 'bestaudio/best',
    'outtmpl': full_audio,
    'quiet': False,
}

with ydl.YoutubeDL(ydl_opts) as dl:
    dl.download([url])

subprocess.run([
    "ffmpeg",
    "-i", full_audio,
    "-ss", start,
    "-to", end,
    "-c", "copy",
    trimmed_audio
])

if os.path.exists(full_audio):
    os.remove(full_audio)
assert os.path.exists(trimmed_audio), f'{trimmed_audio} was not found'
print("Finished fragment:", trimmed_audio)

In [None]:
import torch
from transformers import AutoFeatureExtractor, ASTForAudioClassification, ASTConfig
from safetensors.torch import load_file
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
weight_path = '...'
TARGET_LEN = 16000 * 10
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

config = ASTConfig.from_pretrained(model_name)
config.num_labels = 4

model = ASTForAudioClassification(config)
state_dict = load_file(weight_path)
model.load_state_dict(state_dict)
model.eval()

In [5]:
import librosa

def preprocess_for_inference(path):
    waveform, sr = librosa.load(path, sr=16000)
    inputs = feature_extractor(
        waveform,
        sampling_rate=16000,
        return_tensors="pt"
    )
    return inputs["input_values"]

In [None]:
input_values = preprocess_for_inference(trimmed_audio)
input_values[0]

In [10]:
%%time
with torch.no_grad():
    outputs = model(input_values)
    logits = outputs.logits
    predicted_class = logits.argmax(dim=1).item()

print("Predicted class:", predicted_class)

Predicted class: 0
CPU times: total: 11.9 s
Wall time: 5.06 s


In [8]:
%%time

answer = {'siren': 0, 'gunshot': 1, 'explosion': 2, 'casual': 3}
label_name = [k for k, v in answer.items() if v == predicted_class][0]
print(label_name)

siren
