In [1]:
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class KSSDataset(Dataset):
    def __init__(self, split='train', valid_size=0.1, seed=42):
        super().__init__()
        
        # 데이터셋 로드
        dataset = load_dataset("Bingsu/KSS_Dataset")
        full_dataset = dataset['train']
        
        # train/valid 분할
        train_idx, valid_idx = train_test_split(
            range(len(full_dataset)),
            test_size=valid_size,
            random_state=seed
        )
        
        # split에 따라 인덱스 선택
        self.indices = train_idx if split == 'train' else valid_idx
        self.dataset = full_dataset
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        item = self.dataset[real_idx]
        
        # 오디오 로드 및 리샘플링
        audio = torch.from_numpy(item['audio']['array']).float()
        original_sr = item['audio']['sampling_rate']
        
        if original_sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(
                orig_freq=original_sr,
                new_freq=self.target_sr
            )
            audio = resampler(audio)
        
        return {
            'audio': audio,
            'text': item['original_script'],
            'sampling_rate': self.target_sr
        }