In [None]:
from transformers import AutoModel, AutoProcessor
import torch
import torchaudio
import numpy as np


# load audio
wav, sr = torchaudio.load("dataset/test_wavs/bronya.wav")
# resample if necessary
wav = torchaudio.functional.resample(wav, sr, 16000)

# load model and processor
processor = AutoProcessor.from_pretrained("waveletdeboshir/gigaam-ctc", trust_remote_code=True)
model = AutoModel.from_pretrained("waveletdeboshir/gigaam-ctc", trust_remote_code=True)
model.eval()

input_features = processor(wav[0], sampling_rate=16000, return_tensors="pt", padding=True)
encoder = model.model.encoder

# predict
with torch.no_grad():
    logits = encoder(input_features['input_features'], length=input_features['input_lengths'])[0]
# greedy decoding
greedy_ids = logits.argmax(dim=-1)
# decode token ids to text
transcription = processor.batch_decode(greedy_ids)[0]


You are using a model of type gigaam-ctc to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


In [8]:
logits = logits.transpose(1, 2)
logits.shape

torch.Size([1, 278, 768])

In [12]:
np.linalg.norm(logits.squeeze(0)[0])

25.936773

In [16]:
from transformers import AutoFeatureExtractor, WhisperModel

device = 'mps'

whisper_name = 'openai/whisper-small'
whisper_model = WhisperModel.from_pretrained(
    whisper_name, torch_dtype=torch.float16
).to(device)
del whisper_model.decoder
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)

def semantic_fn(waves_16k):
    ori_inputs = whisper_feature_extractor(
        [waves_16k.squeeze(0).cpu().numpy()],
        return_tensors="pt",
        return_attention_mask=True,
    )
    ori_input_features = whisper_model._mask_input_features(
        ori_inputs.input_features, attention_mask=ori_inputs.attention_mask
    ).to(device)
    with torch.no_grad():
        ori_outputs = whisper_model.encoder(
            ori_input_features.to(whisper_model.encoder.dtype),
            head_mask=None,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )
    S_ori = ori_outputs.last_hidden_state.to(torch.float32)
    S_ori = S_ori[:, : waves_16k.size(-1) // 320 + 1]
    print(S_ori.shape)
    print(np.linalg.norm(S_ori.squeeze(0)[0].cpu().numpy()))
    return S_ori

s_wisper = semantic_fn(wav)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 556, 768])
37.410275


In [62]:
# Load model directly
from transformers import AutoProcessor, AutoModel, Wav2Vec2FeatureExtractor

# processor = AutoProcessor.from_pretrained("facebook/hubert-large-ll60k")
model_hubert = AutoModel.from_pretrained("facebook/hubert-large-ll60k")
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-large-ll60k", trust_remote_code=True)

In [63]:
alt_inputs = processor(
    wav[0],
    return_tensors='pt',
    return_attention_mask=True,
    padding=True,
    sampling_rate=16000
)

In [64]:
alt_inputs.input_values.shape
feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320  # frame rate of hubert is 50 Hz

In [65]:
output_hubert = model_hubert(alt_inputs.input_values, attention_mask=alt_inputs.attention_mask)

In [66]:
last_hidden_states = output_hubert.last_hidden_state
last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
last_hidden_states = last_hidden_states.transpose(1, 2)

In [80]:
feature_lens

tensor([555])

In [79]:
last_hidden_states.shape

torch.Size([1, 1024, 555])