<a href="https://colab.research.google.com/github/10udCryp7/TV-command-synthesis/blob/main/src_prototype/Phase4_NoiseAugmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import random
import numpy as np
import torch
import torchaudio
from copy import deepcopy
import json
from typing import List, Dict

class NoiseAugmentor:
    def __init__(self,
                 noise_folders,
                 target_sr=16000,
                 snr_range=(-5, 20),
                 prepend_range_s=(2, 5),   # giây
                 append_range_s=(2, 5),    # giây
                 overlap_range_s=(2, 5),   # giây
                 seed=None):
        self.noise_files = []
        for folder in noise_folders:
            for f in os.listdir(folder):
                if f.endswith(".wav"):
                    self.noise_files.append(os.path.join(folder, f))
        assert len(self.noise_files) > 0, "No noise files found!"

        self.sr = target_sr
        self.snr_range = snr_range
        self.prepend_range_s = prepend_range_s
        self.append_range_s = append_range_s
        self.overlap_range_s = overlap_range_s
        if seed is not None:
          random.seed(seed)
          np.random.seed(seed)
          torch.manual_seed(seed)

    def _sec2samples(self, sec):
        return int(sec * self.sr)

    def _load_audio(self, path):
        wav, sr = torchaudio.load(path)
        wav = wav.mean(dim=0, keepdim=True)  # mono
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        return wav.squeeze(0)

    def _scale_noise(self, clean, noise, snr_db):
        Px = clean.pow(2).mean()
        Pn = noise.pow(2).mean() + 1e-12
        k = torch.sqrt(Px / (Pn * (10**(snr_db/10))))
        return noise * k

    def _sample_noise_file(self):
        f = random.choice(self.noise_files)
        return self._load_audio(f), f

    def _sample_noise_segment(self, noise, dur_samples):
        if noise.shape[-1] <= dur_samples:
            # noise ngắn quá → lặp/pad
            reps = (dur_samples // noise.shape[-1]) + 1
            noise = noise.repeat(reps)[:dur_samples]
            return noise
        else:
            start = random.randint(0, noise.shape[-1] - dur_samples)
            return noise[start:start+dur_samples]

    def _shift_labels(self, labels, shift_s):
        shifted = []
        for l in labels:
            shifted.append({
                "label": l["label"],
                "start": l["start"] + shift_s,
                "end": l["end"] + shift_s
            })
        return shifted

    def _append_noise_label(self, labels, start_s, end_s):
        labels = deepcopy(labels)
        labels.append({"label": "noise", "start": start_s, "end": end_s})
        return labels

    def _merge_overlap_labels(self, labels, noise_start, noise_end):
        new_labels = deepcopy(labels)
        add_noise = []
        for l in labels:
            if l["label"] == "active":
                if noise_end <= l["start"] or noise_start >= l["end"]:
                    continue  # không chồng
                # noise nhô ra bên trái
                if noise_start < l["start"]:
                    add_noise.append({"label": "noise", "start": noise_start, "end": l["start"]})
                # noise nhô ra bên phải
                if noise_end > l["end"]:
                    add_noise.append({"label": "noise", "start": l["end"], "end": noise_end})
        new_labels.extend(add_noise)
        return new_labels

    def _load_labels(self, json_path: str) -> List[Dict]:
        """
        Đọc file JSON label (dạng list segment) và trả về list dict.

        Args:
            json_path (str): đường dẫn tới file .json

        Returns:
            List[Dict]: danh sách label segment, mỗi phần tử có keys: "label", "start", "end"
        """
        with open(json_path, "r", encoding="utf-8") as f:
            labels = json.load(f)

        # đảm bảo dữ liệu hợp lệ (optional)
        cleaned = []
        for seg in labels:
            if all(k in seg for k in ["label", "start", "end"]):
                cleaned.append({
                    "label": str(seg["label"]),
                    "start": float(seg["start"]),
                    "end": float(seg["end"])
                })
        return cleaned


    def augment(self, clean_path, labels_path, mode=None):
        labels = self._load_labels(labels_path)
        clean = self._load_audio(clean_path)
        speech_len = clean.shape[-1]
        speech_dur = speech_len / self.sr

        noise, noise_path = self._sample_noise_file()
        if mode is None:
            mode = random.choice(["overlap","prepend", "overlap", "append", "overlap"])
        snr_db = random.uniform(*self.snr_range)

        if mode == "prepend":
            dur_s = random.uniform(*self.prepend_range_s)
            dur_samples = self._sec2samples(dur_s)
            seg = self._sample_noise_segment(noise, dur_samples)
            out = torch.cat([seg, clean])
            new_labels = self._shift_labels(labels, dur_s)
            new_labels = self._append_noise_label(new_labels, 0, dur_s)

        elif mode == "append":
            dur_s = random.uniform(*self.append_range_s)
            dur_samples = self._sec2samples(dur_s)
            seg = self._sample_noise_segment(noise, dur_samples)
            out = torch.cat([clean, seg])
            new_labels = deepcopy(labels)
            new_labels = self._append_noise_label(new_labels, speech_dur, speech_dur + dur_s)

        elif mode == "overlap":
            dur_s = random.uniform(*self.overlap_range_s)
            dur_samples = min(self._sec2samples(dur_s), speech_len)
            offset = random.randint(0, speech_len - dur_samples)
            offset_s = offset / self.sr
            seg = self._sample_noise_segment(noise, dur_samples)

            noise_aligned = torch.zeros_like(clean)
            noise_aligned[offset:offset+dur_samples] = seg
            noise_scaled = self._scale_noise(clean, noise_aligned, snr_db)
            out = clean + noise_scaled

            noise_start = offset_s
            noise_end = offset_s + dur_s
            new_labels = self._merge_overlap_labels(labels, noise_start, noise_end)

        metadata = {
            "clean_path": clean_path,
            "noise_path": noise_path,
            "mode": mode,
            "snr_db": snr_db,
            "labels": new_labels
        }
        return out, metadata


In [3]:
!gdown 1ri2VCkL9gkwvgxstEbE6zTqS1GpgSXZy

Downloading...
From (original): https://drive.google.com/uc?id=1ri2VCkL9gkwvgxstEbE6zTqS1GpgSXZy
From (redirected): https://drive.google.com/uc?id=1ri2VCkL9gkwvgxstEbE6zTqS1GpgSXZy&confirm=t&uuid=4e914a25-dad3-45e5-b84b-4499570e2eff
To: /content/sample-5000-vctk-concat.zip
100% 413M/413M [00:09<00:00, 45.8MB/s]


In [4]:
!unzip -q sample-5000-vctk-concat.zip

In [5]:
!wget -O musan.tar.gz https://www.openslr.org/resources/17/musan.tar.gz

--2025-08-28 03:36:33--  https://www.openslr.org/resources/17/musan.tar.gz
Resolving www.openslr.org (www.openslr.org)... 136.243.171.4
Connecting to www.openslr.org (www.openslr.org)|136.243.171.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11086114085 (10G) [application/x-gzip]
Saving to: ‘musan.tar.gz’


2025-08-28 03:49:23 (13.8 MB/s) - ‘musan.tar.gz’ saved [11086114085/11086114085]



In [6]:
!tar -xzf musan.tar.gz

In [None]:
NOISE_FOLDERS = [
    "/content/musan/speech/us-gov",
    "/content/musan/speech/librivox"
]

In [None]:
import os
import json
import soundfile as sf
from tqdm import tqdm  # thêm tqdm

OUTPUT_DIR = "/content/concat_speech/augmented"

def main():
    na = NoiseAugmentor(
        noise_folders=NOISE_FOLDERS
    )

    input_root = "/content/concat_speech/sample"

    # dùng tqdm để có progress bar
    for file_folder in tqdm(os.listdir(input_root), desc="Augmenting"):
        folder_path = os.path.join(input_root, file_folder)
        if not os.path.isdir(folder_path):
            continue

        audio_name = f"{file_folder}_concat.wav"
        json_name = f"{file_folder}.json"

        file_path = os.path.join(folder_path, audio_name)
        label_path = os.path.join(folder_path, json_name)

        # augment
        out, meta = na.augment(
            clean_path=file_path,
            labels_path=label_path,
        )

        # make output folder
        out_folder = os.path.join(OUTPUT_DIR, file_folder)
        os.makedirs(out_folder, exist_ok=True)

        # export audio (soundfile expects [T, C] or [T,])
        out_audio_path = os.path.join(out_folder, audio_name.replace("_concat.wav", "_aug.wav"))
        sf.write(out_audio_path, out.squeeze(0).numpy().T, na.sr)

        # export metadata
        out_json_path = os.path.join(out_folder, json_name.replace(".json", "_aug.json"))
        with open(out_json_path, "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2, ensure_ascii=False)


In [None]:
main()

Augmenting: 100%|██████████| 3021/3021 [04:09<00:00, 12.11it/s]


In [None]:
!zip aug_speech.zip concat_speech/augmented -r

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  adding: concat_speech/augmented/chain_active_ba6bb8b3/chain_active_ba6bb8b3_aug.wav (deflated 16%)
  adding: concat_speech/augmented/chain_active_ba6bb8b3/chain_active_ba6bb8b3_aug.json (deflated 42%)
  adding: concat_speech/augmented/single_active_8b7f03b1/ (stored 0%)
  adding: concat_speech/augmented/single_active_8b7f03b1/single_active_8b7f03b1_aug.json (deflated 50%)
  adding: concat_speech/augmented/single_active_8b7f03b1/single_active_8b7f03b1_aug.wav (deflated 17%)
  adding: concat_speech/augmented/non_active_da020fa8/ (stored 0%)
  adding: concat_speech/augmented/non_active_da020fa8/non_active_da020fa8_aug.json (deflated 42%)
  adding: concat_speech/augmented/non_active_da020fa8/non_active_da020fa8_aug.wav (deflated 13%)
  adding: concat_speech/augmented/single_active_af39ead4/ (stored 0%)
  adding: concat_speech/augmented/single_active_af39ead4/single_active_af39ead4_aug.json (deflated 48%)
  adding: concat_sp