# Evaluate Whisper on ANSP Dataset

In [None]:
dts = 'ANSP/audio'
mdl = 'openai/whisper-large-v2'
wsp = '-'.join(mdl.split('-')[1:])

print('Dataset: ', dts)
print('Model  : ', mdl)
print('Whisper: ', wsp)

### Initialize

In [None]:
import os
import glob
import numpy as np
import pandas as pd
from datetime import datetime
import time

### List all Files

In [None]:
wav_files = glob.glob(dts+'/*.wav')
if len(wav_files) == 0:
    raise Exception('No wav files found. Please check the path.')
else:
    print('Found {} audio files'.format(len(wav_files)))

### Run Inference

In [None]:
import whisper
model = whisper.load_model('-'.join(mdl.split('-')[1:]))

In [None]:
print('Starting inference...')
df = pd.DataFrame(columns=['hyp-prmpt', 'hyp-clean', 'ref', 'file_name'])
time_start = time.time()
nato = "alpha,bravo,charlie,delta,echo,foxtrot,golf,hotel,india,juliett,kilo,lima,mike,november,oscar,papa,quebec,romeo,sierra,tango,uniform,victor,whiskey,xray,yankee,zulu"
terminology = "climb, climbing, descend, descending, passing, feet, knots, degrees, direct, maintain, identified, ILS, VFR, IFR, contact, frequency, turn, right, left, heading, altitude, flight, level, cleared, squawk, approach, runway, established, report, affirm, negative, wilco, roger, radio, radar, right, left, center"
sids = "BERGI WISPA ANDIK BETUS NOPSU SPY TORGA ARNEM ELPAT NYKER EDUPO IVLUT RENDI LOPIK OGINA ROVEN KUDAD LARAS WOODY IDRID VOLLA"

for file in wav_files:
    prompt = 'Air Traffic Control Communications ' + sids.replace(',',' ') + ' ' + nato.replace(',',' ') + ' ' + terminology.replace(',',' ')
    
    res_prmpt = model.transcribe(file, initial_prompt=prompt, language='en', fp16=False)
    res_clean = model.transcribe(file, language='en', fp16=False)
    df.loc[len(df.index)] = [res_prmpt['text'], res_clean['text'], ' ', file]
    
    i = wav_files.index(file)+1
    print('Inference: {:.3f} %'.format(i/len(wav_files)*100), end='\r')

time_end = time.time()
print('Finished {} files in {:.2f} seconds'.format(len(wav_files), (time_end-time_start)/60))
df.to_excel('ANSP-'+mdl.split('/')[-1]+'-'+datetime.today().strftime('%Y-%m-%d--%H:%M:%S')+'.xlsx')

In [None]:
if len(df) > len(wav_files):
    print('WARNING: The length of the DataFrame is longer than the amount of files. Please check the DataFrame.')
elif len(df) < len(wav_files):
    print('WARNING: The length of the DataFrame is shorter than the amount of files. Please check the DataFrame.')
else:
    print('CHECK: DataFrame and amount of files are equal.')

### Normalization

In [None]:
import sys
import os
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent+'/Evaluate')
from Normalizer import filterAndNormalize

In [None]:
df['ref-norm'] = df.apply(lambda x: filterAndNormalize(x['ref']), axis=1)
df['hyp-clean-norm'] = df.apply(lambda x: filterAndNormalize(x['hyp-clean']), axis=1)
df['hyp-prmpt-norm'] = df.apply(lambda x: filterAndNormalize(x['hyp-prmpt']), axis=1)

### WER Calculation

In [None]:
import jiwer

In [None]:
def calcWER(df):
    dff = df
    wer_cln = jiwer.wer(list(dff['ref']), list(dff['hyp-clean']))
    wer_prm = jiwer.wer(list(dff['ref']), list(dff['hyp-prmpt']))
    wer_cln_nrm = jiwer.wer(list(dff['ref-norm']), list(dff['hyp-clean-norm']))
    wer_prm_nrm = jiwer.wer(list(dff['ref-norm']), list(dff['hyp-prmpt-norm']))

    print('clean        : {} %'.format(round(wer_cln*100,4)))
    print('prmpt        : {} %'.format(round(wer_prm*100,4)))
    print('clean-norm   : {} %'.format(round(wer_cln_nrm*100,4)))
    print('prmpt-norm   : {} %'.format(round(wer_prm_nrm*100,4)))

In [None]:
wsp = '-'.join(mdl.split('-')[1:])

print('Dataset: ', dts)
print('Model  : ', mdl)
print('Whisper: ', wsp)

calcWER(df)

wer = jiwer.wer(list(dff['hyp-prmpt']), list(dff['hyp-clean']))
print()
print('WER - prmpt vs. clean: {} %'.format(round(wer*100,4)))