<a href="https://colab.research.google.com/github/dungdt-infopstats/Device-Directed-Speech-Segmentation/blob/main/src_prototype/DDSS_Test_Experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

tridungdo_test_ratio_res_analysis_path = kagglehub.dataset_download('tridungdo/test-ratio-res-analysis')
tridungdo_test_extend_ratio_speaker_split_path = kagglehub.dataset_download('tridungdo/test-extend-ratio-speaker-split')
tridungdo_test_100_150_path = kagglehub.dataset_download('tridungdo/test-100-150')
tridungdo_merge_raw_speech_wer_100_150_path = kagglehub.dataset_download('tridungdo/merge-raw-speech-wer-100-150')
tridungdo_test_infer_100_150_old_model_json_path = kagglehub.dataset_download('tridungdo/test-infer-100-150-old-model-json')
tridungdo_model_split_pytorch_default_1_path = kagglehub.model_download('tridungdo/model-split/PyTorch/default/1')
tridungdo_wav2vec2_asr_100_150_pytorch_default_1_path = kagglehub.model_download('tridungdo/wav2vec2-asr-100-150/PyTorch/default/1')
tridungdo_wav2vec2_base_100_150_pytorch_default_1_path = kagglehub.model_download('tridungdo/wav2vec2-base-100-150/PyTorch/default/1')

print('Data source import complete.')


# Model

In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm

def get_info(root_path):
    list_json = []

    # Đếm tổng số file JSON trước để tqdm có progress bar chính xác
    all_json_files = []
    for dirpath, _, filenames in os.walk(root_path):
        for file in filenames:
            if file.endswith('.json'):
                all_json_files.append(os.path.join(dirpath, file))

    # Duyệt qua file với tqdm
    for file_path in tqdm(all_json_files, desc="Processing JSON files"):
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            data['id'] = os.path.splitext(os.path.basename(file_path))[0]  # lấy tên file làm id
            list_json.append(data)

    return pd.DataFrame(list_json)

In [None]:
import torch
import torch.nn as nn
import torchaudio


class DDSSModel(nn.Module):
    def __init__(self, hidden_dim=256, num_layers=2, bidirectional=True, dropout=0.3, bundle_name = 'asr'):
        super(DDSSModel, self).__init__()

        # Base model: Wav2Vec2 (ASR Base)
        if bundle_name == 'asr':
            bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
        elif bundle_name == 'base':
            bundle = torchaudio.pipelines.WAV2VEC2_BASE
        self.feature_extractor = bundle.get_model()

        # Freeze Wav2Vec2
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # LSTM
        self.lstm = nn.LSTM(
            input_size=bundle._params['encoder_embed_dim'],
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional
        )

        lstm_output_dim = hidden_dim * (2 if bidirectional else 1)

        # Frame-level classifier (MLP thay vì 1 tầng)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_output_dim // 2),
            nn.BatchNorm1d(lstm_output_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(lstm_output_dim // 2, lstm_output_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(lstm_output_dim // 4, 1)  # binary logit
        )

    def forward(self, waveforms):
        """
        waveforms: Tensor [B, T_audio] (mono audio, 16kHz)
        """
        with torch.no_grad():
            features, _ = self.feature_extractor.extract_features(waveforms)
            x = features[-1]  # [B, T_feat, F]

        # LSTM
        x, _ = self.lstm(x)  # [B, T_feat, H]

        # Classifier cho từng frame
        B, T, H = x.shape
        x = x.reshape(-1, H)          # [B*T, H]
        out = self.classifier(x)      # [B*T, 1]
        out = out.view(B, T)          # [B, T]
        return out

# Dataset

In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm

def collect_json_to_df(root_dir: str, flatten: bool = True) -> pd.DataFrame:
    """
    Gom tất cả file JSON trong root_dir thành DataFrame.

    Args:
        root_dir (str): thư mục gốc chứa các thư mục con.
        flatten (bool): nếu True, mỗi label trong json sẽ thành 1 dòng riêng.
                        nếu False, labels sẽ giữ nguyên dạng list trong 1 cột.

    Returns:
        pd.DataFrame
    """
    records = []

    # Đếm trước tổng số file JSON để tqdm hiển thị progress đúng
    json_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for file in filenames:
            if file.endswith(".json"):
                json_files.append(os.path.join(dirpath, file))

    for json_path in tqdm(json_files, desc="Processing JSON files"):
        dirpath = os.path.dirname(json_path)
        file = os.path.basename(json_path)
        folder_name = os.path.basename(dirpath)

        # audio nằm cùng chỗ, tên trùng với folder
        audio_path = os.path.join(dirpath, file.replace('.json','.wav'))

        with open(json_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        if flatten:
            for entry in data['labels']:
                records.append({
                    "id": folder_name,
                    "id_detail": file.split('.')[0],
                    "audio_path": audio_path,
                    "label": entry.get("label"),
                    "start": entry.get("start"),
                    "end": entry.get("end")
                })
        else:
            records.append({
                "id": folder_name,
                "audio_path": audio_path,
                "labels": data
            })

    return pd.DataFrame(records)


In [None]:
import torch
from torch.utils.data import Dataset
import torchaudio

def frame_wav(wav, frame_size, hop_size):
    wav = wav.squeeze(0)  # (N,)
    frames = wav.unfold(0, frame_size, hop_size)  # shape (num_frames, frame_size)
    return frames

def label_frames(num_frames, frame_size, hop_size, sr, annotations):
    labels = torch.zeros(num_frames, dtype=torch.long)  # mặc định = 0

    for _, row in annotations.iterrows():
        start_sample = int(row['start'] * sr)
        end_sample = int(row['end'] * sr)

        # quy đổi sample về index frame
        start_frame = max(0, (start_sample - frame_size) // hop_size + 1)
        end_frame   = min(num_frames, end_sample // hop_size)

        if row['label'] == 'active':
            labels[start_frame:end_frame+1] = 1

    return labels

class DDSSDataset(Dataset):
    def __init__(self, annotations_file, frame_size=400, hop_size=160, target_sr=16000, transform=None, key = None):
        self.annotations_file = annotations_file
        self.key = key
        if self.key is not None:
            self.list_id = self.annotations_file[self.key].unique()
        else:
            self.list_id = self.annotations_file['id'].unique()
        self.frame_size = frame_size
        self.hop_size = hop_size
        self.target_sr = target_sr
        self.transform = transform

    def __len__(self):
        return len(self.list_id)

    def __getitem__(self, idx):
        id_file = self.list_id[idx]
        if self.key is not None:
            df_file = self.annotations_file[self.annotations_file[self.key] == id_file]
        else:
            df_file = self.annotations_file[self.annotations_file['id'] == id_file]
        audio_path = df_file['audio_path'].iloc[0]
        wav, sr = torchaudio.load(audio_path)

        # resample
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr)
            wav = resampler(wav)
            sr = self.target_sr

        # frame
        frames = frame_wav(wav, self.frame_size, self.hop_size)

        # label sequence
        labels = label_frames(frames.shape[0], self.frame_size, self.hop_size, sr, df_file)

        return wav, frames, labels, id_file

In [None]:
import torch
import torch.nn.functional as F

def label_resampler(labels: torch.Tensor, target_len: int) -> torch.Tensor:
    """
    labels: Tensor [num_frames], dtype long/int (0/1 hoặc nhiều class)
    target_len: int, số bước thời gian output của model

    Trả về Tensor [target_len], dtype long/int.
    """
    num_frames = labels.shape[0]

    if num_frames == target_len:
        return labels.clone()

    # Convert to float và reshape [1, 1, num_frames] để dùng interpolate
    labels_f = labels.float().unsqueeze(0).unsqueeze(0)  # [1, 1, N]

    # scale_factor hoặc size
    labels_resampled = F.interpolate(
        labels_f, size=target_len, mode="linear", align_corners=False
    )  # [1, 1, target_len]

    # Convert về long (class index), với binary thì round về 0/1
    labels_resampled = labels_resampled.squeeze().round().long()  # [target_len]

    return labels_resampled

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

# ====== 1) Collate: pad waveform theo batch, giữ độ dài gốc ======
def collate_fn(batch):
    """
    batch: list of (wav, frames, labels)
      - wav: Tensor [1, N]
      - frames: Tensor [num_frames, frame_size] (không dùng ở train)
      - labels: Tensor [num_frames] (0/1)
    """
    wavs, frames_list, labels_list = zip(*batch)
    lengths = torch.tensor([w.shape[-1] for w in wavs], dtype=torch.long)  # số sample gốc mỗi wav

    # pad wavs về cùng độ dài (theo max_len)
    max_len = int(lengths.max().item())
    padded = []
    for w in wavs:
        if w.shape[-1] < max_len:
            pad = torch.zeros((1, max_len - w.shape[-1]), dtype=w.dtype)
            w = torch.cat([w, pad], dim=-1)
        padded.append(w)
    batch_wav = torch.stack(padded, dim=0).squeeze(1)  # [B, max_len]

    return batch_wav, lengths, labels_list  # trả labels dạng list để nội suy sau

# Infer

In [None]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

def infer_and_save(model, dataset, threshold=0.5, device="cuda"):
    model.eval()
    model.to(device)
    records = []
    with torch.no_grad():
        for wav, frames, labels, id_f in tqdm(dataset, desc="Inferencing", unit="batch"):
            wav = wav.to(device)
            outputs = model(wav)

            outputs_resampled = label_resampler(outputs.squeeze(0), labels.shape[-1]).unsqueeze(0).to(device)
            probs = torch.sigmoid(outputs_resampled).squeeze(0).cpu().numpy()
            preds = (probs >= threshold).astype(int)

            records.append({
                "outputs": probs,
                "preds": preds,
                "labels": labels.numpy(),
                "text": None,   # tạm placeholder
                "id": id_f
            })

    df = pd.DataFrame(records)
    return df



# Flow

In [None]:
input_path = "/kaggle/input/test-100-150"

In [None]:
df_info = get_info(root_path=os.path.join(input_path, 'test'))

Processing JSON files: 100%|██████████| 23946/23946 [00:14<00:00, 1671.15it/s]


In [None]:
df_test = collect_json_to_df(f'{input_path}/test')

Processing JSON files: 100%|██████████| 23946/23946 [00:14<00:00, 1686.33it/s]


In [None]:
test_dataset = DDSSDataset(df_test, key = 'id_detail')

In [None]:
import os
import torch
def get_unique_filename(base_name, ext):
    filename = f"{base_name}{ext}"
    counter = 1
    while os.path.exists(filename):
        filename = f"{base_name}_{counter}{ext}"
        counter += 1
    return filename

model_paths = [
    ("base", "/kaggle/input/wav2vec2-base-100-150/pytorch/default/1/wav2vec2-base-100-150.pt"),
    ("asr", "/kaggle/input/wav2vec2-asr-100-150/pytorch/default/1/wav2vec2-asr-100-150.pt"),
]

for bundle_name, model_path in model_paths:
    model = DDSSModel(bundle_name = bundle_name)
    model.load_state_dict(torch.load(model_path))
    df_res = infer_and_save(model, test_dataset)
    model_name = os.path.splitext(os.path.basename(model_path))[0]  # bỏ .pt nếu có

    # Xuất CSV
    csv_path = get_unique_filename(model_name, ".csv")
    df_res.to_csv(csv_path, index=False)

    # Xuất JSON
    json_path = get_unique_filename(model_name, ".json")
    df_res.to_json(json_path, orient="records", force_ascii=False)


Inferencing:  64%|██████▎   | 15208/23946 [21:17<04:17, 33.92batch/s]  