In [1]:
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

  from .autonotebook import tqdm as notebook_tqdm
2023-11-12 16:06:45 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [2]:
from datasets import load_dataset
ds = load_dataset('google/MusicCaps', split='train')
ds

Dataset({
    features: ['ytid', 'start_s', 'end_s', 'audioset_positive_labels', 'aspect_list', 'caption', 'author_id', 'is_balanced_subset', 'is_audioset_eval'],
    num_rows: 5521
})

In [3]:
import subprocess
import os
from pathlib import Path

def download_clip(
    video_id,
    output_file,
    start_time,
    end_time,
    tmpdir='/tmp/musiccaps',
    num_attempts=5,
    url_base='https://youtube.com/watch?v='
):
    status = False

    command = f"""yt-dlp --no-warnings -x --audio-format wav -f bestaudio -o "{output_file}" --download-sections "*{start_time}-{end_time}" {url_base}{video_id}""".strip()

    attempts = 0
    while True:
        try:
            output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            attempts += 1
            if attempts == num_attempts:
                return status, err.output
        # It's a try-except-else block. If there's no exception
        # thrown then the else block is executed, i.e. if the video
        # is successfully downloaded
        else:
            break

    status = os.path.exists(output_file)
    return status, 'Downloaded'

def process(example):
    output_file = str(data_dir / f"{example['ytid']}.wav")
    status = True
    if not os.path.exists(output_file):
        status = False
        status, log = download_clip(
            video_id=example['ytid'],
            output_file=output_file,
            start_time=example['start_s'],
            end_time=example['end_s']
        )

    example['audio'] = output_file
    example['downloaded_status'] = status
    return example

In [4]:
from datasets import Audio
samples_to_load = 100
cores = 4
sampling_rate = 7700
writer_batch_size = 1000
data_dir = "./music_data"

In [5]:
def numpy_to_tensor(example):
    example["audio"]["array"] = torch.from_numpy(example["audio"]["array"]).to(torch.float32)
    print(type(example["audio"]["array"]))
    return example

In [5]:
ds = ds.select(range(samples_to_load))

data_dir = Path(data_dir)
data_dir.mkdir(exist_ok=True, parents=True)

ds = ds.map(
    process,
    num_proc=cores,
    writer_batch_size=writer_batch_size,
    keep_in_memory=False
).cast_column("audio", Audio(sampling_rate=sampling_rate))

In [None]:
ds = ds.map(numpy_to_tensor, num_proc=cores, writer_batch_size=writer_batch_size, keep_in_memory=False)

In [6]:
ds[0]

{'ytid': '-0Gj8-vB1q4',
 'start_s': 30,
 'end_s': 40,
 'audioset_positive_labels': '/m/0140xf,/m/02cjck,/m/04rlf',
 'aspect_list': "['low quality', 'sustained strings melody', 'soft female vocal', 'mellow piano melody', 'sad', 'soulful', 'ballad']",
 'caption': 'The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.',
 'author_id': 4,
 'is_balanced_subset': False,
 'is_audioset_eval': True,
 'audio': {'path': 'music_data/-0Gj8-vB1q4.wav',
  'array': array([ 0.00042981,  0.00274746, -0.00614384, ..., -0.01440152,
         -0.01727214, -0.02183699]),
  'sampling_rate': 7700},
 'downloaded_status': True}

In [12]:
import numpy as np
from pathlib import Path
Path("music_data/-0Gj8-vB1q4.wav").parents[0]

PosixPath('music_data')

In [19]:
muscall_data_dir = Path("./data/datasets")
audi_data_dir = Path("./music_data")
import numpy as np
from pathlib import Path
import json

def update_data_json(track_data):
    with open(muscall_data_dir.joinpath("data.json"), "r") as rd:
        data = json.load(rd)
    ls = [track_data["audio_id"] == k.get("audio_id", "") for k in data]

    if not any(ls):
        data.append(track_data)
        with open(muscall_data_dir.joinpath("data.json"), "w") as wr:
            data = json.dump(data, wr, indent=4)

def export_data(sample, audio_id):
    # sample["audio"]["array"] = torch.from_numpy(sample["audio"]["array"]).to(torch.float32)
    array = sample["audio"]["array"]
    track_id = Path(sample["audio"]["path"]).name.replace(".wav", "")
    # audio_path = muscall_data_dir.joinpath("audio", track_id).replace(".wav", "")
    audio_path = muscall_data_dir.joinpath("audiocaption", "audio", track_id.replace(".wav", ".npy"))

    caption = sample["caption"]
    update_data_json({"audio_id": audio_id, "caption": caption, "audio_path": str(audio_path)})
    np.save(audio_path, array)
# ds = ds.map(export_data, num_proc=cores, writer_batch_size=writer_batch_size, keep_in_memory=False)

In [None]:
for idx, dt in enumerate(ds):
    export_data(dt, idx)

In [23]:
import os
data_size = len(os.listdir("data/datasets/audiocaption/audio"))
test_size = int(data_size * 0.1)
validation_size = int(data_size * 0.1)
train_size = data_size - (test_size + validation_size)
with open(muscall_data_dir.joinpath("data.json"), "r") as rd:
    data = json.load(rd)

test_data = []
train_data = []
validation_data = []
for idx, dt in enumerate(data):
    if idx < train_size:
        train_data.append(dt)

    elif train_size <= idx <  train_size + test_size:
        test_data.append(dt)

    elif train_size + test_size <= idx:
        validation_data.append(dt)


In [24]:
len(test_data)

9

In [25]:
with open(muscall_data_dir.joinpath("audiocaption", "dataset_train.json"), "w") as wr:
    json.dump(train_data, wr, indent=4)

with open(muscall_data_dir.joinpath("audiocaption", "dataset_test.json"), "w") as wr:
    json.dump(test_data, wr, indent=4)

with open(muscall_data_dir.joinpath("audiocaption", "dataset_val.json"), "w") as wr:
    json.dump(validation_data, wr, indent=4)

In [9]:
ds.set_format("torch", columns=["caption", "audio"])


In [9]:
ds[0]


{'caption': 'The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.',
 'audio': {'path': 'music_data/-0Gj8-vB1q4.wav',
  'array': tensor([ 0.0004,  0.0027, -0.0061,  ..., -0.0144, -0.0173, -0.0218]),
  'sampling_rate': tensor(7700)}}

In [10]:
class MusicapsDataset(torch.utils.data.Dataset):
    def __init__(self, ds):
        super(MusicapsDataset, self).__init__()
        self.ds = ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return (ds[idx]["audio"]["array"], ds[idx]["caption"])

In [11]:
mod_ds = MusicapsDataset(ds)

In [12]:
mod_ds[0]

(tensor([ 0.0004,  0.0027, -0.0062,  ..., -0.0144, -0.0173, -0.0219]),
 'The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.')

In [13]:
def custom_collate(data):
    arrays = [x[0] for x in data]
    captions = [x[1] for x in data]
    maxlen_array = 0
    for array in arrays:
        if len(array) > maxlen_array:
            maxlen_array = len(array)
    for i in range(len(data)):
        arrays[i] = torch.cat([arrays[i], torch.zeros(size=(maxlen_array - len(arrays[i]),))])
    arrays = torch.stack(arrays)
    return arrays, captions

In [14]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)


In [None]:
from musiclm_pytorch import MuLaNTrainer
mulan_trainer = MuLaNTrainer(
    mulan=mulan,
    dataset=mod_ds,
    batch_size=4
)

mulan_trainer.train()

In [15]:
train_dataloader = torch.utils.data.DataLoader(dataset=mod_ds, batch_size=16, collate_fn=custom_collate)
for arrays, captions in train_dataloader:
    print(f"Arrays:\n{arrays.shape}")
    print(f"Captions:\n{captions}")

Arrays:
torch.Size([3, 153796])
Captions:
['The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.', 'This song features an electric guitar as the main instrument. The guitar plays a descending run in the beginning then plays an arpeggiated chord followed by a double stop hammer on to a higher note and a descending slide followed by a descending chord run. The percussion plays a simple beat using rim shots. The percussion plays in common time. The bass plays only one note on the first count of each bar. The piano plays backing chords. There are no voices in this song. The mood of this song is relaxing. This song can be played in a coffee shop.', 'a male voice is singing a melody with changing tempos while snipping his fingers rhythmically. The recording sounds like it has been recorded in an empty room. This song may be playi

In [16]:
mod_ds[0][0].reshape(1, -1).shape

torch.Size([1, 153796])

In [33]:



# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 1024)
wavs = mod_ds[0][0].reshape(1, -1)
texts = torch.randint(0, 20000, (2, 256))
raw_texts = ["The low quality recording features a ballad song", "This song features an electric guitar"]

loss = mulan(wavs, raw_texts=raw_texts)
loss.backward()


In [36]:
loss

tensor(-44.6719, grad_fn=<SumBackward0>)

In [None]:
# after much training, you can embed sounds and text into a joint embedding space
# for conditioning the audio LM

embeds = mulan.get_audio_latents(wavs)  # during training

embeds = mulan.get_text_latents(texts)  # during inference

In [19]:
import torch
from transformers import BertTokenizer, BertModel

# For more details - https://huggingface.co/bert-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)okenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 114kB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.62MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 2.39MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 570/570 [00:00<00:00, 2.01MB/s]


In [57]:
print(captions[0])

The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.


In [37]:
tokens = tokenizer.tokenize("'The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.")
caption_tokens = tokenizer(captions, return_tensors="np", padding="max_length", max_length=256, truncation=True)["input_ids"]

In [None]:
audio_latents = mulan.get_audio_latents(arrays)
audio_latents

In [31]:
audio_latents.shape

torch.Size([3, 128])

In [17]:
raw_texts = ["The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services."]
# raw_texts = ["happy music"]
text_latents = mulan.get_text_latents(raw_texts=raw_texts)
print(text_latents.shape)

torch.Size([1, 128])


In [24]:
logits_audio_text = audio_latents @ text_latents.T
logits_audio_text.softmax(0)

tensor([[0.3225],
        [0.3253],
        [0.3522]], grad_fn=<SoftmaxBackward0>)

In [45]:
checkpoint["model"].keys()

odict_keys(['audio.to_patch_tokens.1.weight', 'audio.to_patch_tokens.1.bias', 'audio.to_patch_tokens.2.weight', 'audio.to_patch_tokens.2.bias', 'audio.to_patch_tokens.3.weight', 'audio.to_patch_tokens.3.bias', 'audio.spec.window', 'audio.aug.0.phase_advance', 'audio.transformer.layers.0.0.q_scale', 'audio.transformer.layers.0.0.k_scale', 'audio.transformer.layers.0.0.norm.learned_gamma', 'audio.transformer.layers.0.0.to_q.weight', 'audio.transformer.layers.0.0.to_kv.weight', 'audio.transformer.layers.0.0.to_out.0.weight', 'audio.transformer.layers.0.1.0.learned_gamma', 'audio.transformer.layers.0.1.1.weight', 'audio.transformer.layers.0.1.4.weight', 'audio.transformer.layers.1.0.q_scale', 'audio.transformer.layers.1.0.k_scale', 'audio.transformer.layers.1.0.norm.learned_gamma', 'audio.transformer.layers.1.0.to_q.weight', 'audio.transformer.layers.1.0.to_kv.weight', 'audio.transformer.layers.1.0.to_out.0.weight', 'audio.transformer.layers.1.1.0.learned_gamma', 'audio.transformer.layers.

In [23]:
model = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

checkpoint = torch.load("/Users/berkayg/Codes/music-project/muscall/mulan.1.pt", map_location=torch.device('mps'))
model.load_state_dict(checkpoint['model'])

model.eval()

MuLaN(
  (audio): AudioSpectrogramTransformer(
    (to_patch_tokens): Sequential(
      (0): Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1=16, p2=16)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (spec): Spectrogram()
    (aug): Sequential(
      (0): TimeStretch()
      (1): FrequencyMasking()
      (2): TimeMasking()
    )
    (transformer): Transformer(
      (layers): ModuleList(
        (0-5): 6 x ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=512, out_features=512, bias=False)
            (to_kv): Linear(in_features=512, out_features=1024, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=False)
              (1): Dropout(p=0.0, inplace=False)

In [27]:
ds.data["ytid"]

<pyarrow.lib.ChunkedArray object at 0x2a512bce0>
[
  [
    "-0Gj8-vB1q4"
  ],
  [
    "-0SdAVK79lg"
  ],
  [
    "-0vPFx-wRRI"
  ]
]

In [24]:
audio_latents = model.get_audio_latents(arrays)
audio_latents.shape[0]

: 

In [20]:
logits_audio_text = audio_latents @ text_latents.T
logits_audio_text

tensor([[0.1528],
        [0.1335],
        [0.0916]], grad_fn=<MmBackward0>)

In [22]:
import gc
torch.mps.empty_cache()
gc.collect()

0