In [1]:
%load_ext autoreload
%autoreload 2

In [66]:
import os
import omegaconf
import torch
from tqdm.notebook import tqdm

from spell_checking import BASE_DIR, DATA_DIR

from gnn_lib.data import variants, datasets, utils
from gnn_lib.utils import io

os.environ["GNN_LIB_DATA_DIR"] = DATA_DIR
os.environ["GNN_LIB_CONFIG_DIR"] = os.path.join(BASE_DIR, "configs")

In [32]:
tok_plus_conf = omegaconf.OmegaConf.load(os.path.join(os.environ["GNN_LIB_CONFIG_DIR"], "variant", "tokenization_repair_plus_sed_plus_sec.yaml"))
tok_plus = variants.get_variant_from_config(tok_plus_conf, 22)

sec_words_conf = omegaconf.OmegaConf.load(os.path.join(os.environ["GNN_LIB_CONFIG_DIR"], "variant", "sec_words_nmt_transformer.yaml"))
sec_words = variants.get_variant_from_config(sec_words_conf, 22)

In [35]:
tok_plus_dataset = datasets.PreprocessedDataset(os.path.join(DATA_DIR, "processed", "wikidump_paragraphs_tokenization_repair_plus"), tok_plus_conf, 22)
sec_words_dataset = datasets.PreprocessedDataset(os.path.join(DATA_DIR, "processed", "wikidump_paragraphs_sed_words_and_sec"), sec_words_conf, 22)

In [27]:
max([sum(info["sec_label_splits"]) - len(info["sec_label_splits"]) for _, info in tqdm(dataset)])

  0%|          | 0/12129 [00:00<?, ?it/s]

352

In [52]:
lmdb = utils.open_lmdb(os.path.join(DATA_DIR, "processed", "wikidump_paragraphs_tokenization_repair_plus", "lmdb_0"))
txn = lmdb.begin(write=False)
num_elements = datasets._decompress(txn.get(b"dataset_length"))
num_elements

12129

In [63]:
def load_element(idx: int):
    data = txn.get(str(idx).encode("utf8"))
    sample = utils.deserialize_samples([data])[0]
    target = txn.get(f"{idx}_target".encode("utf8")).decode("utf8")
    return sample, target, sample.info["org_sequence"]

In [64]:
sample, target, org = load_element(9)

In [68]:
for _ in tqdm(list(range(num_elements))):
    _, sec_info = sec_words.get_inputs(target, org)
    _, tok_plus_info = tok_plus.get_inputs(sample, target)
    assert torch.equal(sec_info["label"], tok_plus_info["sec_label"])

  0%|          | 0/12129 [00:00<?, ?it/s]

In [115]:
sequence = " thi s isa t e s t sequen c e "
target_sequence = "this is a test sequence"

import re
def get_word_boundaries(sequence, target_sequence):
    pattern = "\s?"
    for word in target_sequence.split():
        pattern += "(" + "\s?".join(re.escape(char) for char in word) + ")\s?"
    match = re.fullmatch(pattern, sequence)
    assert match is not None
    word_boundaries = []
    for g in range(len(match.groups())):
        word_boundaries.append((match.start(g + 1), match.end(g + 1)))
    return word_boundaries

get_word_boundaries(sequence, target_sequence)

[(1, 6), (7, 9), (9, 10), (11, 18), (19, 29)]

In [172]:
match = re.match("\s?(t\s?h\s?i\s?s)\s?(i\s?s)\s?", "thi s is")
print(match)
if match:
    print(match.groups())
    for g in range(len(match.groups())):
        print(match.start(g+1), match.end(g+1))

<re.Match object; span=(0, 8), match='thi s is'>
('thi s', 'is')
0 5
6 8


In [208]:
re.findall("\w+|[^\w\s]+", "!!!a.9d.s,, ")

['!!!', 'a', '.', '9d', '.', 's', ',,']

In [213]:
re.findall(r"[^\w\s]+", "a!!!")

['!!!']

In [223]:
re.match(r"[^\w]", "a")

In [262]:
re.findall(r"\w\S+\w|\w+|[^\w\s]+", "!!!a.t.t!!! this")

['!!!', 'a.t.t', '!!!', 'this']

In [209]:
pattern = re.compile(r"\w+|[^\w\s]+")
for match in pattern.finditer("!!!a.9d.s,, this is the-end the end"):
    print(match.string[match.start():match.end()])

!!!
a
.
9d
.
s
,,
this
is
the
-
end
the
end
