# SpICE Classifier model (Pytorch Usage)
Licensed under the Apache License, Version 2.0

Paper: Speech Intelligibility Classifiers from Half-a-Million Utterances

This colab walks through how to download and use the SpICE wav2vec2 based speech intelligibility classifier. This colab walks you through how to use the model on a sample audio file.

In [None]:
#@title Imports

import os
import numpy as np
import pickle
import tensorflow as tf

import IPython
import matplotlib
import matplotlib.pyplot as plt
import requests
import torch
import torchaudio

from google.colab import drive

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.__version__)
print(torchaudio.__version__)
print(device)

SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"  # noqa: E501
SPEECH_FILE = "_assets/speech.wav"
EXP_SAMPLE_RATE = 16000

if not os.path.exists(SPEECH_FILE):
    os.makedirs("_assets", exist_ok=True)

In [None]:
!wget "https${SPEECH_URL}"
!mv Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav _assets/speech.wav

In [None]:
#@title Give permissions to access Google Drive
drive.mount('/content/gdrive')
MODEL_HOME = "/content/gdrive/MyDrive/euphonia/spice-w2v2-models/" #@param

In [None]:
spice_w2v2_cls_model = torch.jit.load(f'{MODEL_HOME}/SpICE_w2v2_cls_scripted.pt')
spice_w2v2_cls_model.eval()

In [None]:
def get_waveform(fpath):
  waveform, sample_rate = torchaudio.load(fpath)
  waveform = waveform.to(device)

  if sample_rate != EXP_SAMPLE_RATE:
    waveform = torchaudio.functional.resample(waveform, sample_rate, EXP_SAMPLE_RATE)
  return waveform

def get_prediction(fpath):
  waveform = get_waveform(fpath)
  with torch.inference_mode():
    output = spice_w2v2_cls_model(waveform)
    return output

In [None]:
get_prediction(SPEECH_FILE)