In [13]:
import sys
sys.path.append('..')
import torch
from utils.utils import load_cfg_and_ckpt_path
from pathlib import Path
from streaming.b3_streamer_v4 import MultiStreamer
from dataloader.datamodule_bravo_multi import ECoGDataModule
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torchaudio
import tqdm, string, re
import os
from streaming.b3_streamer_v4 import build_dualstreamer_from_file

os.environ["CUDA_VISIBLE_DEVICES"]="6"


In [2]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
asr = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
asr = asr.to('cuda')

In [15]:
TEST_FILE_PATH = '/data/common/b3_paper/packaged_data/streaming_paper_data/tm1k_mimed_slow/tm1k_very_recent_test_b3_tm1k.txt'


In [16]:
data_config = {'root_dir': '/data/common/b3_paper/packaged_data/streaming_paper_data',
              'hb_dir': '/data/cheoljun/b3_audio_scale-2/unit_label/hubert-l6_km100/',
              'textidx_file': None,
               'text_label_file': '/data/cheoljun/b3_misc/audio_info/tm1k_text_20240108.txt',
               'metainfo_path':'/data/common/b3_paper/packaged_data/streaming_paper_data/tm1k_mimed_slow/tm1k_streamer_paper_final_session_dataframe.csv',
               'paradigm':   'tm1k_mimed_slow',
               'hb_paradigm': 'tm1k',
               'batch_size': 24,
               'num_workers': 1,
               'include_ecog': True,
               'include_unit': True,
               'include_phoneme': True,
               'load_list': ['ecog'],
               'text_list': ['unit', 'phoneme'],
               'relabeled':True,
  'train_files':'/data/common/b3_paper/packaged_data/streaming_paper_data/tm1k_mimed_slow/cleaned_train_b3_tm1k.txt',
  'test_files': TEST_FILE_PATH}

In [None]:
datamodule = ECoGDataModule(**data_config)
test_loader = datamodule.test_dataloader()
test_loader.dataset.transform=None

@@@@@@@@ 0 blocks are overlapping in train and val! @@@@@@@


In [None]:
def get_wer(gt_text,pred_text):
    def tokenize(text):
        regex = re.compile('[%s]' % re.escape(string.punctuation))
        text = regex.sub('', text.lower())
        return [t for t in text.split(' ') if t != '']
    pred_text_tokens = tokenize(pred_text)
    gt_text_tokens = tokenize(gt_text)
    
    return torchaudio.functional.edit_distance(pred_text_tokens,gt_text_tokens) /len(gt_text_tokens)

def flatten_list(data):
    result = []
    for i in data:
        result.extend(i)
    return result

def get_asr(wavs):
    with torch.no_grad():
        input_features = processor(wavs, sampling_rate=16000, return_tensors="pt").input_features.to('cuda')
        predicted_ids = asr.generate(input_features)
        synth_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return synth_text

In [None]:
config_names = ['tm1k_restricted', 'tm1k_super_restricted','tm1k_recent_restricted',
                'tm1k_recent_super_restricted','tm1k_very_recent',
                'textonly_tm1k_restricted', 'textonly_tm1k_super_restricted','textonly_tm1k_recent_restricted',
                'textonly_tm1k_recent_super_restricted','textonly_tm1k_very_recent',]
for config_name in config_names:
    multistreamer, configs = build_dualstreamer_from_file(f'{config_name}.yaml', base_dir='/data/common/b3_model_package_20240125')               
for config_name in config_names:
    multistreamer, configs = build_dualstreamer_from_file(f'{config_name}.yaml', base_dir='/data/common/b3_model_package_20240125')
    text_wer_results=[]
    synth_wer_results=[]
    nonexpand_synth_wer_results=[]
    gt_synth_wer_results=[]
    first_text_emit_times = []
    buffer_size = configs['buffer_size']
    wavs_all = []
    for di in tqdm.tqdm(range(len(test_loader.dataset))):
        x = test_loader.dataset.__getitem__(di)
        input_all = x['ecog']

        wavs = []
        multistreamer.clear_cache()
        first_text_emit_time = -1
        for i in range(0,int(len(input_all)//buffer_size*buffer_size),buffer_size):
            ecog = input_all[i:i+buffer_size]
            outputs = multistreamer(ecog)
            wav= outputs['wav']
            if 'text' in outputs:
                text=outputs['text']
            else:
                text = ''
            if len(text)>0 and first_text_emit_time <0:
                first_text_emit_time = (i+1)*5 # in ms unit
            #print(f'PRED[{int(1/200*i*1000):03d}ms]-', f'{abs(wav).mean():.02f}', text)
            wavs.append(wav)
        first_text_emit_times.append(first_text_emit_time)
        # clearing buffer
        for i in range(100):
            wavs.append(multistreamer(None)['wav'])
        wavs = np.concatenate(wavs)
        wavs_all.append(wavs)
        synth_text = get_asr(wavs)

        #nonexpand_wavs = multistreamer.streamers[0].vocoder.synthesize_v2(torch.tensor(flatten_list(multistreamer.streamers[0].unit_history)), alpha = 2.0, min_dur = 1)
        #nonexpand_synth_text =  get_asr(nonexpand_wavs)

        #gt_synth = multistreamer.streamers[0].vocoder.synthesize_v2(torch.tensor(np.array(x['text']['unit'].split(' ')).astype(int)), skip_duration=True)
        #gt_synth_text = get_asr(gt_synth)
        gt_text = x['text']['phoneme']
        text_wer = get_wer(gt_text,text)
        synth_wer = get_wer(gt_text, synth_text)
        #nonexpand_synth_wer = get_wer(gt_text, nonexpand_synth_text)
        #gt_synth_wer = get_wer(gt_text, gt_synth_text)
        synth_wer_results.append(synth_wer)
        #nonexpand_synth_wer_results.append(nonexpand_synth_wer)
        #gt_synth_wer_results.append(gt_synth_wer)
        text_wer_results.append(text_wer)

    model_name = Path(configs['rnnt_ckpt_path']).stem
    wers = [text_wer_results, synth_wer_results]#, nonexpand_synth_wer_results, gt_synth_wer_results]
    labels = ['BPE','HB100']#,'HB100\nNon-realtime','Ground truth\nHB100' ]
    fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(5, 4))

    # rectangular box plot
    bplot1 = ax1.boxplot(wers,
                         vert=True,  # vertical box alignment
                         patch_artist=True,  # fill with color
                         labels=labels)  # will be used to label x-ticks
    ax1.set_ylabel('WER')
    for i, wer in enumerate(wers):
        wer=[w for w in wer if w <20]
        ax1.text(i+.85,1.6, f'{np.mean(wer):.02f}')
    ax1.set_ylim(0,1.7)
    ax1.set_title(f'{config_name}',fontsize=12)
    
    for wer, label in zip(wers, labels):
        l_ = label.replace('\n', ' ')
        wer=[w for w in wer if w <20]
        print(f'[{config_name}] WER of {l_} - {np.mean(wer):.02f}')
    print(f'[{config_name}] First text emit time - {np.mean(first_text_emit_times):.01f}ms')
    
    plt.show()
