In [200]:
import pandas as pd
import os
os.chdir('/Users/markjos/projects/malachor5')
import numpy as np
from typing import Literal
from tqdm import tqdm
tqdm.pandas()
from glob import glob
import math
from string import punctuation
from random import shuffle
import shutil

import sys
sys.path.append('scripts')
from eval import get_word_language
from longform import load_and_resample
import torchaudio
import torch
import json

from transformers import WhisperTokenizer
from datasets import load_dataset

# Steps to generate dataset
- Run text proc (use existing code)
  - Add 'valid' col and set excluded rows to False (both malformed English and autogen Tira)
- Load Tira CS idcs & map to ASR index
  - For each ASR index, check if an unvalid record exists within its time domain
  - `is_punct==True` should not be excluded here, under the assumption that it reflects silence
  - Create new dataset only containing valid records, save and upload


In [49]:
snippet_csv = 'data/elicitation-wavs/autotranscribed/metadata.csv'
df = pd.read_csv(snippet_csv, keep_default_na=False)
df['filestem']=df['wav_source'].apply(
    lambda x: os.path.basename(os.path.splitext(x)[0])
)
print(f"{df.shape=}")
df.head()

df.shape=(109393, 10)


Unnamed: 0.1,Unnamed: 0,wav_path,tier_name,start,end,transcription,eaf_path,wav_source,sli_pred,filestem
0,0,,asr,30,1414,"Hello, hello, hello.",/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,HH01082021
1,1,,asr,2106,4755,"Hello, hello, hello, hello, hello.",/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,HH01082021
2,2,,asr,6240,8502,"Hello, hello. This shit.",/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,HH01082021
3,3,,asr,8535,8569,...,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,HH01082021
4,4,,asr,8721,11269,Hello. Hello. Hello.,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,HH01082021


In [50]:
balance_df = pd.read_csv('/Users/markjos/projects/malachor5/notebooks/longform_dataset/balance_df.csv', index_col='index')
balance_df.head()

Unnamed: 0_level_0,Unnamed: 0,tier_name,start,end,transcription,eaf_path,wav_source,sli_pred,asr_index,split,filestem,duration,lang_balanced_dataset,label_path,clip_name
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,0,asr,30.0,1414.0,hello hello hello,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,,,HH01082021,1384.0,,data/elicitation-wavs/wav/labels/0.txt,data/elicitation-wavs/wav/clips/0.wav
1,1,asr,6240.0,8502.0,hello hello this shit,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,,,HH01082021,2262.0,,data/elicitation-wavs/wav/labels/1.txt,data/elicitation-wavs/wav/clips/1.wav
2,2,asr,11387.0,13244.0,i do not know,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,,,HH01082021,1857.0,,data/elicitation-wavs/wav/labels/2.txt,data/elicitation-wavs/wav/clips/2.wav
3,3,asr,13649.0,13986.0,yeah,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,,,HH01082021,337.0,,data/elicitation-wavs/wav/labels/3.txt,data/elicitation-wavs/wav/clips/3.wav
4,4,asr,26625.0,27064.0,thank you,/home/AD/mjsimmons/datasets/elicitation-wavs/a...,/home/AD/mjsimmons/datasets/elicitation-wavs/m...,ENG,,,HH01082021,439.0,,data/elicitation-wavs/wav/labels/4.txt,data/elicitation-wavs/wav/clips/4.wav


In [51]:
tira_cs_idcs_path = '/Users/markjos/projects/malachor5/notebooks/longform_dataset/data/tira_cs_indices.txt'
with open(tira_cs_idcs_path) as f:
    tira_cs_idcs = [int(line.strip()) for line in f.readlines()]
tira_cs_idcs

[47327,
 47328,
 47330,
 47334,
 47335,
 47339,
 47342,
 47343,
 47344,
 47345,
 47352,
 47353,
 47357,
 47358,
 47364,
 47366,
 47375,
 47376,
 47377,
 47378,
 47381,
 47386,
 47389,
 47391,
 47392,
 47395,
 47398,
 47399,
 47402,
 47403,
 47410,
 47411,
 47413,
 47424,
 47425,
 47428,
 47442,
 47445,
 47446,
 47447,
 47448,
 47457,
 47460,
 47473,
 47477,
 47480,
 47492,
 47494,
 47495,
 47499,
 47500,
 47506,
 47509,
 47510,
 47511,
 47515,
 47520,
 47522,
 47525,
 47530,
 47532,
 47533,
 47539,
 47540,
 47542,
 47549,
 47552,
 47553,
 47554,
 47555,
 47559,
 47560,
 47565,
 47569,
 47571,
 47576,
 47580,
 47583,
 47585,
 47588,
 47601,
 47608,
 47610,
 47614,
 47617,
 47618,
 47623,
 47630,
 47631,
 47634,
 47635,
 47636,
 47640,
 47641,
 47644,
 47647,
 47648,
 47650,
 47655,
 47658,
 47660,
 47663,
 47664,
 47666,
 47668,
 47675,
 47677,
 47678,
 47682,
 47683,
 47687,
 47688,
 47689,
 47691,
 47696,
 47697,
 47699,
 47702,
 47707,
 47710,
 47718,
 47724,
 47731,
 47736,
 47744,


In [None]:
tira_cs_asr_idcs = balance_df.loc[tira_cs_idcs, 'asr_index']
tira_cs_asr_idcs=tira_cs_asr_idcs.astype(int).tolist()
tira_cs_asr_idcs[:5]

[2, 3, 8, 14, 15]

In [53]:
long_csv_path = '/Users/markjos/projects/malachor5/data/elicitation-wavs/autotranscribed/longlabels.csv'
long_df = pd.read_csv(long_csv_path)

In [119]:
cs_idx_mask = long_df['asr_index'].isin(tira_cs_asr_idcs)
cs_idx_mask.value_counts()

asr_index
False    15054
True      5426
Name: count, dtype: int64

# Whisper output cleaning
Figure out what labels are excessive repetitions and exclude

In [54]:
tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-medium')
df['normalized']=df['transcription'].apply(tokenizer.normalize)
df['normalized']

0                                 hello hello hello
1                     hello hello hello hello hello
2                             hello hello this shit
3                                                 .
4                                 hello hello hello
                            ...                    
109388                                         okay
109389                                         yeah
109390    but this really studied with the question
109391                                         okay
109392             yes i am going to stop recording
Name: normalized, Length: 109393, dtype: object

In [55]:
df['transcription']=df['normalized']

In [56]:
eng_only = lambda s: all(get_word_language(word) == 'eng' for word in s.split())
df['eng_only'] = df['transcription'].apply(eng_only)
df['eng_only'].value_counts()

eng_only
True     80307
False    29086
Name: count, dtype: int64

In [57]:
# get number of words per sentence
# and number of times most common word repeats
df['num_words']=df['transcription'].str.split().apply(len)
def get_most_freq_word_count(s:str):
    words = s.split()
    words = [x.lower() for x in words]
    if len(words)==0:
        return 0
    return max([words.count(w) for w in words])
df['max_word_count'] = df['transcription'].apply(get_most_freq_word_count)

In [58]:
(df['max_word_count']<5).value_counts()

max_word_count
True     105200
False      4193
Name: count, dtype: int64

In [59]:
str_is_punct = lambda s: all(c in punctuation for c in s.strip())
df['is_punct'] = df['transcription'].apply(str_is_punct)
df['is_punct'].value_counts()

is_punct
False    104488
True       4905
Name: count, dtype: int64

In [130]:
df['is_valid'] = (
    # df['eng_only']&
    (df['eng_only']|(df['sli_pred']=='TIC')) &
    # ~df['is_punct'] &
    (df['max_word_count']<5)
)
df['is_valid'].value_counts()

is_valid
True     79819
False    29574
Name: count, dtype: int64

In [131]:
long_df['start'].apply(type).value_counts()

start
<class 'float'>    20480
Name: count, dtype: int64

In [163]:
filestem_masks = {
    filestem: df['filestem']==filestem
    for filestem in long_df['filestem'].unique()
}
long_df['is_valid']=True
long_df['clean_transcription']=''
for i, row in tqdm(long_df[cs_idx_mask].iterrows(), total=cs_idx_mask.sum()):
    asr_i = row['asr_index']
    start_mask = df['start']>=row['start']
    end_mask = df['end']<=row['end']
    filestem_mask = filestem_masks[row['filestem']]
    # print(f"{start_mask.sum()=}, {end_mask.sum()=}, {filestem_mask.sum()=}")
    within_label_mask = start_mask&end_mask&filestem_mask
    if (within_label_mask.sum()==0) or ((~df.loc[within_label_mask,'is_valid']).sum()>0):
        long_df.at[i,'is_valid']=False
    transcription_start_tuples = list(zip(
        df.loc[within_label_mask, 'transcription'].tolist(),
        df.loc[within_label_mask, 'start'].tolist(),
    ))
    tira_row_transcription = balance_df.loc[balance_df['asr_index']==asr_i,'transcription'].item()
    tira_row_start = balance_df.loc[balance_df['asr_index']==asr_i,'start'].item()
    transcription_start_tuples.append((tira_row_transcription,tira_row_start))
    transcription_start_tuples.sort(key=lambda t:t[1])
    transcription = ' '.join(t[0] for t in transcription_start_tuples).strip()
    long_df.at[i, 'clean_transcription']=transcription
long_df.loc[tira_cs_asr_idcs, 'is_valid'].value_counts()

100%|██████████| 5426/5426 [00:09<00:00, 595.84it/s]


is_valid
True     4302
False    1124
Name: count, dtype: int64

In [164]:
long_df.loc[cs_idx_mask].pivot_table(values='duration', index='is_valid', aggfunc='sum')/3_600_000

Unnamed: 0_level_0,duration
is_valid,Unnamed: 1_level_1
False,12.944179
True,4.455045


In [182]:
cs_valid = long_df.loc[cs_idx_mask,'is_valid']
cs_clean_df = long_df.loc[cs_idx_mask][cs_valid]
cs_clean_df['clean_transcription'].sample().item()

'i will see you next time there is a couple of things on it i am going to go to the next slide and my recent effort was to get a quarter of a school on the standard thank you that he is proof of your hearing yesterday i am going to go ahead and close out the meeting thank you kìjɔ́ ŋɔ́mɔ̀ nɛ̀ ŋɛ́t̪ì ŋìjà unɛɾɛ'

In [186]:
cs_clean_df=cs_clean_df.rename(columns={
    'transcription': 'old_transcription',
    'clean_transcription': 'transcription',
})
cs_clean_df.head()

Unnamed: 0,asr_index,start,end,old_transcription,indices,duration,split,filestem,wav_source,is_valid,transcription
14,15002.0,2479097.0,2502863.0,lárɔ̄lɛ́ úrnɔ̀ lə̀gwɔ̀t̪ɔ́ ɔ́ɟɔ́ únɛ̀rɛ̀ Th...,"[116952, 47979, 129]",23766.0,train,HH04092021,data/elicitation-wavs/wav/HH04092021.wav,True,lárɔ̄lɛ́ úrnɔ̀ lə̀gwɔ̀t̪ɔ́ ɔ́ɟɔ́ únɛ̀rɛ̀ th...
116,15913.0,2976459.0,3006329.0,"íŋgánɔ́nà ŋòðà ŋɛ́ jámlá nd̪ɔ̀bà Okay,...","[117863, 541, 8229, 32428, 51514]",29870.0,train,HH20230420-2-Zoom,data/elicitation-wavs/wav/HH20230420-2-Zoom.wav,True,íŋgánɔ́nà ŋòðà ŋɛ́ jámlá nd̪ɔ̀bà okay ...
137,18476.0,577040.0,604965.0,ŋùdúŋàɲà ŋìcə̀lò But we actually have th...,"[120426, 4143, 587]",27925.0,train,HH20230717-1,data/elicitation-wavs/wav/HH20230717-1.wav,True,ŋùdúŋàɲà ŋìcə̀lò but we actually have th...
138,17383.0,193548.0,220165.0,"ðàmɔ̀cò ðə̀bùlìjí ápríɲâ Okay, I guess...","[119333, 588, 89877]",26617.0,train,HH20210326,data/elicitation-wavs/wav/HH20210326.wav,True,ðàmɔ̀cò ðə̀bùlìjí ápríɲâ okay i guess...
196,5597.0,966408.0,991538.0,"lìɟí là ílə̀ðɛ̀ dìjò Okay. All right, th...","[107547, 86807, 2033, 86808, 763, 99138]",25130.0,train,HH20220111-3,data/elicitation-wavs/wav/HH20220111-3.wav,True,lìɟí là ílə̀ðɛ̀ dìjò okay all right ther...


In [202]:
cs_clean_dir = '/Users/markjos/projects/malachor5/data/hf-datasets/tira_cs_clean'
cs_clean_metadata = '/Users/markjos/projects/malachor5/data/hf-datasets/tira_cs_clean/metadata.csv'
cs_clean_df.to_csv(cs_clean_metadata, index=False)

In [193]:
longlabels_hf_ds_metadata_path = '/Users/markjos/projects/malachor5/data/hf-datasets/tira-longlabels/metadata.csv'
longlabels_hf_dir = '/Users/markjos/projects/malachor5/data/hf-datasets/tira-longlabels'
longlabels_hf = pd.read_csv(longlabels_hf_ds_metadata_path, index_col='asr_index')
longlabels_hf.head()

Unnamed: 0_level_0,start,end,transcription,indices,duration,split,filestem,wav_source,file_name
asr_index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
11230.0,1229590.0,1259547.0,íŋgánɔ̀nà ɛ́léɲé kə́ náɾùwè nd̪ɔ̀bà T...,"[113180, 69522, 101459, 10, 36988, 69525, 7323...",29957.0,train,HH20220719-2,data/elicitation-wavs/wav/HH20220719-2.wav,clips/train/HH20220719-2-m20.0s29.0ms590.0-m20...
14569.0,902320.0,930910.0,weird tone patterns. And so that might either ...,"[42, 116519]",28590.0,train,HH20230414-Zoom-3,data/elicitation-wavs/wav/HH20230414-Zoom-3.wav,clips/train/HH20230414-Zoom-3-m15.0s2.0ms320.0...
14538.0,1318720.0,1344900.0,"làdɔ́ŋnɛ̀ nìðìnɔ́ŋù ùnɛ̀ɾɛ̀ So actually, ...","[116488, 59]",26180.0,train,HH20220629-2,data/elicitation-wavs/wav/HH20220629-2.wav,clips/train/HH20220629-2-m21.0s58.0ms720.0-m22...
12250.0,3154322.0,3184300.0,"from the Karcha, Karcha is actually far away. ...","[31526, 72251, 64, 49843, 37711, 19284, 70756,...",29978.0,train,HH20221127,data/elicitation-wavs/wav/HH20221127.wav,clips/train/HH20221127-m52.0s34.0ms322.0-m53.0...
6154.0,1148379.0,1177450.0,"íŋgáðə́rɔ̀ðà You know, our friend, El Yasse...","[108104, 86, 9008, 63666, 48904]",29071.0,train,HH07242020-Zoom2,data/elicitation-wavs/wav/HH07242020-Zoom2.wav,clips/train/HH07242020-Zoom2-m19.0s8.0ms379.0-...


In [None]:
os.makedirs(
    os.path.join(cs_clean_dir, 'clips', 'train'), exist_ok=True
)

def copy_to_clean_ds(asr_index):
    clip_relpath = longlabels_hf.loc[asr_index, 'file_name']
    assert type(clip_relpath) is str
    shutil.copy(
        os.path.join(longlabels_hf_dir, clip_relpath),
        os.path.join(cs_clean_dir, clip_relpath)
    )

    return clip_relpath

cs_clean_df['file_name'] = cs_clean_df['asr_index'].progress_apply(copy_to_clean_ds)

100%|██████████| 1152/1152 [00:02<00:00, 430.55it/s]


In [203]:
cs_clean_df.to_csv(cs_clean_metadata, index=False)

In [204]:
ds = load_dataset('audiofolder', data_dir=cs_clean_dir)
ds.save_to_disk(cs_clean_dir.replace('hf-datasets', 'pyarrow-datasets'))

Resolving data files: 100%|██████████| 1152/1152 [00:00<00:00, 168491.76it/s]
Downloading data files: 100%|██████████| 1153/1153 [00:00<00:00, 25235.78it/s]
Downloading data files: 0it [00:00, ?it/s]
Extracting data files: 0it [00:00, ?it/s]
Generating train split: 1152 examples [00:00, 2155.92 examples/s]
Saving the dataset (3/3 shards): 100%|██████████| 1152/1152 [00:03<00:00, 314.17 examples/s]
