In [1]:
import torch
import torch.nn as nn
from models.audio_encoder import AudioEncoder
from models.text_encoder import TextEncoder
import torch.nn.functional as F
import copy
from tools.losses import AudioTextContrastiveLoss, NTXent
from tools.utils import remove_grad
import ruamel.yaml as yaml
import librosa
import random
import numpy as np


INFO:numexpr.utils:Note: NumExpr detected 56 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [2]:
config_file_path = './settings/vamp.yaml'

with open(config_file_path, "r") as f:
        config = yaml.safe_load(f)

In [3]:
# Load audio signal file
wav_file_path = '../../dac/audio_samples/at2_cvt.wav'
waveform, _ = librosa.load(wav_file_path, sr=config["audio_args"]["sr"], mono=True)
print('waveform shape before crop: ', waveform.shape)
if config["audio_args"]["max_length"] != 0:
            # if audio length is longer than max_length, we random crop it
            max_length = config["audio_args"]["max_length"] * config["audio_args"]["sr"]
            if waveform.shape[-1] > max_length:
                max_start = waveform.shape[-1] - max_length
                start = random.randint(0, max_start)
                waveform = waveform[start: start + max_length]
                
print('waveform shape: ', waveform.shape)
waveform_tensor = torch.tensor(waveform[None, :])
print('waveform_tensor shape: ', waveform_tensor.shape)


waveform shape before crop:  (661500,)
waveform shape:  (441000,)
waveform_tensor shape:  torch.Size([1, 441000])


In [4]:
batch_size = 5
batch_waveform_tensor = waveform_tensor.repeat(batch_size, 1)
print(batch_waveform_tensor.shape)

torch.Size([5, 441000])


In [1]:
from encodec import EncodecModel
from encodec.utils import convert_audio

import torchaudio
import torch

# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
print("target_bandwidths have:", model.target_bandwidths) # [1.5, 3.0, 6, 12.0, 24.0] 
model.set_target_bandwidth(6.0) # means 8 quantizers, 24.0 == 32 n_q

# Load and pre-process the audio waveform
wav, sr = torchaudio.load("../../dac/audio_samples/at2_cvt.wav")
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
print("wav shape:", wav.shape)

# Extract discrete codes from EnCodec
encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]
print("codes shape:", codes.shape)

target_bandwidths have: [1.5, 3.0, 6, 12.0, 24.0]
wav shape: torch.Size([1, 1, 360000])
codes shape: torch.Size([1, 8, 1125])


In [4]:
model.quantizer.vq.layers[0]._codebook.embed.shape

torch.Size([1024, 128])

In [3]:
model.quantizer.vq.layers[0]._codebook.embed # to show the 5+1=6th codebook of RVQ. shape is (1024,128) corresponding to (nbr_entries, dimensionality)


tensor([[ 5.3395, 13.1336, -3.3514,  ...,  2.2543, -4.5506,  3.7425],
        [ 2.5562, 13.8098, -5.7393,  ...,  0.4362, -2.5406,  1.5548],
        [ 3.9551, 12.0306, -6.5480,  ...,  1.6861, -5.3334,  1.3966],
        ...,
        [ 2.3868, 11.8062, -3.8374,  ..., -0.3132, -3.2393,  1.8929],
        [ 1.1349, 11.0860, -2.8491,  ..., -0.6624, -1.4591,  1.9885],
        [ 3.7719, 12.2859, -3.8640,  ...,  1.1728, -3.3949,  3.3238]])

In [12]:
encoded_frames[0][0].shape

torch.Size([1, 8, 1125])

In [13]:
encoded_frames[0][0][0][0]

tensor([738, 244, 843,  ..., 106, 106, 121])

In [10]:

from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor, EncodecFeatureExtractor
import torchaudio
from encodec.utils import convert_audio

# wav, sr = torchaudio.load("../../dac/audio_samples/at2_cvt.wav")
# wav = wav.unsqueeze(0)
# wav = convert_audio(wav, sr, 24000, 1)
# print("wav shape:", wav.shape)

librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]

# audio_sample = wav
print("audio_sample shape:", audio_sample.shape)
inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
print("inputs:", inputs)
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
print("encoder_outputs.audio_codes shape:", encoder_outputs.audio_codes.shape)
# print( encoder_outputs.audio_codes)
print("encoder_outputs.audio_scales length:", len(encoder_outputs.audio_scales))
print( encoder_outputs.audio_scales)
# audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
# print("audio_values shape:", audio_values.shape)
# print(audio_values)
# or the equivalent with a forward pass
# audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
# print("audio_values shape:", audio_values.shape)
# print(audio_values)


audio_sample shape: (107520,)
inputs: {'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32), 'input_values': tensor([[[ 0.0020,  0.0008, -0.0011,  ..., -0.0006, -0.0008, -0.0005]]])}
encoder_outputs.audio_codes shape: torch.Size([1, 1, 2, 336])
encoder_outputs.audio_scales length: 1
[None]


In [5]:
audio_encoder = AudioEncoder(config)
# settings for projection layers
embed_size = config["embed_size"]
audio_width = audio_encoder.audio_width

loading model hugggof/vampnet-models:vampnet-9codebook-linear-sched-best from the huggingface hub.


In [6]:
audio_feats = audio_encoder(batch_waveform_tensor)
# audio_embeds = F.normalize(self.audio_proj(audio_feats), dim=-1)
print('audio_encoded.shape: ', audio_feats.shape)
audio_feats

audio_encoded.shape:  torch.Size([1, 1280, 861])


tensor([[[ 0.1130,  0.1130,  0.1130,  ...,  0.1130,  0.1130,  0.1130],
         [-0.1961, -0.1961, -0.1961,  ..., -0.1961, -0.1961, -0.1961],
         [-0.1603, -0.1603, -0.1603,  ..., -0.1603, -0.1603, -0.1603],
         ...,
         [ 0.1852,  0.1852,  0.1852,  ...,  0.1852,  0.1852,  0.1852],
         [ 0.1420,  0.1420,  0.1420,  ...,  0.1420,  0.1420,  0.1420],
         [ 0.0394,  0.0394,  0.0394,  ...,  0.0394,  0.0394,  0.0394]]],
       grad_fn=<CatBackward0>)