# DATA IMPORT

Data import and processing are crucial step in both learning / using machine learning models. Hence, `vschaos` provide high-level methods for easy data loading / processing that can be used both for general data, and specific methods for audio (transforms, inversion, time conditioning). 

## Dataset import in `vschaos` 

Data import is handled by the :py:class:`Dataset` object, that implements most data / metadata operations. 
Datasets in `vschaos` imply a specific architecture, that can be split across different folders. This 
architecture is based on three separate folders, that are usually included in a root directory :
* *data/*, containing all the raw data (as wav or mp3 for audio) 
* *analysis/*, containing different transforms with various transformation parameters* 
* *metadata/*, containing the metadata of the dataset.

The `Dataset` object has to be initialised using a dictionary, specifying the properties of the import. 

In [None]:
from vschaos.data import Dataset
# The ./dataset folder contains a dummy dataset
data_prefix = "dataset"
# Creating object 
# if data / metadata /analysis are not contained within a single location, 
#   locations can be specified with data_directory / metadata_directory / analysis_directory 
dataset = Dataset(data_prefix)
print(f"{dataset}, tasks : {dataset.tasks}")
print('files : ', dataset.files)

The generic class `Dataset` does not have an import function by default. A callback function has to be specified to import data using the `Dataset.import_data` function

In [None]:
from vschaos.data import Dataset
import torchaudio

def audio_callback(self, file, **kwargs):
    return torchaudio.load(file)[0], None

data_prefix = "dataset"
dataset = Dataset(data_prefix, import_callback=audio_callback)
dataset.import_data()
data, metadata = dataset[:]
print("data : ", data.shape)
print("metadata : ", metadata)

Be careful : the Dataset object is able to deal with data of variable length, as it does here with audio files of different duration. However, depending of the fianl size, the collating operation can be expensive. It is thus strongly advised to fix the data size (see below).

In [None]:
for d in dataset.data:
    print(d.shape)

Raw inputs have generally to be adapted to a model's input in order to be suitable for learning. This is done with `Transform` objects, contained in the `data_transforms.py` module. A sequence of transformations can be used using the `ComposeTransform` function. This can be used for cropping, (un)squeezing, or transforming the data in several ways before loading. 

In [None]:
from vschaos.data import Dataset
from vschaos.data.data_transforms import Squeeze, Sequence, ComposeTransform
import torchaudio

def audio_callback(self, file, **kwargs):
    return torchaudio.load(file)[0], {}

data_prefix = "dataset"
dataset = Dataset(data_prefix, import_callback=audio_callback)
dataset.import_data()

#Sequence transform takes a sub-sequence of input data among given axis
transforms = ComposeTransform([Squeeze(0), Sequence(512, dim=0, random_start=True)])

x, y = dataset[0]
out = transforms(x)
x_inv = transforms.invert(out)

print(x.shape, '->', out.shape, '->', x_inv.shape)

# Transforms can be directly embedded in the Dataset object, such that it is applied when using __getitem__ method
dataset = Dataset(data_prefix, import_callback=audio_callback, transforms=transforms)
dataset.import_data(scale=False) # Here we do not scale the transforms (further information below)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(4,4)
for i in range(4):
    for j in range(4):
        ax[i, j].plot(dataset[0][0].numpy())



Some transforms, such as normalization procedures, need first to be *scaled* to a given dataset. Hence, each embedded transform are scaled to a given amount of data that can be specified using the `scale` keyword. 

In [None]:
from vschaos.data import Dataset, Sequence, Normalize, ComposeTransform, Unsqueeze
import torchaudio

def audio_callback(self, file, **kwargs):
    return torchaudio.load(file)[0], None

data_prefix = "dataset"
dataset = Dataset(data_prefix, transforms=[Normalize(mode="gaussian", scale="bipolar")], import_callback=audio_callback)
# scale can be False, True (normalize on all the dataset), or an int (randomly picking files)
dataset.import_data(scale=2)

x, y = dataset[:]
print(x.min(), x.max(), x.mean(), x.std())


The `Dataset` object can be directly indexed to obtain a tuple `(x, y)` containing the data and the metadata of the target slice. Sub-datasets can be obtained by using the `retrieve` method, that can be called whether with a set of indicies, or with a partition name.

In [None]:
from vschaos.data import Dataset
from vschaos.data.data_transforms import Binary
from vschaos.utils.dataloader import DataLoader
import torchaudio

def audio_callback(self, file, **kwargs):
    return torchaudio.load(file)[0][0, :2000], None

data_prefix = "dataset"
dataset = Dataset(data_prefix, transforms=[Binary()], import_callback=audio_callback)
dataset.import_data(scale=True)
dataset.apply_transforms() # the apply_transform method overwrites the data and flushes the Dataset's transforms

dataset.construct_partition(['train', 'test'], [0.8, 0.2])
print(dataset.partitions['train'])


train_dataset = dataset.retrieve('train')
# we have to set the loaded metadata with the `drop_tasks` method
data_loader = DataLoader(train_dataset, 64, tasks=dataset.tasks)
x, y = next(data_loader.__iter__())
print(x.shape, y.keys())


### Saving / loading transforms

Transforms can also be saved in the dataset analysis path using the `Dataset.write_transform` method. 

In [None]:
import os, torch, torchaudio
from vschaos.data import Dataset, DatasetAudio
from vschaos.data.data_transforms import Mono, STFT, Squeeze
from vschaos.utils.dataloader import DataLoader

def audio_callback(self, file, **kwargs):
    return torchaudio.load(file)[0][:, :2000], {}

data_prefix = "dataset"
dataset = Dataset(data_prefix, transforms=[Mono(), Squeeze(0)], import_callback=audio_callback)
dataset.import_data()
# if the transforms keyword is None, write_transforms take the dataset's transforms
dataset.write_transform('stft-1024', transforms=[STFT(1024)], scale=True)
print("content of analysis dir : ", os.listdir(f'{data_prefix}/analysis'))


Transforms are saved as `numpy.memmap`, hence allowing asynchronous import using the `OfflineDataList` object. If the `offline` keyword is `True`, `Dataset.data` is then a collection of callback called dynamically with the `__getitem`. As `numpy.memmap` arrays, selectors can be specified in order to load just a part of the file, allowing light import when training on large files. 

In [None]:
from vschaos.data import Dataset, SequencePick

dataset = Dataset(data_prefix)
dataset, transforms = dataset.load_transform('stft-1024', offline=False)
print('-- regular import')
print("type of dataset.data : ", type(dataset.data))
print(dataset[0][0].shape)

dataset = Dataset(data_prefix)
dataset, transforms = dataset.load_transform('stft-1024', offline=True)
print('-- asynchronous import')
print("type of dataset.data : ", type(dataset.data))
print(dataset[0][0].shape)

dataset.load_transform('stft-1024', offline=True,
                       selector=SequencePick,
                       selector_args={'axis':0, 'sequence_length':60, 'random_idx':False})
print('-- asynchronous import')
print("type of dataset.data : ", type(dataset.data))

import matplotlib.pyplot as plt
fig, ax = plt.subplots(2,2)
for i in range(2):
    for j in range(2):
        ax[i, j].imshow(dataset[2*i+j][0].abs().numpy(), aspect="auto")



## Audio import

While audio files were taken as an example for the generic `Dataset` class, the `DatasetAudio` class provides additional features for sound import, notably specific import callbacks and temporal additional features. 

In [None]:
import os, torch, torchaudio, pdb

from vschaos.data import DatasetAudio, Normalize, SequencePick, ComposeAudioTransform
from vschaos.data.data_transforms import Mono, STFT, Squeeze, Polar
from vschaos.utils.dataloader import DataLoader

data_prefix = "dataset"
audioSet = DatasetAudio(data_prefix, drop_time="both") # drop_time can be "data", "meta", or "both"
audioSet.import_data(options={'resampleTo':22050}) 
# imported data corresponds to the raw signal
magnitude_args = {'constrast':'log1p', 'normalize':{'mode':'minmax', 'scale':'unipolar'}}
phase_args = {'unwrap':True, 'normalize':{'mode':'gaussian', 'scale':'bipolar'}}
audioSet.transforms = ComposeAudioTransform([Mono(), Squeeze(0), STFT(2048),  Polar(mag_options=magnitude_args, phase_options=phase_args)])
audioSet.scale_transforms(True)

data, metadata = audioSet[0]
print("magnitude : ", data[0].shape)
print("phase : ", data[1].shape)
print("time as data : ", data[2].shape)
print("time as metadata : ", metadata['time'][:10])

## Toy data generation

`vsacids` also provides a toy dataset generation routine, grid sampling the parameters of audio generation units, defiined in the file `vschaos.data.toys.synthesis`.


In [None]:
import torch
import vschaos.data.toys.synthesis as syn

generator = syn.additive_generator
parameters = {'n_partials':torch.arange(1,3),
              'harmonic_decay': torch.Tensor([0.01, 1.0, 2.0]),
              'f0':torch.linspace(10, 500, 5)}

# generate the dataset!
dataset, _ = syn.dataset_from_generator(1.0, 44100, generator, export="toy_additive_test", **parameters)


import matplotlib.pyplot as plt
ids = torch.randperm(len(dataset))[:12].tolist()
fig, ax = plt.subplots(3, 4)
for i in range(3):
    for j in range(4):
        ax[i, j].plot(dataset[ids[3*i+j]][0][2000:3000])