In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
from tqdm.notebook import tqdm

from spelling_correction.baselines import sec
from spelling_correction import BENCHMARK_DIR

from gnn_lib.api import utils
from gnn_lib.data.utils import clean_sequence

In [13]:
def run_neuspell_with_detections(baseline: sec.SECNeuspellBaseline, input_file: str, detection_file: str, batch_size: int = 16) -> list:
    inputs = utils.load_text_file(input_file)
    detections = utils.load_text_file(detection_file)
    detections = [[int(d) for d in det.split()] for det in detections]
    
    all_outputs = []
    for i in tqdm(list(range(0, len(inputs), batch_size)), desc=f"running neuspell baseline {baseline.name} on {os.path.relpath(input_file, BENCHMARK_DIR)}"):
        batch_inputs = []
        batch_detections = []
        for ipt, detection in zip(inputs[i:i+batch_size], detections[i:i+batch_size]):
            cleaned_ipt = clean_sequence(ipt, fix_unicode_errors=True)
            if len(cleaned_ipt.split()) != len(ipt.split()):
                print("found input containing unicode that will be removed by neuspell, adapting detections:", ipt, cleaned_ipt, detection)
                detection = detection[1:]
            batch_inputs.append(ipt)
            batch_detections.append(detection)
        outputs = baseline.inference(batch_inputs, detections=batch_detections)
        all_outputs.extend(outputs)
    return all_outputs

In [10]:
bert = sec.SECNeuspellBaseline("bert")
benchmarks = ["bookcorpus/realistic", "bookcorpus/artificial", "wikidump/realistic", "wikidump/artificial", "neuspell/bea60k"]
sed_words_file = "gnn_cliques_wfc_high_rec.txt"

loading vocab from path:/home/sebastian/anaconda3/envs/masters_thesis/lib/python3.8/site-packages/neuspell/../data/checkpoints/subwordbert-probwordnoise/vocab.pkl
initializing model


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


SubwordBert(
  (bert_dropout): Dropout(p=0.2, inplace=False)
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): 

In [14]:
for benchmark in benchmarks:
    input_file = os.path.join(BENCHMARK_DIR, "test", "sec", benchmark, "corrupt.txt")
    detection_file = os.path.join(BENCHMARK_DIR, "test", "sed_words", "results", benchmark, sed_words_file)
    out_file = os.path.join(BENCHMARK_DIR, "test", "sec", "results", benchmark, "gnn_cliques_wfc_plus_baseline_neuspell_bert.txt")
    
    outputs = run_neuspell_with_detections(bert, input_file, detection_file)
    utils.save_text_file(out_file, outputs)

running neuspell baseline neuspell_bert on test/sec/bookcorpus/realistic/corrupt.txt:   0%|          | 0/625 […

found input containing unicode that will be removed by neuspell, adapting detections  decodation. decodation. [0, 1]
found input containing unicode that will be removed by neuspell, adapting detections  Critical issue. Critical issue. [0, 1, 0]


running neuspell baseline neuspell_bert on test/sec/bookcorpus/artificial/corrupt.txt:   0%|          | 0/625 …

found input containing unicode that will be removed by neuspell, adapting detections  DecofrCation. DecofrCation. [0, 1]
found input containing unicode that will be removed by neuspell, adapting detections  Critical issue. Critical issue. [0, 1, 0]


running neuspell baseline neuspell_bert on test/sec/wikidump/realistic/corrupt.txt:   0%|          | 0/625 [00…

running neuspell baseline neuspell_bert on test/sec/wikidump/artificial/corrupt.txt:   0%|          | 0/625 [0…

running neuspell baseline neuspell_bert on test/sec/neuspell/bea60k/corrupt.txt:   0%|          | 0/3941 [00:0…