# Whisper zero-shot classification

## Data
Data provided by Gion Sialm (SBB) consisting of 14 samples. All samples have been converted from \*.m4a with 48000 Hz sample rate to \*.wav with 16000 Hz sample rate.

## Import packages

In [1]:
# ASR packages
import whisper

# General packages
import os
from pathlib import Path
import numpy as np

In [2]:
MAIN_DIR = Path("/home/user/code/sbb_project")
DATA_DIR = MAIN_DIR.joinpath("data_sbb")
TEST_DATA_DIR = DATA_DIR.joinpath("zeroshot_data")

In [3]:
import glob
files = glob.glob(os.path.join(TEST_DATA_DIR, '*.wav'))

In [4]:
ground_truth = "Gleis Alpha 44 via Gleis Beta 45"

In [5]:
def whisper_transcribe(files, model_type="small"):
    output = list()
    model = whisper.load_model(model_type)
    options = whisper.DecodingOptions()
    for file in files:
        audio = whisper.load_audio(file)
        audio = whisper.pad_or_trim(audio)
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
        result = whisper.decode(model, mel, options)
        output.append(result.text)
    return output

In [6]:
def wer(ref, hyp ,debug=True):
    r = ref.split()
    h = hyp.split()
    #costs will holds the costs, like in the Levenshtein distance algorithm
    costs = [[0 for inner in range(len(h)+1)] for outer in range(len(r)+1)]
    # backtrace will hold the operations we've done.
    # so we could later backtrace, like the WER algorithm requires us to.
    backtrace = [[0 for inner in range(len(h)+1)] for outer in range(len(r)+1)]
 
    OP_OK = 0
    OP_SUB = 1
    OP_INS = 2
    OP_DEL = 3
    DEL_PENALTY = 1
    INS_PENALTY = 1
    SUB_PENALTY = 1
    
    # First column represents the case where we achieve zero
    # hypothesis words by deleting all reference words.
    for i in range(1, len(r)+1):
        costs[i][0] = DEL_PENALTY*i
        backtrace[i][0] = OP_DEL
    
    # First row represents the case where we achieve the hypothesis
    # by inserting all hypothesis words into a zero-length reference.
    for j in range(1, len(h) + 1):
        costs[0][j] = INS_PENALTY * j
        backtrace[0][j] = OP_INS
    
    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                costs[i][j] = costs[i-1][j-1]
                backtrace[i][j] = OP_OK
            else:
                substitutionCost = costs[i-1][j-1] + SUB_PENALTY # penalty is always 1
                insertionCost    = costs[i][j-1] + INS_PENALTY   # penalty is always 1
                deletionCost     = costs[i-1][j] + DEL_PENALTY   # penalty is always 1
                 
                costs[i][j] = min(substitutionCost, insertionCost, deletionCost)
                if costs[i][j] == substitutionCost:
                    backtrace[i][j] = OP_SUB
                elif costs[i][j] == insertionCost:
                    backtrace[i][j] = OP_INS
                else:
                    backtrace[i][j] = OP_DEL
                 
    # back trace though the best route:
    i = len(r)
    j = len(h)
    numSub = 0
    numDel = 0
    numIns = 0
    numCor = 0
    if debug:
        print("OP\tREF\tHYP")
        lines = []
    while i > 0 or j > 0:
        if backtrace[i][j] == OP_OK:
            numCor += 1
            i-=1
            j-=1
            if debug:
                lines.append("OK\t" + r[i]+"\t"+h[j])
        elif backtrace[i][j] == OP_SUB:
            numSub +=1
            i-=1
            j-=1
            if debug:
                lines.append("SUB\t" + r[i]+"\t"+h[j])
        elif backtrace[i][j] == OP_INS:
            numIns += 1
            j-=1
            if debug:
                lines.append("INS\t" + "****" + "\t" + h[j])
        elif backtrace[i][j] == OP_DEL:
            numDel += 1
            i-=1
            if debug:
                lines.append("DEL\t" + r[i]+"\t"+"****")
    if debug:
        lines = reversed(lines)
        for line in lines:
            print(line)
        print("#cor " + str(numCor))
        print("#sub " + str(numSub))
        print("#del " + str(numDel))
        print("#ins " + str(numIns))
    # return (numSub + numDel + numIns) / (float) (len(r))
    wer_result = round( (numSub + numDel + numIns) / (float) (len(r)), 3)
    return {'WER':wer_result, 'numCor':numCor, 'numSub':numSub, 'numIns':numIns, 'numDel':numDel, "numCount": len(r)}

In [7]:
output = whisper_transcribe(files, "large")

In [8]:
print(output)

['Preis Alpha 444, Preis Beta 45.', 'Gleis A44 via Gleis beta 45', 'Gleis A44 via Gleis B45', 'Gleis A44 via Gleis B45', 'Gleis A44 via Gleis B45', 'Gleis Alpha 4-4 via Gleis Beta 4-5', 'Gleis, A44, via Gleis, Vetter 4, 5', 'Gleis A44 via Gleis B45', 'Gleis A44 via Gleis E45', 'Da ist A244, hier gleich B45.', 'Gleis A44 via Gleis Hegde 45', 'Reise aus der 444, die Jagdreise mit der 445.']


In [9]:
wers = list()
for pred in output:
    word_error = wer(pred, ground_truth, debug=False)
    wers.append(word_error["WER"])

In [10]:
print("WER for all sentences: \n", wers)
print("Average WER: {} \n Minimum WER: {} \n Maximum WER: {}".format(np.mean(wers), np.min(wers), np.max(wers)))

WER for all sentences: 
 [0.833, 0.5, 0.8, 0.8, 0.8, 0.286, 1.0, 0.8, 0.8, 1.167, 0.5, 1.0]
Average WER: 0.7738333333333333 
 Minimum WER: 0.286 
 Maximum WER: 1.167
