<a href="https://colab.research.google.com/github/dungdt-infopstats/Device-Directed-Speech-Segmentation/blob/main/src_prototype/DDSS_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

tridungdo_100_150_test_path = kagglehub.dataset_download('tridungdo/100-150-test')
tridungdo_test_100_150_path = kagglehub.dataset_download('tridungdo/test-100-150')
tridungdo_train_100_150_path = kagglehub.dataset_download('tridungdo/train-100-150')
tridungdo_model_split_pytorch_default_1_path = kagglehub.model_download('tridungdo/model-split/PyTorch/default/1')

print('Data source import complete.')


# Data

In [None]:
import os
import json
import pandas as pd

def collect_json_to_df(root_dir: str, flatten: bool = True) -> pd.DataFrame:
    """
    Gom tất cả file JSON trong root_dir thành DataFrame.

    Args:
        root_dir (str): thư mục gốc chứa các thư mục con.
        flatten (bool): nếu True, mỗi label trong json sẽ thành 1 dòng riêng.
                        nếu False, labels sẽ giữ nguyên dạng list trong 1 cột.

    Returns:
        pd.DataFrame
    """
    records = []

    for dirpath, _, filenames in os.walk(root_dir):
        for file in filenames:
            if file.endswith(".json"):
                json_path = os.path.join(dirpath, file)
                folder_name = os.path.basename(dirpath)

                # audio nằm cùng chỗ, tên trùng với folder
                audio_path = os.path.join(dirpath, folder_name + "_aug" + ".wav")

                with open(json_path, "r", encoding="utf-8") as f:
                    data = json.load(f)

                if flatten:
                    for entry in data['labels']:
                        records.append({
                            "id": folder_name,
                            "audio_path": audio_path,
                            "label": entry.get("label"),
                            "start": entry.get("start"),
                            "end": entry.get("end"),
                        })
                else:
                    records.append({
                        "id": folder_name,
                        "audio_path": audio_path,
                        "labels": data
                    })

    return pd.DataFrame(records)


In [None]:
import json
import pandas as pd
import os
def get_information(root_folder):
    list_info = []
    for folder in os.listdir(root_folder):
        for file in os.listdir(os.path.join(root_folder, folder)):
            if file.endswith('.json'):
                with open(os.path.join(root_folder, folder, file), 'r') as f:
                    data = json.load(f)
                    data['id'] = folder
                list_info.append(data)
    df = pd.DataFrame(list_info)
    return df

info_train = get_information("/kaggle/input/train-100-150/train-100-150-1/train")
info_val = get_information("/kaggle/input/test-100-150/test")

In [None]:
df_train = collect_json_to_df("/kaggle/input/train-100-150/train-100-150-1/train")
df_test = collect_json_to_df("/kaggle/input/test-100-150/test")

In [None]:
import torch
from torch.utils.data import Dataset
import torchaudio

def frame_wav(wav, frame_size, hop_size):
    wav = wav.squeeze(0)  # (N,)
    frames = wav.unfold(0, frame_size, hop_size)  # shape (num_frames, frame_size)
    return frames

def label_frames(num_frames, frame_size, hop_size, sr, annotations):
    labels = torch.zeros(num_frames, dtype=torch.long)  # mặc định = 0

    for _, row in annotations.iterrows():
        start_sample = int(row['start'] * sr)
        end_sample = int(row['end'] * sr)

        # quy đổi sample về index frame
        start_frame = max(0, (start_sample - frame_size) // hop_size + 1)
        end_frame   = min(num_frames, end_sample // hop_size)

        if row['label'] == 'active':
            labels[start_frame:end_frame+1] = 1

    return labels

class DDSSDataset(Dataset):
    def __init__(self, annotations_file, frame_size=400, hop_size=160, target_sr=16000, transform=None):
        self.annotations_file = annotations_file
        self.list_id = self.annotations_file['id'].unique()
        self.frame_size = frame_size
        self.hop_size = hop_size
        self.target_sr = target_sr
        self.transform = transform

    def __len__(self):
        return len(self.list_id)

    def __getitem__(self, idx):
        id_file = self.list_id[idx]
        df_file = self.annotations_file[self.annotations_file['id'] == id_file]
        audio_path = df_file['audio_path'].iloc[0]
        wav, sr = torchaudio.load(audio_path)

        # resample
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr)
            wav = resampler(wav)
            sr = self.target_sr

        # frame
        frames = frame_wav(wav, self.frame_size, self.hop_size)

        # label sequence
        labels = label_frames(frames.shape[0], self.frame_size, self.hop_size, sr, df_file)

        return wav, frames, labels, id_file, df_file

In [None]:
import torch
import torch.nn.functional as F

def convert_labels(labels, new_num_frames):
    """
    labels: tensor (B, old_num_frames) với giá trị 0/1
    new_num_frames: số frame mới mong muốn

    Trả về: tensor (B, new_num_frames)
    """
    if labels.ndim == 1:
        # thêm batch dim nếu chỉ có 1 sequence
        labels = labels.unsqueeze(0)  # (1, old_num_frames)

    labels = labels.float().unsqueeze(1)  # (B, 1, old_num_frames)

    # interpolate theo chiều thời gian
    new_labels = F.interpolate(labels, size=new_num_frames, mode="nearest")

    return new_labels.squeeze(1).long()  # (B, new_num_frames)


# Model

In [None]:
import torch
import torch.nn as nn
import torchaudio


class DDSSModel(nn.Module):
    def __init__(self, hidden_dim=256, num_layers=2, bidirectional=True, dropout=0.3):
        super(DDSSModel, self).__init__()

        # Base model: Wav2Vec2 (ASR Base)
        bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
        self.feature_extractor = bundle.get_model()

        # Freeze Wav2Vec2
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # LSTM
        self.lstm = nn.LSTM(
            input_size=bundle._params['encoder_embed_dim'],
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional
        )

        lstm_output_dim = hidden_dim * (2 if bidirectional else 1)

        # Frame-level classifier (MLP thay vì 1 tầng)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_output_dim // 2),
            nn.BatchNorm1d(lstm_output_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(lstm_output_dim // 2, lstm_output_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(lstm_output_dim // 4, 1)  # binary logit
        )

    def forward(self, waveforms):
        """
        waveforms: Tensor [B, T_audio] (mono audio, 16kHz)
        """
        with torch.no_grad():
            features, _ = self.feature_extractor.extract_features(waveforms)
            x = features[-1]  # [B, T_feat, F]

        # LSTM
        x, _ = self.lstm(x)  # [B, T_feat, H]

        # Classifier cho từng frame
        B, T, H = x.shape
        x = x.reshape(-1, H)          # [B*T, H]
        out = self.classifier(x)      # [B*T, 1]
        out = out.view(B, T)          # [B, T]
        return out


In [None]:
import torch
import torch.nn as nn
import torchaudio

class DDSSModelFusion(nn.Module):
    def __init__(self, hidden_dim=512, num_layers=3, bidirectional=True, dropout=0.3):
        super(DDSSModelFusion, self).__init__()

        # Base model 1: Wav2Vec2 Base (self-supervised)
        bundle_base = torchaudio.pipelines.WAV2VEC2_BASE
        self.feature_extractor_base = bundle_base.get_model()

        # Base model 2: Wav2Vec2 ASR Base (fine-tuned for ASR)
        bundle_asr = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
        self.feature_extractor_asr = bundle_asr.get_model()

        # Freeze both Wav2Vec2 models
        for param in self.feature_extractor_base.parameters():
            param.requires_grad = False
        for param in self.feature_extractor_asr.parameters():
            param.requires_grad = False

        # Get feature dimensions from both models
        base_dim = bundle_base._params['encoder_embed_dim']  # 768
        asr_dim = bundle_asr._params['encoder_embed_dim']    # 768

        # Concatenated feature dimension
        concat_dim = base_dim + asr_dim  # 1536

        # Larger BiLSTM to handle more information
        self.lstm = nn.LSTM(
            input_size=concat_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional
        )

        lstm_output_dim = hidden_dim * (2 if bidirectional else 1)

        # Frame-level classifier (MLP)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_output_dim // 2),
            nn.BatchNorm1d(lstm_output_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim // 2, lstm_output_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim // 4, 1)  # binary logit
        )

    def forward(self, waveforms):
        """
        waveforms: Tensor [B, T_audio] (mono audio, 16kHz)
        """
        with torch.no_grad():
            # Extract features from Wav2Vec2 Base
            features_base, _ = self.feature_extractor_base.extract_features(waveforms)
            x_base = features_base[-1]  # [B, T_feat, 768]

            # Extract features from Wav2Vec2 ASR Base
            features_asr, _ = self.feature_extractor_asr.extract_features(waveforms)
            x_asr = features_asr[-1]    # [B, T_feat, 768]

            # Concatenate features along feature dimension
            x = torch.cat([x_base, x_asr], dim=-1)  # [B, T_feat, 1536]

        # BiLSTM
        x, _ = self.lstm(x)  # [B, T_feat, hidden_dim*2]

        # Classifier cho từng frame
        B, T, H = x.shape
        x = x.reshape(-1, H)          # [B*T, H]
        out = self.classifier(x)      # [B*T, 1]
        out = out.view(B, T)          # [B, T]

        return out

In [None]:
import torch
import torch.nn.functional as F

def label_resampler(labels: torch.Tensor, target_len: int) -> torch.Tensor:
    """
    labels: Tensor [num_frames], dtype long/int (0/1 hoặc nhiều class)
    target_len: int, số bước thời gian output của model

    Trả về Tensor [target_len], dtype long/int.
    """
    num_frames = labels.shape[0]

    if num_frames == target_len:
        return labels.clone()

    # Convert to float và reshape [1, 1, num_frames] để dùng interpolate
    labels_f = labels.float().unsqueeze(0).unsqueeze(0)  # [1, 1, N]

    # scale_factor hoặc size
    labels_resampled = F.interpolate(
        labels_f, size=target_len, mode="linear", align_corners=False
    )  # [1, 1, target_len]

    # Convert về long (class index), với binary thì round về 0/1
    labels_resampled = labels_resampled.squeeze().round().long()  # [target_len]

    return labels_resampled


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

# ====== 1) Collate: pad waveform theo batch, giữ độ dài gốc ======
def collate_fn(batch):
    """
    batch: list of (wav, frames, labels)
      - wav: Tensor [1, N]
      - frames: Tensor [num_frames, frame_size] (không dùng ở train)
      - labels: Tensor [num_frames] (0/1)
    """
    wavs, frames_list, labels_list, _, _ = zip(*batch)
    lengths = torch.tensor([w.shape[-1] for w in wavs], dtype=torch.long)  # số sample gốc mỗi wav

    # pad wavs về cùng độ dài (theo max_len)
    max_len = int(lengths.max().item())
    padded = []
    for w in wavs:
        if w.shape[-1] < max_len:
            pad = torch.zeros((1, max_len - w.shape[-1]), dtype=w.dtype)
            w = torch.cat([w, pad], dim=-1)
        padded.append(w)
    batch_wav = torch.stack(padded, dim=0).squeeze(1)  # [B, max_len]

    return batch_wav, lengths, labels_list  # trả labels dạng list để nội suy sau


class TemporalBCELoss(nn.Module):
    """
    BCE loss theo timestep + regularization để học mối quan hệ giữa các frame kế nhau.
    """
    def __init__(self, lambda_smooth=0.1):
        """
        lambda_smooth: trọng số cho phần temporal smoothing
        """
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.lambda_smooth = lambda_smooth

    def forward(self, logits, targets, valid_steps):
        """
        logits: [B, T] (raw scores)
        targets: [B, T] (0/1 float)
        valid_steps: [B] số timestep hợp lệ cho mỗi mẫu (loại bỏ phần pad)
        """
        B, T = logits.shape
        device = logits.device

        # ===== BCE cơ bản =====
        loss = self.bce(logits, targets)  # [B, T]

        # mask cho phần hợp lệ
        mask = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
        mask = (mask < valid_steps.unsqueeze(1)).float()
        loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)

        # ===== Temporal smoothing =====
        if self.lambda_smooth > 0 and T > 1:
            prob = torch.sigmoid(logits)  # [B, T]
            # phạt sự thay đổi giữa các timestep liên tiếp
            smooth_loss = ((prob[:, 1:] - prob[:, :-1]) ** 2)
            smooth_mask = mask[:, 1:] * mask[:, :-1]  # chỉ tính frame hợp lệ
            smooth_loss = (smooth_loss * smooth_mask).sum() / smooth_mask.sum().clamp_min(1.0)
            loss = loss + self.lambda_smooth * smooth_loss

        return loss



from tqdm import tqdm

def train_one_epoch(model, dataloader, optimizer, device, label_resampler, scaler=None, grad_clip=1.0):
    model.train()
    criterion = TemporalBCELoss()
    total_loss, total_steps = 0.0, 0

    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch_wav, lengths, labels_list in pbar:
        batch_wav = batch_wav.to(device)
        lengths = lengths.to(device)

        optimizer.zero_grad(set_to_none=True)

        # Forward
        if scaler is None:
            logits = model(batch_wav)         # [B, T]
        else:
            with torch.autocast(device_type=device.type, dtype=torch.float16 if device.type=="cuda" else torch.bfloat16):
                logits = model(batch_wav)

        B, T = logits.shape

        # Resample labels về T
        targets = []
        valid_steps = []
        max_len = batch_wav.shape[1]
        for i in range(B):
            t_i = label_resampler(labels_list[i], T)
            targets.append(t_i)
            t_valid = int((lengths[i].item() / max_len) * T)
            valid_steps.append(t_valid)

        targets = torch.stack(targets, dim=0).to(device).float()
        valid_steps = torch.tensor(valid_steps, device=device, dtype=torch.long)

        loss = criterion(logits, targets, valid_steps)

        if scaler is None:
            loss.backward()
            if grad_clip is not None:
                clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
        else:
            scaler.scale(loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()

        total_loss += loss.item() * B
        total_steps += B

        pbar.set_postfix(loss=f"{loss.item():.4f}")

    return total_loss / max(total_steps, 1)


@torch.no_grad()
def evaluate(model, dataloader, device, label_resampler):
    model.eval()
    criterion = TemporalBCELoss()
    total_loss, total_steps = 0.0, 0

    pbar = tqdm(dataloader, desc="Evaluating", leave=False)
    for batch_wav, lengths, labels_list in pbar:
        batch_wav = batch_wav.to(device)
        lengths = lengths.to(device)
        logits = model(batch_wav)
        B, T = logits.shape

        targets = []
        valid_steps = []
        max_len = batch_wav.shape[1]
        for i in range(B):
            t_i = label_resampler(labels_list[i], T)
            targets.append(t_i)
            t_valid = int((lengths[i].item() / max_len) * T)
            valid_steps.append(t_valid)

        targets = torch.stack(targets, dim=0).to(device).float()
        valid_steps = torch.tensor(valid_steps, device=device, dtype=torch.long)

        loss = criterion(logits, targets, valid_steps)
        total_loss += loss.item() * B
        total_steps += B

        pbar.set_postfix(loss=f"{loss.item():.4f}")

    return total_loss / max(total_steps, 1)



# ====== 4) Hàm chạy toàn bộ training ======
from torch.utils.data import random_split, DataLoader

def fit(
    model,
    dataset,                # dataset chung, sẽ chia train/val nếu val_ds=None
    val_ds=None,
    val_split=0.1,          # tỉ lệ validation nếu chỉ truyền dataset
    epochs=10,
    batch_size=8,
    lr=1e-3,
    weight_decay=0.0,
    num_workers=4,
    label_resampler=label_resampler,   # bắt buộc: hàm nội suy labels -> T
    use_amp=True,
    grad_clip=1.0,
    seed=42,
):
    assert label_resampler is not None, "Bạn cần truyền hàm label_resampler(seq_labels, target_len)."

    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Tách train/val nếu val_ds không được truyền
    if val_ds is None:
        val_size = int(len(dataset) * val_split)
        train_size = len(dataset) - val_size
        train_ds, val_ds = random_split(dataset, [train_size, val_size])
    else:
        train_ds = dataset

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    val_loader = None
    if val_ds is not None:
        val_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
        )

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and device.type == "cuda"))

    best_val = float("inf")
    for epoch in range(1, epochs + 1):
        train_loss = train_one_epoch(
            model, train_loader, optimizer, device, label_resampler, scaler=scaler, grad_clip=grad_clip
        )
        if val_loader is not None:
            val_loss = evaluate(model, val_loader, device, label_resampler)
            print(f"[Epoch {epoch:02d}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")
            if val_loss < best_val:
                best_val = val_loss
                torch.save(model.state_dict(), "best_ddss.pt")
        else:
            print(f"[Epoch {epoch:02d}] train_loss={train_loss:.4f}")

    return model


In [None]:
train_dataset = DDSSDataset(
    annotations_file = df_train,
)

val_dataset = DDSSDataset(
    annotations_file = df_test
)

In [None]:
model = DDSSModelFusion()

In [None]:
trained_model = fit(
    model=model,
    dataset=train_dataset,
    val_ds=None,                   # hoặc dataset validation nếu có
    epochs=200,
    batch_size=8,
    lr=1e-3,
    label_resampler=label_resampler,   # hàm mình code ở trên
    use_amp=True
)
