In [1]:
import json
import random
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
import torchaudio
from onnxruntime.quantization import CalibrationDataReader, quantize_static
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from streamsad.feature_extractor import FeatureExtractor

In [2]:
@dataclass
class Config:
    # data path
    sad_base_path = "/home/aj/mahsan/AVA/cropped_wave"
    sad_json_path = "/home/aj/repo/SAD-AVA/ava_labels_exsiting_files.json"
    noise_base_path = "/home/aj/additive_noise"
    noise_json_path = "/home/aj/repo/SAD-AVA/music_singking_files.json"

    # features
    duration = 60.
    fs = 16000
    n_fft = 512
    n_hop = 512
    feature_epsilon = 1e-6

    # augmentation
    augment_prob = 0.4
    min_amplitude_percent = 10
    max_amplitude_percent = 100


class RandomCrop:
    def __init__(self) -> None:
        super().__init__()
        self.num_samples = int(Config.fs * Config.duration)

    def crop(self, x: torch.Tensor) -> tuple[torch.Tensor, int]:
        _, samples = x.size()
        offset = torch.randint(0, samples - self.num_samples - 1, (1,))
        return x[:, offset : offset + self.num_samples], offset.item()


class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_epsilon = Config.feature_epsilon
        self.window = np.hanning(Config.n_fft)

    def compute_fft(self, x_np):
        num_frames = x_np.shape[0] // Config.n_fft
        fft_frames_real = []
        for i in range(num_frames):
            start_idx = i * Config.n_fft
            end_idx = start_idx + Config.n_fft
            frame = x_np[start_idx:end_idx] * self.window
            fft_frame = np.fft.rfft(frame)
            fft_frame = (fft_frame * fft_frame.conj()).real
            fft_frame_real = np.log10(np.abs(fft_frame) + Config.feature_epsilon)
            fft_frames_real.append(fft_frame_real)
        fft_frames_real = np.array(fft_frames_real).T
        return fft_frames_real

    def forward(self, x):
        # x of the shape 1xT
        x_np = x.view(-1).numpy()
        fft_frames_real = self.compute_fft(x_np)
        return torch.from_numpy(fft_frames_real).unsqueeze(0).float()


class AVADS(Dataset):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.sad_base_path = Path(self.config.sad_base_path)
        with open(config.sad_json_path) as f:
            self.sad = json.load(f)
        self.sad_wavs = [
            torchaudio.load(self.sad_base_path / i["filename"])[0]
            for i in tqdm(self.sad)
        ]
        self.noise_base_path = Path(self.config.noise_base_path)
        with open(config.noise_json_path) as f:
            self.noise = json.load(f)
        self.noise_wavs = [
            torchaudio.load(self.noise_base_path / i)[0] for i in tqdm(self.noise)
        ]
        self.feature_extractor = FeatureExtractor()
        self.random_crop = RandomCrop()

    def _has_overlap(
        self, interval0: tuple[float, float], interval1: tuple[float, float]
    ) -> int:
        if interval1["start"] < interval0["start"] < interval1["end"]:
            return 1
        if interval1["start"] < interval0["end"] < interval1["end"]:
            return 1
        if interval0["start"] < interval1["start"] < interval0["end"]:
            return 1
        if interval0["start"] < interval1["start"] < interval0["end"]:
            return 1
        return 0

    def _create_label(self, labels: dict, offset: int, num_samples: int) -> list[int]:
        win_len = self.config.n_fft
        hop_len = self.config.n_fft
        fs = self.config.fs
        frame_start = offset
        frame_end = offset + win_len
        labels_gen = iter(labels)
        try:
            label = next(labels_gen)
        except StopIteration:
            label = {"start": torch.inf, "end": torch.inf}
        result = []
        while frame_end <= offset + num_samples:
            # update label if needed
            if frame_start / fs > label["end"]:
                try:
                    label = next(labels_gen)
                    continue
                except StopIteration:
                    pass
            # tag the frame
            interval0 = {"start": frame_start / fs, "end": frame_end / fs}
            result.append(self._has_overlap(interval0, label))
            # step forward
            frame_start += hop_len
            frame_end += hop_len
        return result

    def __len__(self) -> int:
        return len(self.sad)

    def __getitem__(self, index: int) -> tuple[torch.FloatTensor, torch.LongTensor]:
        x = self.sad_wavs[index]
        sample = self.sad[index]
        labels = sample["vad"]
        # input
        x, offset = self.random_crop.crop(x)
        # amplitude augmentation
        x *= random.uniform(0.1, 2)
        if torch.rand(1).item() <= self.config.augment_prob:
            noise = random.choice(self.noise_wavs)
            # print(f"1: {noise.shape = }")
            # print(f"{x.shape = }")
            while noise.size(1) <= x.size(1):
                noise = torch.cat([noise, noise], dim=1)
            # print(f"2: {noise.shape = }")
            noise, _ = self.random_crop.crop(noise)
            noise *= (
                torch.randint(
                    self.config.min_amplitude_percent,
                    self.config.max_amplitude_percent,
                    size=(1,),
                )
                / 100
            )
            x += noise
        num_samples = x.size(1)
        lmfb = self.feature_extractor(x).squeeze(0)
        # target
        target = self._create_label(labels, offset, num_samples)
        # check size
        while len(target) < lmfb.size(1):
            target.append(target[-1])
        target = torch.LongTensor(target)
        return x, lmfb, target


sad_dataset = AVADS(Config)
sad_dataloader = DataLoader(sad_dataset, batch_size=1, num_workers=4, shuffle=True)

100%|██████████| 119/119 [00:07<00:00, 15.61it/s]
100%|██████████| 1300/1300 [00:13<00:00, 97.52it/s]


In [None]:
class SADCalibrationDataReader(CalibrationDataReader):
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.dl_iter = None
        self.rewind()

    def get_state(self, x):
        batch_size = x.size(0)
        return np.zeros((batch_size, 1, 64), dtype=np.float32)

    def get_next(self):
        x = next(self.dl_iter, None)
        if x is not None:
            # x is a tuple consists of: (audio, lmfb, target)
            x = x[1]
            return {"input": x.numpy(), "input_state": self.get_state(x)}
        else:
            return None

    def rewind(self):
        self.dl_iter = iter(self.dataloader)

In [10]:
sad_calibration_data_reader = SADCalibrationDataReader(sad_dataloader)

# input_dict = sad_calibration_data_reader.get_next()
# model_fp32 = "../src/streamsad/models/model_2025-06-10.onnx"
# ort_session_fp32 = ort.InferenceSession(model_fp32)
# raw_output_fp32, _ = ort_session_fp32.run(
#     None,
#     input_dict,
# )
# input_dict["input"].shape, input_dict["input_state"].shape, input_dict.keys, input_dict["input"].dtype, input_dict["input_state"].dtype

In [11]:
model_fp32 = "../src/streamsad/models/model_2025-06-10.onnx"
model_pre = "model_2025-06-10_pre.onnx"
model_int8_static = "model_2025-06-10_static_int8.onnx"

q_static_opts = {"ActivationSymmetric":False, "WeightSymmetric":True}
if torch.cuda.is_available():
    q_static_opts = {"ActivationSymmetric":True, "WeightSymmetric":True}

quantized_model = quantize_static(
    model_input=model_pre,
    model_output=model_int8_static,
    calibration_data_reader=sad_calibration_data_reader,
    extra_options=q_static_opts,
)