In [None]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
from wave import open as wave_open

class WavDataset(Dataset):
    def __init__(self, audio_paths, labels=None, transform=None, desired_length=16000):
        self.audio_paths = audio_paths
        self.labels = labels
        self.transform = transform
        self.desired_length = desired_length

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        # 加载WAV文件
        file_path = self.audio_paths[idx]
        with wave_open(file_path, 'rb') as wav_file:
            frames = wav_file.readframes(wav_file.getnframes())
            wav_data = np.frombuffer(frames, dtype=np.int16)

        # 归一化
        wav_data = wav_data.astype(np.float32) / 32768.0

        # 长度标准化
        if len(wav_data) != self.desired_length:
            if len(wav_data) < self.desired_length:
                padding = np.zeros(self.desired_length - len(wav_data))
                wav_data = np.concatenate((wav_data, padding))
            else:
                wav_data = wav_data[:self.desired_length]

        # 应用transform
        if self.transform:
            wav_data = self.transform(wav_data)

        # 标签
        label = self.labels[idx] if self.labels is not None else None

        return wav_data, label

# 假设你已经有了WAV文件路径列表和对应的标签列表
audio_paths = ['path/to/wav1.wav', 'path/to/wav2.wav', ...]
labels = [0, 1, ...]  # 假设0和1是标签

# 创建数据集实例
dataset = WavDataset(audio_paths, labels)

# 现在你可以使用这个数据集实例与PyTorch的DataLoader一起使用
# 例如：
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

In [6]:
import os
import json


def get_labels(audio_paths):
    labels = []
    with open('config.json', 'r') as config_file:
        config = json.load(config_file)
    
    for path in audio_paths:
        word = path.split('\\')
        label.append(config["label"][word])


def get_path(mode):
    
    with open('data/testing_list.txt', 'r', encoding='utf-8') as f:
        test_path = f.readlines()
        test_path = [element for element in test_path if element]

    with open('data/validating_list.txt', 'r', encoding='utf-8') as f:
        val_path = f.readlines()
        val_path = [element for element in val_path if element]

    with open('data/training_list.txt', 'r', encoding='utf-8') as f:
        train_path = f.readlines()
        train_path = [element for element in train_path if element]

    # train_path = []
    # for root, dirs, files in os.walk('data/data_mini_merge'):
    #     for filename in files:
    #         if filename.endswith('.wav'):
    #             # 打印文件的完整路径
    #             path = os.path.join(root, filename)
    #             path = path.split('data_mini_merge\\')[-1]
    #             train_path.append(path)
    # with open('data/training_list.txt', 'w', encoding='utf-8') as f:
    #     f.write('\n'.join(train_path))
    if mode == 'test':
        return test_path
    elif mode == 'val':
        return val_path
    elif mode == 'train':
        return train_path
path = get_path('val')
label = get_labels(path)

NameError: name 'label' is not defined