In [1]:
import yt_dlp as ydl
import subprocess
import os
import dotenv

In [2]:
dotenv.load_dotenv()

BASE_PATH = os.getenv('BASE_PATH')
TEST_PATH = os.path.join(BASE_PATH, 'test')
os.makedirs(TEST_PATH, exist_ok=True)

full_audio = os.path.join(TEST_PATH, "audio_full.wav")
trimmed_audio = os.path.join(TEST_PATH, "audio_trimmed.wav")

In [3]:
# url = "https://www.youtube.com/watch?v=GAWKHBycWvQ"# 0 - 14 sec
# url = "https://www.youtube.com/watch?v=qa5wIivTuxY" # 3-8 sec
url = "https://www.youtube.com/watch?v=Ymyq2tmv8d4" #0 - 20 sec


start = "00:00:00"
end   = "00:00:20"

ydl_opts = {
    'format': 'bestaudio/best',
    'outtmpl': full_audio,
    'quiet': False,
}

with ydl.YoutubeDL(ydl_opts) as dl:
    dl.download([url])

subprocess.run([
    "ffmpeg",
    "-i", full_audio,
    "-ss", start,
    "-to", end,
    "-c", "copy",
    trimmed_audio
])

if os.path.exists(full_audio):
    os.remove(full_audio)
assert os.path.exists(trimmed_audio), f'{trimmed_audio} was not found'
print("Finished fragment:", trimmed_audio)

[youtube] Extracting URL: https://www.youtube.com/watch?v=Ymyq2tmv8d4
[youtube] Ymyq2tmv8d4: Downloading webpage
[youtube] Ymyq2tmv8d4: Downloading tv client config
[youtube] Ymyq2tmv8d4: Downloading tv player API JSON
[youtube] Ymyq2tmv8d4: Downloading ios player API JSON
[youtube] Ymyq2tmv8d4: Downloading player ef5f17ca-main


         player = https://www.youtube.com/s/player/ef5f17ca/player_ias.vflset/en_US/base.js
         n = CwIQUtQzkdYE8hIs ; player = https://www.youtube.com/s/player/ef5f17ca/player_ias.vflset/en_US/base.js
         Please report this issue on  https://github.com/yt-dlp/yt-dlp/issues?q= , filling out the appropriate issue template. Confirm you are on the latest version using  yt-dlp -U


[youtube] Ymyq2tmv8d4: Downloading m3u8 information
[info] Testing format 234
[info] Ymyq2tmv8d4: Downloading 1 format(s): 234
[hlsnative] Downloading m3u8 manifest
[hlsnative] Total fragments: 4
[download] Destination: D:\audio_cls_coursework\test\audio_full.wav
[download] 100% of  437.06KiB in 00:00:00 at 608.71KiB/s               
Finished fragment: D:\audio_cls_coursework\test\audio_trimmed.wav


In [4]:
import torch
from transformers import AutoFeatureExtractor, ASTForAudioClassification, ASTConfig
from safetensors.torch import load_file
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
weight_path = 'D:\\audio_cls_coursework\\result\\AST\\model.safetensors'
TARGET_LEN = 16000 * 10
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

config = ASTConfig.from_pretrained(model_name)
config.num_labels = 4

model = ASTForAudioClassification(config)
state_dict = load_file(weight_path)
model.load_state_dict(state_dict)
model.eval()

ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=T

In [5]:
import librosa

def preprocess_for_inference(path):
    waveform, sr = librosa.load(path, sr=16000)
    inputs = feature_extractor(
        waveform,
        sampling_rate=16000,
        return_tensors="pt"
    )
    return inputs["input_values"]

In [None]:
input_values = preprocess_for_inference(trimmed_audio)
input_values[0]

In [10]:
%%time
with torch.no_grad():
    outputs = model(input_values)
    logits = outputs.logits
    predicted_class = logits.argmax(dim=1).item()

print("Predicted class:", predicted_class)

Predicted class: 0
CPU times: total: 11.9 s
Wall time: 5.06 s


In [8]:
%%time

answer = {'siren': 0, 'gunshot': 1, 'explosion': 2, 'casual': 3}
label_name = [k for k, v in answer.items() if v == predicted_class][0]
print(label_name)

siren
