In [95]:
import torch
import os
import numpy as np
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_list = []
        
        self.labels_to_int = {}
        labels = sorted(os.listdir(root_dir))
        for i in range(len(labels)):
            self.labels_to_int[labels[i]] = i

        # Iterate through each class folder
        for class_folder in os.listdir(root_dir):
            class_path = os.path.join(root_dir, class_folder)
            
            # Check if it's a directory
            if os.path.isdir(class_path):
                # Iterate through CSV files in the class folder
                for file_name in os.listdir(class_path):
                    if file_name.endswith('.csv'):
                        file_path = os.path.join(class_path, file_name)
                        self.file_list.append((file_path, class_folder))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path, label = self.file_list[idx]
        data =  data_tensor = torch.tensor(np.genfromtxt(file_path, delimiter=','), dtype=torch.float32)

        label = self.labels_to_int[label]
        sample = {'data': data, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
    
# for variable length input    
def collate_fn(batch):
    data, labels = zip(*[(sample['data'], sample['label']) for sample in batch])

    # Find the maximum number of columns in the batch
    max_cols = max(seq.size(1) for seq in data)

    # Pad sequences to the length of the longest sequence along the columns
    padded_data = [torch.cat([seq, torch.zeros(seq.size(0), max_cols - seq.size(1))], dim=1) for seq in data]

    return {'data': torch.stack(padded_data), 'labels': torch.tensor(labels)}


In [97]:
root_dir = 'C:\GenshinVoice\Melspec_data'
custom_dataset = CustomDataset(root_dir)

# Use DataLoader to handle batching, shuffling, etc.
dataloader = DataLoader(custom_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

# Iterate through the dataloader to get batches of data
for batch in dataloader:
    data, labels = batch['data'], batch['labels']
    print(data.shape)
    # Your training code here

torch.Size([64, 128, 458])
torch.Size([64, 128, 434])
torch.Size([64, 128, 461])
torch.Size([64, 128, 716])
torch.Size([64, 128, 1244])
torch.Size([64, 128, 501])
torch.Size([64, 128, 433])
torch.Size([64, 128, 386])
torch.Size([64, 128, 405])
torch.Size([64, 128, 475])
torch.Size([64, 128, 425])
torch.Size([64, 128, 465])
torch.Size([64, 128, 483])
torch.Size([64, 128, 404])
torch.Size([64, 128, 610])
torch.Size([64, 128, 446])
torch.Size([64, 128, 429])
torch.Size([64, 128, 1184])
torch.Size([64, 128, 1535])
torch.Size([64, 128, 620])
torch.Size([64, 128, 537])
torch.Size([64, 128, 622])
torch.Size([64, 128, 473])
torch.Size([64, 128, 454])
torch.Size([64, 128, 426])
torch.Size([64, 128, 412])
torch.Size([64, 128, 366])
torch.Size([64, 128, 478])
torch.Size([64, 128, 527])
torch.Size([64, 128, 779])
torch.Size([64, 128, 470])
torch.Size([64, 128, 445])
torch.Size([64, 128, 592])
torch.Size([64, 128, 586])
torch.Size([64, 128, 643])
torch.Size([64, 128, 535])
torch.Size([64, 128, 1334

KeyboardInterrupt: 

In [73]:
a = next(iter(dataloader))

In [79]:
a['data'][1]

torch.Size([128, 458])

In [94]:
a = librosa.feature.inverse.mel_to_audio(a['data'][1], sr = 22050)
out_file_name = 'C:\\Users\\Joshua Ning\\Desktop\\VChanger\\test4.wav'
sf.write(out_file_name, a, sr)

torch.Size([64, 128, 458])

In [92]:
a = next(iter(dataloader))

In [93]:
a['data'][1]

tensor([[8.1018e-08, 3.5602e-05, 1.0788e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.7335e-07, 8.9042e-05, 1.3769e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [3.0822e-06, 9.5129e-05, 6.6653e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [4.9465e-06, 7.9538e-04, 2.4249e-02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0612e-06, 3.3902e-04, 1.0264e-02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.5968e-07, 2.4259e-05, 9.0028e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])