# Testing our model

## Step 1: Pull our model from HuggingFace

In [12]:
# Load model directly
from transformers import AutoModelForSpeechSeq2Seq, WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
model = AutoModelForSpeechSeq2Seq.from_pretrained("f-azm17/whisper-small-singapore-aphasia").to("mps")

In [13]:
model.eval()

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

In [19]:
import numpy as np
import torch

def transcribe(
        model: AutoModelForSpeechSeq2Seq, 
        processor: WhisperProcessor, 
        waveform_path: str = '', 
        device: str = "cpu",
        sample_rate: int = 16000
    ) -> str:
    try:
        waveform = np.load(waveform_path)
    except:
        raise OSError
    
    inputs = processor(waveform, sr=sample_rate, return_tensors="pt").to(device)

    with torch.no_grad():
        generated_ids = model.generate(inputs["input_features"])

    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription

In [39]:
transcription = transcribe(model, processor, 'samples/waveform/waveform_al_e048_C-10.wav.npy', "mps")
print(transcription)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


! Postman is delivering letter.


In [1]:
from model import WhisperForSingaporeAphasia

model = WhisperForSingaporeAphasia()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
transcription = model.transcribe('samples/waveform/waveform_al_e048_C-10.wav.npy')
print(transcription)

! Postman is delivering letter.


In [4]:
transcription = model.transcribe('samples/waveform/waveform_al_e028_C-05.wav.npy')
print(transcription)

! The person is flying a plane.


In [22]:
transcription = model.transcribe('samples/waveform/waveform_al_e085_C-04.wav.npy')
print(transcription)

! What was that? This man was reading the books.
