In [None]:
"""
The goal is to take the original code from the asteroid library for an inherited, custom Dataset object 
and fit it to our needs.

Original code:
https://github.com/daea69twins/voiceseperation/wiki/Time-Log/_edit
"""

import numpy as np
import pandas as pd
import soundfile as sf
import torch
from torch import hub
from torch.utils.data import Dataset, DataLoader
import random as random
import os
import shutil
import zipfile

# from .wham_dataset import wham_noise_license

# MINI_URL = "https://zenodo.org/record/3871592/files/MiniLibriMix.zip?download=1"


class CustomOverlay(Dataset):
    """Dataset class for LibriMix source separation tasks.

    Args:
        csv_dir (str): The path to the metadata file.
        task (str): One of ``'enh_single'``, ``'enh_both'``, ``'sep_clean'`` or
            ``'sep_noisy'`` :

            * ``'enh_single'`` for single speaker speech enhancement.
            * ``'enh_both'`` for multi speaker speech enhancement.
            * ``'sep_clean'`` for two-speaker clean source separation.
            * ``'sep_noisy'`` for two-speaker noisy source separation.

        sample_rate (int) : The sample rate of the sources and mixtures.
        n_src (int) : The number of sources in the mixture.
        segment (int, optional) : The desired sources and mixtures length in s.

    References
        [1] "LibriMix: An Open-Source Dataset for Generalizable Speech Separation",
        Cosentino et al. 2020.
    """

#     dataset_name = "Random"

    def __init__(
        self, csv_dir, dataset_name, sample_rate=22040, 
    ):
        self.csv_dir = csv_dir

        # Get the csv corresponding to the dataset (specific character)
        md_file = [f for f in os.listdir(csv_dir) if dataset_name in f][0]
        self.csv_path = os.path.join(self.csv_dir, md_file)

        self.sample_rate = sample_rate
        
        # Open csv file
        self.df = pd.read_csv(self.csv_path)
        

    def __len__(self):
        return len(self.df)
    

    def __getitem__(self, idx):
        # Get the row in dataframe
        row = self.df.iloc[idx]
        
        # Get mixture path
        mixture_path = row["mixture_path"]
        self.mixture_path = mixture_path
        sources_list = []
        
        start = 0
        stop = None  

        # Read sources
        for i in range(self.n_src):
            source_path = row[f"source_{i + 1}_path"]
            s, _ = sf.read(source_path, dtype="float32", start=start, stop=stop)
            sources_list.append(s)
            
        # Read the mixture
        mixture, _ = sf.read(mixture_path, dtype="float32", start=start, stop=stop)
        # Convert to torch tensor
        mixture = torch.from_numpy(mixture)
        
        # Stack sources (this puts the sources in the same array, but does not combine them)
        sources = np.vstack(sources_list)
        # Convert sources to tensor
        sources = torch.from_numpy(sources)
        
        return mixture, sources


    def get_infos(self):
        """Get dataset infos (for publishing models).

        Returns:
            dict, dataset infos with keys `dataset`, `task` and `licences`.
        """
        infos = dict()
        infos["dataset"] = self._dataset_name()
        infos["task"] = self.task
        if self.task == "sep_clean":
            data_license = [librispeech_license]
        else:
            data_license = [librispeech_license, wham_noise_license]
        infos["licenses"] = data_license
        return infos

    def _dataset_name(self):
        """ Differentiate between 2 and 3 sources."""
        return f"Libri{self.n_src}Mix"


librispeech_license = dict(
    title="LibriSpeech ASR corpus",
    title_link="http://www.openslr.org/12",
    author="Vassil Panayotov",
    author_link="https://github.com/vdp",
    license="CC BY 4.0",
    license_link="https://creativecommons.org/licenses/by/4.0/",
    non_commercial=False,
)
