In [1]:
# https://medium.com/doma/using-nlp-bert-to-improve-ocr-accuracy-385c98ae174c
# https://medium.com/@yashj302/spell-check-and-correction-nlp-python-f6a000e3709d better spell echeck
# https://gist.github.com/yuchenlin/a2f42d3c4378ed7b83de65c7a2222eb2
from difflib import SequenceMatcher
import json
from pathlib import Path
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from spylls.hunspell import Dictionary #https://github.com/zverok/spylls
import regex as re 
import torch

import sys
sys.path.append(str(Path.cwd().parent.parent))

In [2]:
# lets read in the files we will continue to work on
src = Path.cwd().parent.parent.parent.parent / 'processing' / 'nro_declassified' / 'bert_ocr'
files = list(src.glob('*json'))
output_loc = src.parent / 'bert_results'
output_loc.mkdir(exist_ok=True)
good = [file for file in files]
dictionary = Dictionary.from_files('en_US')

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [4]:
def predict_masked_sent(text, top_k=5):
    text = "[CLS] %s [SEP]"%text
    tokenized_text = tokenizer.tokenize(text)
    masked_index = [idx for idx, token in enumerate(tokenized_text) if token =='[MASK]']
    preds = []
    for idx, mask in enumerate(masked_index): 
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        with torch.no_grad():
            outputs = model(tokens_tensor)
            predictions = outputs[0]
        probs = torch.nn.functional.softmax(predictions[0, mask], dim=-1)
        top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
        prediction = []
        for i, pred_idx in enumerate(top_k_indices):
            predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
            token_weight = top_k_weights[i]
            print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))
            prediction.append(predicted_token)
        preds.append(prediction)
    return preds

In [5]:
original_text = ("The BERT pre-trained language model is useful for predicting multiple viable replacements for the masked words. " 
    "With that said, the model is not aware of any characters uncovered by OCR. We can augment this deficiency with our suggested word" 
    " list from SpellChecker, which incorporates characters from the garbled OCR output. Combining BERT’s context-based suggtion with" 
    " SpellChecker’s word-based suggestions  better predictions than relying solely on BERT")

masked_text = ("The BERT pre-trained language model is useful for predicting multiple viable replacements for the masked words. " 
    "With that said, the model is not aware of any characters [MASK] by OCR. We can augment this deficiency with our suggested word" 
    " list from SpellChecker, which [MASK] characters from the garbled OCR output. Combining BERT’s context-based suggtion with" 
    " SpellChecker’s word-based suggestions yield better predictions than [MASK] solely on BERT")
predictions = predict_masked_sent(masked_text, top_k=2)

[MASK]: 'generated'  | weights: 0.08679242432117462
[MASK]: 'provided'  | weights: 0.06315764784812927
[MASK]: 'includes'  | weights: 0.3903972804546356
[MASK]: 'contains'  | weights: 0.22848112881183624
[MASK]: 'based'  | weights: 0.6670317649841309
[MASK]: 'relying'  | weights: 0.2214537411928177


In [6]:
peds = predict_masked_sent(masked_text, top_k=1)

[MASK]: 'generated'  | weights: 0.08679242432117462
[MASK]: 'includes'  | weights: 0.3903972804546356
[MASK]: 'based'  | weights: 0.6670317649841309


In [None]:
for file in files:
    results = {}
    # let's open it, open the text, and join it all so we can use BERT to create a token
    # then we need to grab the mask ids
    with open(file, 'r') as f:
        data = json.load(f)
    doc_words = []
    for pg_num in data.keys():
        doc_words.extend( data[pg_num]['text'] )
    print(file)
    # we could do it by paragraph or max token length, we'll do it by both
    doc_text = (' ').join(doc_words)
    doc_text = re.sub("\s\s+" , " ", doc_text)
    max_token = 1500 # bad
    results = []
    for idx in range(0, len(doc_text), max_token):
        text = ('').join(doc_text[idx: idx+max_token])
        if '[MASK]' in text:
            try:
                predictions = predict_masked_sent(text, top_k=1)
                tokenized_text = text.split(' ')
                masked_index = [idx for idx, token in enumerate(tokenized_text) if token =='[MASK]']
                for idx, mask in enumerate(masked_index): 
                    tokenized_text[mask] = predictions[idx][0]
                replaced = (' ').join([str(t) for t in tokenized_text])
                results.append({'text': text, 'predictions': replaced})
            except Exception as e:
                print(f'Exception encountered: {e} with {str(file.name)}')

    with open(output_loc / file.name, 'w') as f:
        f.write(json.dumps(results))

In [7]:
with open(r'c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\bert_results\1958-04-16MEMO DEAR GOODPASTER COVER LETTER PROJECT CORONA_812.json', 'r') as f:
    data = json.loads(f.read())

In [8]:
def display_data(data: dict):
    max_tokens = 5 # it has to use tokens instead of words since [MASK] is not a one to one index
    text = data['text'].split(' ')
    pred = data['predictions'].split(' ')
    masked_idx = [i for i, t in enumerate(text) if t == '[MASK]']
    for idx, mask in enumerate(masked_idx):
        print('\n\nSubmitted Text:\n')
        start = mask - max_tokens
        end = mask + max_tokens
        t = (' ').join(text[start:end])
        p = (' ').join(pred[start:end])
        print(f"{t}")
        print('\n\nGenerated Text:\n')
        print(f"{p}")

In [9]:
display_data(data[1])



Submitted Text:

sites under construction previously assbeerved [MASK] OR other Major installations


Generated Text:

sites under construction previously assbeerved sites OR other Major installations


Submitted Text:

As the Seviet far North [MASK] Sots 3 program IT


Generated Text:

As the Seviet far North and Sots 3 program IT


In [10]:
file = r'c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\bert_results\1958-07-23DOC PROGRESS REPORT ON MILITARY RECONNAISSANCE SATELLITE PRO_2312.json'
with open(file, 'r') as f:
    data = json.loads(f.read())

In [11]:
display_data(data[0])



Submitted Text:

for release 20231018 CO51449 29 [MASK] [MASK] July 23 1958


Generated Text:

for release 20231018 CO51449 29 approved issued July 23 1958


Submitted Text:

release 20231018 CO51449 29 [MASK] [MASK] July 23 1958 memorandum


Generated Text:

release 20231018 CO51449 29 approved issued July 23 1958 memorandum


In [12]:
file = r'c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\bert_results\1958-09-30MEMO STAFF MEETING MINUTES 30 SEPTEMBER 1958 (CORONA ITEMS) _409.json'
with open(file, 'r') as f:
    data = json.loads(f.read())

In [13]:
display_data(data[0])



Submitted Text:

need to Be reexamined He [MASK] QM to undertake this


Generated Text:

need to Be reexamined He asked QM to undertake this


Submitted Text:

[MASK] QM to undertake this [MASK] and suggested that later


Generated Text:

asked QM to undertake this study and suggested that later
