In [4]:
import os
import numpy as np
import torch
import pickle
import time
import sys

current_directory = os.getcwd()
models_dir = os.path.join(current_directory, '..')
print(models_dir)
sys.path.append(models_dir)

import torch
from torch.utils.data import DataLoader, Dataset
from models import Pose2AudioTransformer
from transformers import EncodecModel
from utils import DanceToMusic
from datetime import datetime
from torch.optim import Adam

/home/azeez/Documents/projects/DanceToMusicApp/ml/notebooks/..


  from .autonotebook import tqdm as notebook_tqdm


In [39]:
# assign GPU or CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
# device = torch.device("cpu")

model_id = "facebook/encodec_24khz"
encodec_model = EncodecModel.from_pretrained(model_id)
encodec_model.to(device)
codebook_size = encodec_model.quantizer.codebook_size
sample_rate = 24000

data_dir = "/Users/azeez/Documents/pose_estimation/DanceToMusic/data/samples/5sec_min_data"
data_dir = "/home/azeez/Documents/projects/DanceToMusicApp/ml/data/samples/3sec_24fps_expando_dnb_min_training_data"
dataset = DanceToMusic(data_dir, encoder = encodec_model, sample_rate = sample_rate, device=device)
print("Dataset size: ", len(dataset))



Dataset size:  2095


In [40]:
src_pad_idx = 0
trg_pad_idx = 0
# learned_weights = '/Users/azeez/Documents/pose_estimation/DanceToMusic/weights/5_sec_best_model_weights_loss_6.733452348148122.pth' 
learned_weights = '/home/azeez/Documents/projects/model_saves/run_20240119-171022/gen_3_sec_dnb__best_model_0.0199.pt'

# device = torch.device("mps")
embed_size = dataset.data['poses'].shape[2] * dataset.data['poses'].shape[3]
pose_model = Pose2AudioTransformer(codebook_size, src_pad_idx, trg_pad_idx, device=device, num_layers=4, heads = 2, embed_size=embed_size, dropout = 0.1)
pose_model.load_state_dict(torch.load(learned_weights, map_location=device))
pose_model.to(device)

Pose2AudioTransformer(
  (encoder): Encoder(
    (position_embedding): Embedding(2000, 96)
    (layers): ModuleList(
      (0-3): 4 x TransformerBlock(
        (attention): SelfAttention(
          (values): Linear(in_features=48, out_features=48, bias=False)
          (keys): Linear(in_features=48, out_features=48, bias=False)
          (queries): Linear(in_features=48, out_features=48, bias=False)
          (fc_out): Linear(in_features=96, out_features=96, bias=True)
        )
        (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): ReLU()
          (2): Linear(in_features=384, out_features=96, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): Decoder(
    (codebook_embedding): Embedding(1024

In [41]:
from IPython.display import Audio, display


def audioCodeToWav(audio_code, encodec_model, sample_rate = 24000):
    batch_size = audio_code.shape[0]
    print(audio_code.shape)
    audio_code = audio_code.reshape(batch_size,1,2,int(audio_code.shape[2]))

    # Check if the devices of audio_code and encodec_model.decoder are the same
    if encodec_model.device != audio_code.device:
        raise ValueError("The device of encodec_model.decoder and audio_code must be the same.")
    device = audio_code.device

    decoded_wavs = []
    # Iterate through each frame in audio_code and decode them individually
    for i in range(audio_code.size(0)):
        single_audio_code = audio_code[i:i+1]  # Extracting a single frame
        single_wav = encodec_model.decode(single_audio_code.int(), [None])
        decoded_wavs.append(single_wav[0])

    # Concatenate the decoded audio samples into a single tensor
    # Use torch.stack if each wav sample has more than one dimension, else use torch.cat
    print(len(decoded_wavs))
    print(decoded_wavs[0].shape)
    if decoded_wavs[0].ndim > 1:
        combined_wav = torch.stack(decoded_wavs, dim=0)
    else:
        combined_wav = torch.cat(decoded_wavs, dim=0)

    print(f"Combined_wa {combined_wav.shape}")
    print(combined_wav.squeeze(1).shape)
    return combined_wav.squeeze(1).to(device)

In [42]:
audio_codes, pose, pose_mask, wav, wav_mask, _, _ = dataset[0]
output = pose_model.generate(pose.unsqueeze(0).to(device), pose_mask.to(device), max_length = audio_codes.shape[0]+1, temperature = 1)
print(output[0][:20])
print(output.shape)
wav = audioCodeToWav(output.unsqueeze(0), encodec_model, sample_rate = 24000)
print(wav.shape)
display(Audio(wav[0].to('cpu').detach().numpy(), rate=24000))

tensor([[ 835,  688],
        [ 259,  474],
        [ 396,  583],
        [ 511,   36],
        [ 797,  281],
        [ 633,  551],
        [ 471,  145],
        [ 924,  904],
        [ 660,  738],
        [ 174,  237],
        [ 637, 1018],
        [ 702,  722],
        [ 334,   34],
        [ 463,  661],
        [ 104,  337],
        [ 674,  477],
        [ 734,  334],
        [ 606,  683],
        [ 704,  329],
        [ 309,  815]], device='cuda:0')
torch.Size([1, 227, 2])
torch.Size([1, 1, 227, 2])
1
torch.Size([1, 1, 72640])
Combined_wa torch.Size([1, 1, 1, 72640])
torch.Size([1, 1, 72640])
torch.Size([1, 1, 72640])


In [43]:
for i in range(5):
    audio_codes, pose, pose_mask, wav, wav_mask, _, _ = dataset[i]
    sample = pose_model.generate(pose.unsqueeze(0).to(device), pose_mask.to(device), max_length = 100)
    print(sample[0,:10])

tensor([[377, 403],
        [895,  32],
        [627, 538],
        [861, 643],
        [967, 132],
        [335, 895],
        [309, 239],
        [712, 110],
        [978, 921],
        [327, 519]], device='cuda:0')
tensor([[627, 132],
        [425, 968],
        [422, 483],
        [607, 630],
        [ 87, 811],
        [776, 176],
        [ 52, 714],
        [173, 580],
        [896, 724],
        [368, 764]], device='cuda:0')


tensor([[1002,  776],
        [ 688,  918],
        [ 143,  982],
        [ 520,  158],
        [ 636,  948],
        [ 993,  523],
        [ 714,   27],
        [ 906,  330],
        [ 967,   83],
        [ 637,  759]], device='cuda:0')
tensor([[ 808,  348],
        [ 996, 1009],
        [ 475,  526],
        [ 290,   36],
        [ 811,  529],
        [ 896,  893],
        [ 800,  232],
        [ 260,  115],
        [ 350,  123],
        [ 136,  256]], device='cuda:0')
tensor([[707, 636],
        [442, 229],
        [403, 494],
        [391, 970],
        [729, 415],
        [291, 272],
        [ 27, 973],
        [419, 407],
        [  1, 837],
        [789,  33]], device='cuda:0')


In [45]:
print(output.shape)
wav = audioCodeToWav(output.unsqueeze(0), encodec_model, sample_rate = 24000)
print(wav.shape)
display(Audio(wav[0][0].to('cpu').detach().numpy(), rate=24000))

torch.Size([1, 227, 2])
torch.Size([1, 1, 227, 2])
1
torch.Size([1, 1, 72640])
Combined_wa torch.Size([1, 1, 1, 72640])
torch.Size([1, 1, 72640])
torch.Size([1, 1, 72640])


: 

In [None]:
output[0].shape, audio_codes[0].shape

(torch.Size([60]), torch.Size([2, 759]))

In [None]:
output[0], audio_codes[0][0]

(tensor([401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401,
         401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401,
         401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401,
         401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401, 401,
         401, 401, 401, 401], device='mps:0'),
 tensor([ 121,  395,  537,  537,  662,  401,   34,  568,  844,  572,  231,  758,
          715,  637,  790,  568,  446,  657, 1021,  657,  419,  713,  322,  568,
          568,  568,  924,  560,  713,  384,  445,  754,  509,  362,  568,  434,
          797,  352,  246,  189,  568,  713,  659,  568,  568,  568,  568,  659,
          560,  169,  560,  701,  788,  659,  817,  437,  560,  531,  560,  782,
          568,  568,  568,  560,  543,  654,  631,  152,  152,  715,  388,  388,
          388,  366,  844,  568,  388,  388,  388,  388,  213,  213,  213,  560,
          388,  388,  659,  790,  830,  713, 1021,  790,  322,  560,  

In [None]:
print(output.shape)
print(output[0].reshape(1,1,2,int(output.size(1)/2)).shape)
wav = audioCodeToWav(output.unsqueeze(0), encodec_model, sample_rate = 24000, device=device)['audio_values']
display(Audio(wav[0].detach().numpy(), rate=24000))

torch.Size([1, 754])
torch.Size([1, 1, 2, 377])


IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)