In [1]:
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW

import lovely_tensors as lt
lt.monkey_patch()

from torch.utils.data import Dataset

model = MusicGen.get_pretrained('small')
    
from audiocraft.modules.conditioners import (
    ClassifierFreeGuidanceDropout
)

import os

In [2]:
model.lm = model.lm.to(torch.float32)

In [3]:
pairs = {
    1: {
        'audio': '/home/ubuntu/dataset/segment_000.wav',
        'label': '/home/ubuntu/dataset/segment_000.txt'
    },
    2: {
        'audio': '/home/ubuntu/dataset/segment_001.wav',
        'label': '/home/ubuntu/dataset/segment_001.txt'
    }
}

In [4]:
learning_rate = 0.00001
model.lm.train()
optimizer = AdamW(model.lm.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def count_nans(tensor):
    nan_mask = torch.isnan(tensor)
    num_nans = torch.sum(nan_mask).item()
    return num_nans

In [5]:
def preprocess_audio(audio_path, model: MusicGen, duration: int = 30):
    wav, sr = torchaudio.load(audio_path)
    wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
    wav = wav.mean(dim=0, keepdim=True)
    end_sample = int(model.sample_rate * duration)
    wav = wav[:, :end_sample]

    assert wav.shape[0] == 1
    assert wav.shape[1] == model.sample_rate * duration

    wav = wav.cuda()
    wav = wav.unsqueeze(1)

    with torch.no_grad():
        gen_audio = model.compression_model.encode(wav)

    codes, scale = gen_audio

    assert scale is None

    return codes

In [6]:
text = pairs[1]['label']
audio = pairs[1]['audio']
# print both
print("text:", text)
print("audio:", audio)

audio = preprocess_audio(audio, model)
# print audio info
print("audio shape:", audio.shape)
print("audio:", audio)

text: /home/ubuntu/dataset/segment_000.txt
audio: /home/ubuntu/dataset/segment_000.wav
audio shape: torch.Size([1, 4, 1500])
audio: tensor[1, 4, 1500] i64 n=6000 (47Kb) x∈[0, 2047] μ=1.149e+03 σ=622.125 cuda:0


In [7]:
text = open(text, 'r').read().strip()

attributes, _ = model._prepare_tokens_and_attributes([text], None)
# print attributes info
print("attributes:", attributes)

attributes: [ConditioningAttributes(text={'description': 'funky song by moeshop, electropop, hype, fast paced, webcore'}, wav={'self_wav': WavCondition(wav=tensor[1, 1] cuda:0 [[0.]], length=tensor[1] i64 cuda:0 [0], path='null_wav')})]


In [8]:
conditions = attributes
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
print("null_conditions:", null_conditions, '\n')
conditions = conditions + null_conditions
print("conditions", conditions, '\n')
tokenized = model.lm.condition_provider.tokenize(conditions)
print("tokenized", tokenized, '\n')
cfg_conditions = model.lm.condition_provider(tokenized)
print("Cfg", cfg_conditions, '\n')
condition_tensors = cfg_conditions

null_conditions: [ConditioningAttributes(text={'description': None}, wav={'self_wav': WavCondition(wav=tensor[1, 1] cuda:0 [[0.]], length=tensor[1] i64 cuda:0 [0], path=['null_wav'])})] 

conditions [ConditioningAttributes(text={'description': 'funky song by moeshop, electropop, hype, fast paced, webcore'}, wav={'self_wav': WavCondition(wav=tensor[1, 1] cuda:0 [[0.]], length=tensor[1] i64 cuda:0 [0], path='null_wav')}), ConditioningAttributes(text={'description': None}, wav={'self_wav': WavCondition(wav=tensor[1, 1] cuda:0 [[0.]], length=tensor[1] i64 cuda:0 [0], path=['null_wav'])})] 

tokenized {'description': {'input_ids': tensor[2, 15] i64 n=30 x∈[0, 18789] μ=1.930e+03 σ=4.392e+03 cuda:0, 'attention_mask': tensor[2, 15] i64 n=30 x∈[0, 1] μ=0.500 σ=0.509 cuda:0}} 

Cfg {'description': (tensor[2, 15, 1024] n=30720 (0.1Mb) x∈[-8.109, 8.683] μ=0.002 σ=0.389 grad MulBackward0 cuda:0, tensor[2, 15] i64 n=30 x∈[0, 1] μ=0.500 σ=0.509 cuda:0)} 



In [9]:
import torch

def one_hot_encode(tensor, num_classes=6):
    shape = tensor.shape
    one_hot = torch.zeros((shape[0], shape[1], num_classes))

    for i in range(shape[0]):
        for j in range(shape[1]):
            index = tensor[i, j].item()
            one_hot[i, j, index] = 1

    return one_hot

In [10]:
codes = torch.cat([audio, audio], dim=0)
print(codes.shape)
print(codes)

torch.Size([2, 4, 1500])
tensor[2, 4, 1500] i64 n=12000 (94Kb) x∈[0, 2047] μ=1.149e+03 σ=622.099 cuda:0


In [11]:
model.lm = model.lm.to(torch.float32)

In [12]:
with model.autocast:
    # 3. Pass encoded_audio and text_embeddings to compute_predictions()
    lm_output = model.lm.compute_predictions(
        codes=codes,
        conditions=[],
        condition_tensors=condition_tensors
    )

In [15]:
codes = codes[0]
logits = lm_output.logits[0]
mask = lm_output.mask[0]

In [21]:
codes = one_hot_encode(codes, 2048)

In [27]:
print(codes.shape)
print(logits.shape)
print(mask.shape)

torch.Size([4, 1500, 2048])
torch.Size([4, 1500, 2048])
torch.Size([6000])


In [26]:
# Flatten mask tensor
mask = mask.view(-1)

In [29]:
masked_logits = logits.view(-1, 2048)[mask]

In [30]:
masked_logits.v

tensor[5994, 2048] f16 n=12275712 (23Mb) x∈[-26.906, 25.359] μ=-3.502 σ=4.516 grad IndexBackward0 cuda:0
tensor([[-4.9976e-01, -2.7988e+00, -4.4922e+00,  ..., -1.4417e-01,
         -3.6641e+00, -4.1719e+00],
        [ 1.6807e+00, -4.3477e+00, -4.0234e+00,  ...,  1.0430e+00,
         -1.9404e+00, -3.5020e+00],
        [ 8.9600e-02, -9.5312e+00, -4.5508e+00,  ...,  2.0137e+00,
         -3.5762e+00, -5.8594e+00],
        ...,
        [-1.4229e+00,  3.6836e+00, -7.8760e-01,  ...,  2.8149e-01,
         -3.6797e+00,  2.0488e+00],
        [-1.6533e+00,  6.6797e-01, -6.7529e-01,  ..., -4.2227e+00,
         -5.5389e-03, -7.9834e-01],
        [-3.8535e+00,  7.5703e+00, -9.3115e-01,  ...,  2.3789e+00,
          3.7480e+00,  5.0898e+00]], device='cuda:0', dtype=torch.float16,
       grad_fn=<IndexBackward0>)

In [32]:
codes = codes.cpu()
mask = mask.cpu()

In [37]:
masked_codes = codes.view(-1, 2048)[mask]

In [None]:
logits.chans()

In [42]:
masked_logits = masked_logits.cuda()
masked_codes = masked_codes.cuda()

In [43]:
loss = criterion(masked_logits, masked_codes)

In [45]:
loss.v

tensor grad DivBackward1 cuda:0 2.968
tensor(2.9675, device='cuda:0', grad_fn=<DivBackward1>)

In [49]:
loss.backward()

In [50]:
torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 0.5)

tensor cuda:0 21.087

In [51]:
optimizer.step()

In [52]:
optimizer.zero_grad()