# Filter by WER

> "Also, find determiner location"

- branch: master
- hidden: true
- categories: [hsi, wer, groundinggpt]

In [21]:
PATHS = (
    ("/tmp/numered_all.json", "/tmp/gpt4o-generated-speech"),
    ("/tmp/ggpt_numbered.json", "/tmp/groundinggpt-generated-speech"),
    ("/tmp/ggpt_numbered2.json", "/tmp/groundinggpt-generated-speech2")
)

In [1]:
import json

with open("/tmp/numered_all.json") as inf:
    data = json.load(inf)

In [9]:
from string import punctuation

punct = set(punctuation)

In [None]:
def gather_phrase(text):
    words = text.split(" ")
    phrase = []
    inphrase = False
    start = 0
    for idx, word in enumerate(words):
        if inphrase:
            if word.endswith("*") or (word[-1] in punct and word[-2] == "*"):
                word = word.replace("*", "")
                phrase.append(word)
                inphrase = False
            else:
                phrase.append(word)
        else:
            if word.startswith("*"):
                if word.endswith("*") or (word[-1] in punct and word[-2] == "*"):
                    return idx, word.replace("*", "")
                else:
                    phrase.append(word.strip("*"))
                    start = idx
                inphrase = True
            else:
                continue
    return start, (" ".join(phrase)).strip("*")

In [None]:
def index_of_determiner(phrase):
    dp = gather_phrase(phrase)
    words = phrase.split(" ")
    for i, word in enumerate(words):
        if word in ["this", "that", "these", "those"]:
            return i
    return -1

In [None]:
def clean_text(text):
    text = text.replace("’", "'")
    text = text.replace("—", " ")
    words = text.split(" ")
    cleaned = [w.lower().strip(punctuation) for w in words]
    return " ".join(cleaned)

In [None]:
!pip install jiwer

In [24]:
from jiwer import wer

In [22]:
def prune_fillers(text):
    FILLERS = ["uh", "um"]
    words = [x for x in text.split(" ") if x not in FILLERS]
    return " ".join(words)

In [None]:
FIXES = {
    "hsi_4_0717_222_002__0__8": ("racket", "racquet")
}

In [49]:
def read_tsv(tsvfile):
    tsvdata = []
    with open(tsvfile) as inf:
        for line in inf:
            line = line.strip()
            if not line:
                continue
            tsvdata.append(line.split("\t"))
    return tsvdata

In [77]:
def get_phrase_in_tsv(tsvdata, phrase):
    phrase_parts = clean_text(phrase).split(" ")
    tsvwords = [clean_text(x[2]) for x in tsvdata]
    for i in range(len(tsvwords) - len(phrase_parts) + 1):
        if tsvwords[i:i + len(phrase_parts)] == phrase_parts:
            return i
    return -1

In [140]:
from pathlib import Path
try:
    import librosa
    LIBROSA_AVAILABLE = True
except ImportError:
    LIBROSA_AVAILABLE = False

def process_data(data, tsvpath, gatherable_phrase = True, THRESHOLD = 0.3):
    discard_ids = []
    collected_data = []

    for item in data:
        person = item["person"]
        fileid = item["id"]
        text = item["snippet"]

        tsvfile = Path(tsvpath) / f"{fileid}.tsv"
        tsvdata = read_tsv(tsvfile)
        
        tsvwords = [x[2] for x in tsvdata]
        tsvtext = " ".join(tsvwords)

        cleaned_text = clean_text(text.strip())
        cleaned_tsv = clean_text(tsvtext)

        if gatherable_phrase:
            phrase = gather_phrase(text)
            phrase_parts = phrase.split(" ")
            if phrase_parts[0] in ["a", "the"]:
                discard_ids.append(fileid)
                continue

        if cleaned_text != cleaned_tsv:
            if cleaned_tsv.replace("-", " ") == cleaned_text:
                cleaned_tsv = cleaned_tsv.replace("-", " ")

        current = {
            "person": person,
            "fileid": fileid,
            "text": text,
            "tsv_text": tsvtext,
            "room": item["room"],
            "topic": item["topic"],
            "filename": item["filename"],
        }

        if LIBROSA_AVAILABLE:
            wavfile = Path(tsvpath) / f"{fileid}.wav"
            if wavfile.exists():
                y, sr = librosa.load(wavfile, sr=None)
                current["duration"] = librosa.get_duration(y=y, sr=sr)
            else:
                print("Wav file not found", wavfile)

        if gatherable_phrase:
            phrase_index = get_phrase_in_tsv(tsvdata, phrase)
            if phrase_index == -1 and not "determiner_index" in current:
                discard_ids.append(fileid)
                current["discarded"] = True
                current["disard_reason"] = "Phrase not in tsv"
                continue
            else:
                current["determiner_index"] = phrase_index
                current["determiner_start"] = tsvdata[phrase_index][0]
                current["determiner_end"] = tsvdata[phrase_index][1]
                current["determiner_duration"] = float(tsvdata[phrase_index][1]) - float(tsvdata[phrase_index][0])
        else:
            for idx, word in enumerate(tsvwords):
                word = clean_text(word)
                if word in ['this','that','one','those','these','there','here']:
                    current["determiner_index"] = idx
                    current["determiner_start"] = tsvdata[idx][0]
                    current["determiner_end"] = tsvdata[idx][1]
                    current["determiner_duration"] = float(tsvdata[idx][1]) - float(tsvdata[idx][0])
                    break

        if cleaned_text != cleaned_tsv:
            if gatherable_phrase:
                if not phrase in cleaned_tsv and not "determiner_index" in current:
                    discard_ids.append(fileid)
                    current["discarded"] = True
                    current["disard_reason"] = "Phrase not in tsv"
            cur_wer = wer(cleaned_text, cleaned_tsv)
            current["wer"] = cur_wer
            if cur_wer > THRESHOLD:
                discard_ids.append(fileid)
                current["discarded"] = True
                current["discard_reason"] = "wer"
        else:
            current["wer"] = 0.0
        collected_data.append(current)

    return collected_data, discard_ids


In [131]:
a, b = process_data(data, "/tmp/gpt4o-generated-speech")

In [132]:
import json

with open("/tmp/procced.json", "w") as outf:
    json.dump(a, outf, indent=4)

In [134]:
with open("/tmp/discarded.txt", "w") as outf:
    outf.write("\n".join(set(b)))


In [146]:
#("/tmp/ggpt_numbered.json", "/tmp/groundinggpt-generated-speech"),

with open("/tmp/ggpt_numbered.json") as inf:
    data = json.load(inf)

a, b = process_data(data, "/tmp/groundinggpt-generated-speech", gatherable_phrase=False)

In [147]:
import json

with open("/tmp/procced2.json", "w") as outf:
    json.dump(a, outf, indent=4)

In [148]:
with open("/tmp/discarded2.txt", "w") as outf:
    outf.write("\n".join(set(b)))


In [151]:
for item in a:
    if "determiner_start" in item:
        if float(item["determiner_start"]) > 5.0:
            item["discarded"] = True
            item["discard_reason"] = "start too late"
    else:
        if not "discarded" in item:
            item["discarded"] = True
            item["discard_reason"] = "no determiner start"

In [152]:
import json

with open("/tmp/procced2.1.json", "w") as outf:
    json.dump(a, outf, indent=4)

In [153]:
filtered = [x for x in a if not "discarded" in x]

In [156]:
import numpy as np

L = np.load('bvh_pt_lengths.npy',allow_pickle=True)
framerate = 120
point_length = {}
for thing in L:
    item = list(thing.keys())[0].split('/')[-1]
    item = item.replace('.bvh','')
    point_length[item] = list(thing.values())[0]/framerate

In [158]:
synth_length = {}
for item in filtered:
    synth_length[item["fileid"]] = item["duration"]

In [None]:
# get Anna's names from my list
synth_pre = {}
synth_post = {}
synth_data = {}
synth_times = []

pt_times = []
pt_names = []
pt_pre = {}
pt_post = {}

for item in filtered:
    fileid = item["fileid"]
    dem_start = float(item["determiner_start"])
    dem_end = float(item["determiner_end"])
    duration = float(item["duration"])

    synth_pre[fileid] = dem_start
    synth_post[fileid] = dem_end - dem_start
    synth_data[fileid] = (dem_start, duration)
    synth_times.append((duration, dem_start))

    pt_pre[fileid] = dem_start
    pt_post[fileid] = dem_end - dem_start
    pt_names.append(fileid)
    pt_times.append((duration, dem_start))


In [170]:
pt_times = [x for x in synth_times]
pt_pre = {x: synth_pre[x] for x in synth_pre}
pt_post = {x: synth_post[x] for x in synth_post}
pt_names = [x for x in synth_times]

In [171]:
pt_names

[(4.576, 1.6),
 (3.509333333333333, 0.864),
 (2.8266666666666667, 0.913),
 (3.2213333333333334, 1.167),
 (5.066666666666666, 1.263),
 (4.597333333333333, 1.082),
 (3.8186666666666667, 0.779),
 (5.610666666666667, 1.326),
 (2.538666666666667, 0.958),
 (2.848, 0.765),
 (3.1786666666666665, 1.04),
 (2.8693333333333335, 1.058),
 (5.365333333333333, 1.533),
 (6.730666666666667, 1.788),
 (3.04, 1.032),
 (4.458666666666667, 0.841),
 (3.5946666666666665, 1.113),
 (5.024, 2.178),
 (2.3893333333333335, 0.756),
 (3.36, 0.943),
 (2.208, 0.623),
 (2.848, 0.745),
 (2.208, 0.623),
 (3.68, 1.263),
 (3.392, 1.309),
 (4.96, 1.354),
 (2.2613333333333334, 0.627),
 (7.370666666666667, 2.908),
 (6.922666666666666, 2.356),
 (7.210666666666667, 2.739),
 (3.7013333333333334, 1.652),
 (4.597333333333333, 2.193),
 (3.7653333333333334, 0.922),
 (3.530666666666667, 0.883),
 (5.28, 1.887),
 (2.7093333333333334, 0.829),
 (4.277333333333333, 1.896),
 (3.5733333333333333, 2.025),
 (4.597333333333333, 2.648),
 (5.61066