In [37]:
from torch.utils.data import DataLoader, Dataset
import librosa
import os
import numpy as np
import pickle

# Assuming you have a dataset of spectrograms
class SpectrogramDataset(Dataset):
    def __init__(self, directory):
        self.sample_rate = 24000
        self.spectrograms_keys = {}
        self.spectrograms = self._load_spectrograms(directory)
        self.poses = self._load_poses(directory)
        self.signals = self._load_signals(directory)
        print(len(self.spectrograms), len(self.poses), len(self.signals))

    def __len__(self):
        return len(self.spectrograms)

    def __getitem__(self, idx):
        dic = {
            'spectrogram': self.spectrograms[idx],
            'pose': self.poses[idx],
            'signal': self.signals[idx]
        }
        return dic
    
    def _load_spectrograms(self, directory):
        # Load your spectrograms here
        spectrograms = []
        for root, dirs, files in os.walk(directory):
            for f in files:
                if '.npy' in f:
                    spec = np.load(os.path.join(root, f))
                    spec = spec[np.newaxis,:,:]
                    spectrograms.append(spec)
                    self.spectrograms_keys[len(spectrograms)-1] = f.split('.')[0]
        return spectrograms

    def _load_poses(self, directory):
        # Load your poses here
        poses = []
        directory = os.path.join(directory, 'poses')
        for root, dirs, files in os.walk(directory):
            for f in files:
                pkl_path = os.path.join(root, f)
                with open(pkl_path, 'rb') as file:
                    pose = pickle.load(file)
                poses.append(pose)
        return poses

    def _load_signals(self, directory):
        signals = []
        directory = os.path.join(directory, 'wavs')
        for root, dirs, files in os.walk(directory):
            for f in files:
                sig_path = os.path.join(root, f)
                s = librosa.load(sig_path, sr=self.sample_rate, mono=True)[0]
                signals.append(s)
        return signals

directory = "/Users/azeez/Documents/pose_estimation/Learning2Dance/processed_data"
dataset = SpectrogramDataset(directory)

print(dataset[0].keys())
print(dataset[0]['spectrogram'].shape)
print(dataset[0]['pose'])
print(dataset[0]['signal'].shape)

51 51 51
dict_keys(['spectrogram', 'pose', 'signal'])
(1, 256, 469)
[array([[0.45631936, 0.06601439],
       [0.45137922, 0.18240485],
       [0.35465093, 0.18734499],
       [0.34004438, 0.36676389],
       [0.34023685, 0.52667887],
       [0.534057  , 0.16803355],
       [0.60691866, 0.25995216],
       [0.68901902, 0.31343394],
       [0.44682403, 0.4637702 ],
       [0.39346629, 0.46854994],
       [0.49028012, 0.65780206],
       [0.53416393, 0.90981316],
       [0.49060091, 0.45912946],
       [0.42744843, 0.69160885],
       [0.33551058, 0.90498637],
       [0.43217471, 0.03727178],
       [0.46617825, 0.03710497],
       [0.38382126, 0.06121754],
       [0.49040844, 0.04692322],
       [0.41786756, 0.95802119],
       [0.40805145, 0.93884234],
       [0.31098098, 0.92414383],
       [0.59220518, 0.9580939 ],
       [0.56786806, 0.96289503],
       [0.5339073 , 0.92445179]]), array([[0.48026226, 0.05367237],
       [0.48017521, 0.17670358],
       [0.39628436, 0.18195683],
     

In [69]:
print(f"Number of Samples: {len(dataset)}")
for i in range(len(dataset)):
    print(dataset[i]['spectrogram'].shape, len(dataset[i]['pose']), dataset[i]['signal'].shape)

Number of Samples: 51
(1, 256, 469) 53 (130613,)
(1, 256, 469) 94 (130613,)
(1, 256, 469) 107 (130613,)
(1, 256, 469) 52 (130613,)
(1, 256, 469) 73 (130613,)
(1, 256, 469) 124 (130613,)
(1, 256, 469) 194 (130613,)
(1, 256, 469) 197 (130613,)
(1, 256, 469) 101 (130613,)
(1, 256, 469) 101 (130613,)
(1, 256, 469) 185 (130613,)
(1, 256, 469) 82 (130613,)
(1, 256, 469) 108 (130613,)
(1, 256, 469) 240 (130613,)
(1, 256, 469) 76 (130613,)
(1, 256, 469) 79 (130613,)
(1, 256, 469) 90 (130613,)
(1, 256, 469) 91 (130613,)
(1, 256, 469) 166 (130613,)
(1, 256, 469) 153 (130613,)
(1, 256, 469) 70 (130613,)
(1, 256, 469) 136 (130613,)
(1, 256, 469) 109 (130613,)
(1, 256, 469) 162 (130613,)
(1, 256, 469) 224 (130613,)
(1, 256, 469) 243 (130613,)
(1, 256, 469) 145 (130613,)
(1, 256, 469) 124 (130613,)
(1, 256, 469) 97 (130613,)
(1, 256, 469) 122 (130613,)
(1, 256, 469) 94 (130613,)
(1, 256, 469) 98 (130613,)
(1, 256, 469) 107 (130613,)
(1, 256, 469) 79 (130613,)
(1, 256, 469) 142 (130613,)
(1, 256, 469

In [70]:
from datasets import load_dataset
from transformers import AutoProcessor, EncodecModel, EncodecFeatureExtractor

# dataset = load_dataset("ashraq/esc50")
# audio_sample = dataset["train"]["audio"][0]["array"]
sampling_rate = 24000

audio_sample = dataset[8]['signal']
model_id = "facebook/encodec_24khz"
model = EncodecModel.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

inputs = processor(raw_audio=audio_sample, return_tensors="pt", sampling_rate=sampling_rate)

outputs = model(**inputs)
audio_codes = outputs.audio_codes # the latent space encoding
audio_values = outputs.audio_values # the reconstructed audio

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [71]:
audio_codes.shape, audio_values.shape

(torch.Size([1, 1, 2, 409]), torch.Size([1, 1, 130613]))

In [72]:
print(audio_codes.shape)
print(audio_codes)

torch.Size([1, 1, 2, 409])
tensor([[[[  62,  408,  835,  835,  835,  659,  661,  307,  817,  372,  661,
            900,  999,   25,  999,  887,  325,  819,  951,   25,  627,  393,
            999,   25,  276,  276,  276,  887,  228,  887,  935,  731,  887,
            370,  690,  228,  339,  370, 1017,  876,  475,  395,  463,  647,
            401,  388,  951,  559,  602,  602,  276,  276,  537,  779,  325,
            339,  731,   63,  341,  499,  404,  475,  404,  834,  257,  395,
            887,  951,  731,  461,  347,  347,  855,  835,  835,  855,  855,
            430, 1017,  463,  662,  401,  511,  341,  602,  602,  602,   23,
            276,  276,  537,  325,  537,  627,  951,  677,  876,  430,  339,
            257,  395,  887,  724,  457,  373,  401,  388,  341,  602,  559,
            677,   23,  935,  537,  537,  208,  887,  543,  814,  756,  629,
            756,  274,  300,  228,  325,  783,  900,  380,  774,  778,  300,
            690,  157,  274,  157,  325,  629,  7

In [73]:
from IPython.display import Audio, display

print("Original audio:")
display(Audio(audio_sample, rate=sampling_rate))

print("Reconstructed audio:")
display(Audio(audio_values[0].detach().numpy(), rate=sampling_rate))

Original audio:


Reconstructed audio:
