In [10]:
import sys
sys.path.append('../')

from src.data.components.dataio import load_audio, pad

In [17]:
from IPython.display import Audio, display
import librosa.display
import matplotlib.pyplot as plt
import random


def play_and_show(file_path):
    """
    Play and show an audio file
    :param file_path: path to audio file
    """
    y, sr = librosa.load(file_path, sr=None)
    fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
    display(Audio(file_path, rate=16000))
    # Compute the spectrogram
    S = np.abs(librosa.stft(y, n_fft=2048, hop_length=240, win_length=480, window='hamming'))
    D = librosa.amplitude_to_db(S, ref=np.max)
    img = librosa.display.specshow(D, y_axis='linear', x_axis='time', n_fft=2048, hop_length=240, win_length=480, 
                               sr=sr, ax=ax)
    plt.show()



In [24]:
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import librosa
import os

def genSpoof_list( dir_meta, is_train=False, is_eval=False):
    """
    This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17
    Official source: https://arxiv.org/abs/2202.12233
    Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation
    """
    d_meta = {}
    file_list=[]
    with open(dir_meta, 'r') as f:
        l_meta = f.readlines()

    if (is_train):
        for line in l_meta:
            _,key,_,_,label = line.strip().split()
            
            file_list.append(key)
            d_meta[key] = 1 if label == 'bonafide' else 0
    
        return d_meta,file_list
    
    elif(is_eval):
        for line in l_meta:
            key= line.strip()
            file_list.append(key)
        return file_list
    else:
        for line in l_meta:
            _,key,_,_,label = line.strip().split()
            
            file_list.append(key)
            d_meta[key] = 1 if label == 'bonafide' else 0
        return d_meta,file_list
    
class Dataset_ASVspoof2019_train(Dataset):
    def __init__(self,list_IDs, labels, base_dir):
        '''self.list_IDs	: list of strings (each string: utt key),
            self.labels      : dictionary (key: utt key, value: label integer)'''
               
        self.list_IDs = list_IDs
        self.labels = labels
        self.base_dir = base_dir
        
    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):            
        utt_id = self.list_IDs[index]
        X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) 
        x_inp= torch.from_numpy(X)
        target = self.labels[utt_id]

        return x_inp, target
# Create a synthetic dataset
class SyntheticAudioDataset(Dataset):
    def __init__(self, num_samples=100, sample_rate=16000):
        self.num_samples = num_samples
        self.sample_rate = sample_rate
        # Generate random lengths between 0.5 seconds and 4 seconds
        self.lengths = np.random.randint(low=sample_rate // 2, high=sample_rate * 4, size=self.num_samples)
        self.labels = np.random.randint(low=0, high=2, size=self.num_samples)  # Binary classification

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random audio data based on the length
        length = self.lengths[idx]
        audio_data = torch.randn(length)
        label = self.labels[idx]
        return audio_data, label


# Custom collate function
def collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000):
    view_batches = {view: [] for view in views}

    # Process each sample in the batch
    for x, label in batch:
        # Pad each sample for each view
        for view in views:
            view_length = view * sample_rate
            x_view = pad(x, padding_type='zero', max_len=view_length, random_start=True)
            # Check if x_view is Tensor or numpy array and convert to Tensor if necessary
            if not torch.is_tensor(x_view):
                x_view = torch.from_numpy(x_view)
            view_batches[view].append((x_view, label))

    # Convert lists to tensors
    for view in views:
        sequences, labels = zip(*view_batches[view])
        padded_sequences = torch.stack(sequences)
        labels = torch.tensor(labels, dtype=torch.long)
        view_batches[view] = (padded_sequences, labels)

    return view_batches

# Create the dataset and dataloader
protocols_path = "/data/hungdx/Datasets/protocols/database/"
database_path = "/home/hungdx/code/Lightning-hydra/data/Datasets/"
d_label_trn,file_train = genSpoof_list( dir_meta =  os.path.join(protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False)


d_label_trn = {k: d_label_trn[k] for k in list(d_label_trn)}
file_train = file_train


data_train = Dataset_ASVspoof2019_train(list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(database_path+'ASVspoof2019_LA_train/'))
#dataset = SyntheticAudioDataset()
dataloader = DataLoader(data_train, batch_size=10, collate_fn=lambda x: collate_fn(x, views=[1, 2, 3, 4]), shuffle=True)

# Iterate through the DataLoader
for batch_idx, batch in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}")
    for view, (data, labels) in batch.items():
        print(f"  View Duration: {view}s - Batch Size: {data.size(0)}, Sequence Length: {data.size(1)}")
        # Play a random audio file from the batch
        random_idx = random.randint(0, data.size(0) - 1)
        # Show the label
        print(f"    Label: {labels[random_idx]}")
        display(Audio(data[random_idx].numpy(), rate=16000)) 
    if batch_idx == 0:  # Print only for the first batch for brevity
        break



Batch 1
  View Duration: 1s - Batch Size: 10, Sequence Length: 16000
    Label: 0


  View Duration: 2s - Batch Size: 10, Sequence Length: 32000
    Label: 0


  View Duration: 3s - Batch Size: 10, Sequence Length: 48000
    Label: 1


  View Duration: 4s - Batch Size: 10, Sequence Length: 64000
    Label: 1
