In [33]:
!pip -qqq install datasets audiolm-pytorch

In [34]:
import torch
from torch import nn
from audiolm_pytorch.encodec import EncodecWrapper
import torchaudio.transforms as T
from datasets import Dataset, Audio, concatenate_datasets, Split
import os
from scipy.io.wavfile import write

In [35]:
# mount drive and set path to dataset
from google.colab import drive
drive.mount('/content/drive')
data_dir = "/content/drive/Shareddrives/DeepLearningProject/minibabyslakh"
# make sure 
os.listdir(data_dir)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['train', 'test']

In [36]:
# Use "encodec = EncodecWrapper().cuda()" when a GPU is available
encodec = EncodecWrapper()

In [37]:
# Function to load the audio files from the directory structure
def get_data_files(directory):
    bass_files = []
    residual_files = []
    tracks = []
    for track_dir in os.listdir(directory):
        track_path = os.path.join(directory, track_dir)
        if os.path.isdir(track_path):
            bass_audio_dir = os.path.join(track_path, 'bass')
            # bass_file = os.path.join(bass_audio_dir, 'bass.wav')
            # residual_file = os.path.join(bass_audio_dir, 'residuals.wav')
            if os.path.isdir(bass_audio_dir):
                for file in os.listdir(bass_audio_dir):
                    if file.startswith('bass') and file.endswith('.wav'):
                        bass_file = os.path.join(bass_audio_dir, file)
                        bass_files.append(bass_file)
                        residual_file = os.path.join(bass_audio_dir, 'residuals' + file[4:])
                        residual_files.append(residual_file)
                        tracks.append(track_dir)
        
    return {"bass": bass_files, "residuals": residual_files, "track": tracks}

In [38]:
# Get the audio filenames from the dataset directory
train_files = get_data_files(os.path.join(data_dir, "train"))
test_files = get_data_files(os.path.join(data_dir, "test"))
# validation_data = load_audio_files(os.path.join(data_dir, "validation"))
train_files

{'bass': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/bass.wav'],
 'residuals': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/residuals.wav'],
 'track': ['Track00002', 'Track00001', 'Track00003']}

In [39]:
# Create the dataset objects
train_dataset = Dataset.from_dict(train_files, split="train") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
test_dataset = Dataset.from_dict(test_files, split="test") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
combined_dataset = concatenate_datasets([train_dataset, test_dataset])

train_dataset

Dataset({
    features: ['bass', 'residuals', 'track'],
    num_rows: 3
})

In [40]:
src_audio = train_dataset[0]["residuals"]["array"]
tgt_audio = train_dataset[0]["bass"]["array"]

# The whole audio file is too big to run in colab
src_audio = src_audio[0:50000]
tgt_audio = tgt_audio[0:50000]

sampling_rate = train_dataset["residuals"][0]['sampling_rate']

write("test_pre_encoding_src.wav", sampling_rate, src_audio)
write("test_pre_encoding_tgt.wav", sampling_rate, tgt_audio)

# Required for encodec
src_audio = torch.from_numpy(src_audio).float()
tgt_audio = torch.from_numpy(tgt_audio).float()

In [41]:
with torch.no_grad():
    src_encodec_tokens = encodec(src_audio, return_encoded = False)
    tgt_encodec_tokens = encodec(tgt_audio, return_encoded = False)

In [42]:
# TODO: Figure out sequence number and number of dimensions
print(src_encodec_tokens[1].shape)
print(tgt_encodec_tokens[1].shape)

torch.Size([157, 8])
torch.Size([157, 8])


In [43]:
src_encodec_tokens = src_encodec_tokens[1].unsqueeze(dim=0)
tgt_encodec_tokens = tgt_encodec_tokens[1].unsqueeze(dim=0)
print(src_encodec_tokens.shape)
print(tgt_encodec_tokens.shape)

torch.Size([1, 157, 8])
torch.Size([1, 157, 8])


In [48]:
test_post_encoding_src = encodec.decode_from_codebook_indices(src_encodec_tokens)
test_post_encoding_tgt = encodec.decode_from_codebook_indices(tgt_encodec_tokens)

In [50]:
write("test_post_encoding_src.wav", sampling_rate, test_post_encoding_src.detach().numpy())
write("test_post_encoding_tgt.wav", sampling_rate, test_post_encoding_tgt.detach().numpy())