In [86]:
!pip -qqq install transformers datasets nnAudio

In [87]:
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from datasets import Dataset, Audio, concatenate_datasets, Split
import os

In [88]:
# mount drive and set path to dataset
from google.colab import drive
drive.mount('/content/drive')
data_dir = "/content/drive/Shareddrives/DeepLearningProject/minibabyslakh"
# make sure 
os.listdir(data_dir)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['train', 'test']

In [89]:
# loading our model weights
model = AutoModel.from_pretrained("m-a-p/MERT-v0", trust_remote_code=True)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0",trust_remote_code=True)

In [90]:
# # load demo audio and set processor
# dataset = Dataset.load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
# sampling_rate = dataset.features["audio"].sampling_rate

In [91]:
# Function to load the audio files from the directory structure
def get_data_files(directory):
    bass_files = []
    residual_files = []
    tracks = []
    for track_dir in os.listdir(directory):
        track_path = os.path.join(directory, track_dir)
        if os.path.isdir(track_path):
            bass_audio_dir = os.path.join(track_path, 'bass')
            # bass_file = os.path.join(bass_audio_dir, 'bass.wav')
            # residual_file = os.path.join(bass_audio_dir, 'residuals.wav')
            if os.path.isdir(bass_audio_dir):
                for file in os.listdir(bass_audio_dir):
                    if file.startswith('bass') and file.endswith('.wav'):
                        bass_file = os.path.join(bass_audio_dir, file)
                        bass_files.append(bass_file)
                        residual_file = os.path.join(bass_audio_dir, 'residuals' + file[4:])
                        residual_files.append(residual_file)
                        tracks.append(track_dir)
        
    return {"bass": bass_files, "residuals": residual_files, "track": tracks}

In [92]:
# Get the audio filenames from the dataset directory
train_files = get_data_files(os.path.join(data_dir, "train"))
test_files = get_data_files(os.path.join(data_dir, "test"))
# validation_data = load_audio_files(os.path.join(data_dir, "validation"))
train_files

{'bass': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/bass.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/bass.wav'],
 'residuals': ['/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00002/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00001/bass/residuals.wav',
  '/content/drive/Shareddrives/DeepLearningProject/minibabyslakh/train/Track00003/bass/residuals.wav'],
 'track': ['Track00002', 'Track00001', 'Track00003']}

In [93]:
# Create the dataset objects
train_dataset = Dataset.from_dict(train_files, split="train") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
test_dataset = Dataset.from_dict(test_files, split="test") \
                    .cast_column("bass", Audio()) \
                    .cast_column("residuals", Audio()) \
                    .sort("track")
combined_dataset = concatenate_datasets([train_dataset, test_dataset])

train_dataset

Dataset({
    features: ['bass', 'residuals', 'track'],
    num_rows: 3
})

In [94]:
sampling_rate = train_dataset["residuals"][0]['sampling_rate']
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sampling_rate:
    print(f'setting rate from {sampling_rate} to {resample_rate}')
    resampler = T.Resample(sampling_rate, resample_rate)
else:
    resampler = None

In [95]:
# audio file is decoded on the fly
if resampler is None:
    src_audio = train_dataset[0]["residuals"]["array"]
    tgt_audio = train_dataset[0]["bass"]["array"]
else:
  src_audio = resampler(torch.from_numpy(train_dataset[0]["residuals"]["array"]))
  tgt_audio = resampler(torch.from_numpy(train_dataset[0]["bass"]["array"]))

In [96]:
# The whole audio file is too big to run in colab
src_audio = src_audio[0:93680]
tgt_audio = tgt_audio[0:93680]

In [97]:
src_inputs = processor(src_audio, sampling_rate=resample_rate, return_tensors="pt")
tgt_inputs = processor(tgt_audio, sampling_rate=resample_rate, return_tensors="pt")
with torch.no_grad():
    src_outputs = model(**src_inputs, output_hidden_states=True)
    tgt_outputs = model(**tgt_inputs, output_hidden_states=True)

In [98]:
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
src_hidden_states = torch.stack(src_outputs.hidden_states).squeeze()
tgt_hidden_states = torch.stack(tgt_outputs.hidden_states).squeeze()

print("src_hidden_states.shape: " + str(src_hidden_states.shape))
print("tgt_hidden_states.shape: " + str(tgt_hidden_states.shape))

src_hidden_states.shape: torch.Size([13, 292, 768])
tgt_hidden_states.shape: torch.Size([13, 292, 768])


#Define Model

In [99]:
seq_length = src_hidden_states.size(dim=1)
d_model = src_hidden_states.size(dim=2)

transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=8)
transformer_decoder = nn.TransformerDecoder(transformer_decoder_layer, num_layers=6)

print(transformer_decoder)

TransformerDecoder(
  (layers): ModuleList(
    (0-5): 6 x TransformerDecoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (multihead_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=768, bias=True)
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
    )
  )
)


# Train Model

In [100]:
# TODO: Just taking the last layer for fine grain acoustic tokens, need to select a layer to use for semantic tokens as well and find out how to concat
src_fine_grain_tokens = src_hidden_states[-1]
tgt_fine_grain_tokens = tgt_hidden_states[-1]

# TODO: What is the appropriate starting token?
previously_generated_token = torch.zeros([1, d_model])

generated_tokens = []

for index, tgt_fine_grain_token in enumerate(tgt_fine_grain_tokens):
  print("Generating token: " + str(index))

  generated_token = transformer_decoder(tgt = previously_generated_token, memory = fine_grain_tokens)
  #TODO: Calculate loss between generated_token and tgt_fine_grain_token and then backpropogate once per loop?

  generated_tokens.append(generated_token)
  previously_generated_token = generated_token


Generating token: 0
Generating token: 1
Generating token: 2
Generating token: 3
Generating token: 4
Generating token: 5
Generating token: 6
Generating token: 7
Generating token: 8
Generating token: 9
Generating token: 10
Generating token: 11
Generating token: 12
Generating token: 13
Generating token: 14
Generating token: 15
Generating token: 16
Generating token: 17
Generating token: 18
Generating token: 19
Generating token: 20
Generating token: 21
Generating token: 22
Generating token: 23
Generating token: 24
Generating token: 25
Generating token: 26
Generating token: 27
Generating token: 28
Generating token: 29
Generating token: 30
Generating token: 31
Generating token: 32
Generating token: 33
Generating token: 34
Generating token: 35
Generating token: 36
Generating token: 37
Generating token: 38
Generating token: 39
Generating token: 40
Generating token: 41
Generating token: 42
Generating token: 43
Generating token: 44
Generating token: 45
Generating token: 46
Generating token: 47
Ge