In [1]:
!pip -qqq install transformers datasets nnAudio

In [2]:
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from datasets import Dataset, Audio, concatenate_datasets, Split
import os

In [3]:
# mount drive and set path to dataset
from google.colab import drive
drive.mount('/content/drive')
data_dir = "/content/drive/Shareddrives/DeepLearningProject/minibabyslakh"
# make sure 
os.listdir(data_dir)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['train', 'test']

In [4]:
# loading our model weights
model = AutoModel.from_pretrained("m-a-p/MERT-v0", trust_remote_code=True)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0",trust_remote_code=True)

In [5]:
# # load demo audio and set processor
# dataset = Dataset.load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
# sampling_rate = dataset.features["audio"].sampling_rate

In [6]:
# Function to load the audio files from the directory structure
def get_data_files(directory):
    bass_files = []
    residual_files = []
    tracks = []
    for track_dir in os.listdir(directory):
        track_path = os.path.join(directory, track_dir)
        if os.path.isdir(track_path):
            bass_audio_dir = os.path.join(track_path, 'bass')
            # bass_file = os.path.join(bass_audio_dir, 'bass.wav')
            # residual_file = os.path.join(bass_audio_dir, 'residuals.wav')
            if os.path.isdir(bass_audio_dir):
                for file in os.listdir(bass_audio_dir):
                    if file.startswith('bass') and file.endswith('.wav'):
                        bass_file = os.path.join(bass_audio_dir, file)
                        bass_files.append(bass_file)
                        residual_file = os.path.join(bass_audio_dir, 'residuals' + file[4:])
                        residual_files.append(residual_file)
                        tracks.append(track_dir)
        
    return {"bass": bass_files, "residuals": residual_files, "track": tracks}

In [7]:
# Get the audio filenames from the dataset directory
train_files = get_data_files(os.path.join(data_dir, "train"))
test_files = get_data_files(os.path.join(data_dir, "test"))
# validation_data = load_audio_files(os.path.join(data_dir, "validation"))
train_files

{'bass': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/bass.wav'],
 'residuals': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/residuals.wav'],
 'track': ['Track00002', 'Track00001', 'Track00003']}

In [8]:
# Create the dataset objects
train_dataset = Dataset.from_dict(train_files, split="train") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
test_dataset = Dataset.from_dict(test_files, split="test") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
combined_dataset = concatenate_datasets([train_dataset, test_dataset])

train_dataset

Dataset({
    features: ['bass', 'residuals', 'track'],
    num_rows: 3
})

In [9]:
sampling_rate = train_dataset["residuals"][0]['sampling_rate']
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sampling_rate:
    print(f'setting rate from {sampling_rate} to {resample_rate}')
    resampler = T.Resample(sampling_rate, resample_rate)
else:
    resampler = None

In [10]:
# audio file is decoded on the fly
if resampler is None:
    src_audio = train_dataset[0]["residuals"]["array"]
    tgt_audio = train_dataset[0]["bass"]["array"]
else:
  src_audio = resampler(torch.from_numpy(train_dataset[0]["residuals"]["array"]))
  tgt_audio = resampler(torch.from_numpy(train_dataset[0]["bass"]["array"]))

In [11]:
# The whole audio file is too big to run in colab
src_audio = src_audio[0:50000]
tgt_audio = tgt_audio[0:50000]

In [12]:
src_inputs = processor(src_audio, sampling_rate=resample_rate, return_tensors="pt")
tgt_inputs = processor(tgt_audio, sampling_rate=resample_rate, return_tensors="pt")
with torch.no_grad():
    src_outputs = model(**src_inputs, output_hidden_states=True)
    tgt_outputs = model(**tgt_inputs, output_hidden_states=True)

In [13]:
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
src_hidden_states = torch.stack(src_outputs.hidden_states).squeeze()
tgt_hidden_states = torch.stack(tgt_outputs.hidden_states).squeeze()

print("src_hidden_states.shape: " + str(src_hidden_states.shape))
print("tgt_hidden_states.shape: " + str(tgt_hidden_states.shape))

src_hidden_states.shape: torch.Size([13, 156, 768])
tgt_hidden_states.shape: torch.Size([13, 156, 768])


#Define Model

In [14]:
seq_length = src_hidden_states.size(dim = 1)
d_model = src_hidden_states.size(dim = 2)

#TODO: Change the hyperparameters and consider custom implementations once we get something working
transformer_decoder_layer = nn.TransformerDecoderLayer(d_model = 2 * d_model, nhead = 8)
transformer_decoder = nn.TransformerDecoder(transformer_decoder_layer, num_layers = 6)

print(transformer_decoder)

TransformerDecoder(
  (layers): ModuleList(
    (0-5): 6 x TransformerDecoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)
      )
      (multihead_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)
      )
      (linear1): Linear(in_features=1536, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=1536, bias=True)
      (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
    )
  )
)


# Train Model

In [15]:
# TODO: Just taking the last layer for fine grain acoustic tokens and middle layer for semantic tokens, perhaps this should be tuned?
src_fine_grain_tokens = src_hidden_states[-1]
src_semantic_tokens = src_hidden_states[int(src_hidden_states.size(dim = 0) / 2)]

# TODO: Is just concatenating the fine grain tokens and semantic tokens acceptable?
# Ideally the dimension here needs to match with the eventual encoded dimensions
src_hybrid_tokens = torch.cat((src_fine_grain_tokens, src_semantic_tokens), dim=1)

# TODO: This should come from encodec, just putting as an example here
tgt_fine_grain_tokens = tgt_hidden_states[-1]
tgt_semantic_tokens = tgt_hidden_states[int(tgt_hidden_states.size(dim = 0) / 2)]
tgt_hybrid_tokens = torch.cat((tgt_fine_grain_tokens, tgt_semantic_tokens), dim=1)

# TODO: What is the appropriate starting token?
previously_generated_token = torch.zeros([1, 2 * d_model])

generated_tokens = []

for index, tgt_hybrid_token in enumerate(tgt_hybrid_tokens):
  print("Generating token: " + str(index) + "/" + str(seq_length))

  generated_token = transformer_decoder(tgt = previously_generated_token, memory = src_hybrid_tokens)
  #TODO: Calculate loss between generated_token and tgt_fine_grain_token and then backpropogate once per loop?

  generated_tokens.append(generated_token)
  previously_generated_token = generated_token

print("Generated tokens: " + str(generated_tokens))


Generating token: 0/156
Generating token: 1/156
Generating token: 2/156
Generating token: 3/156
Generating token: 4/156
Generating token: 5/156
Generating token: 6/156
Generating token: 7/156
Generating token: 8/156
Generating token: 9/156
Generating token: 10/156
Generating token: 11/156
Generating token: 12/156
Generating token: 13/156
Generating token: 14/156
Generating token: 15/156
Generating token: 16/156
Generating token: 17/156
Generating token: 18/156
Generating token: 19/156
Generating token: 20/156
Generating token: 21/156
Generating token: 22/156
Generating token: 23/156
Generating token: 24/156
Generating token: 25/156
Generating token: 26/156
Generating token: 27/156
Generating token: 28/156
Generating token: 29/156
Generating token: 30/156
Generating token: 31/156
Generating token: 32/156
Generating token: 33/156
Generating token: 34/156
Generating token: 35/156
Generating token: 36/156
Generating token: 37/156
Generating token: 38/156
Generating token: 39/156
Generating