In [137]:
# https://medium.com/doma/using-nlp-bert-to-improve-ocr-accuracy-385c98ae174c
# https://medium.com/@yashj302/spell-check-and-correction-nlp-python-f6a000e3709d better spell echeck
import json
import nltk
from pathlib import Path
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
import regex as re 
from spylls.hunspell import Dictionary #https://github.com/zverok/spylls
import torch

import sys
sys.path.append(str(Path.cwd().parent.parent))
from utils import create_fh_logger

In [138]:
# locations of json files 
good_docs =  Path.cwd().parent.parent.parent.parent / 'processing' / 'nro_declassified' / 'good_docs'
src = Path.cwd().parent.parent.parent.parent / 'processing' / 'nro_declassified' / 'ocr'
files = list(src.glob('*json'))
output_loc = src.parent / 'good_docs'
output_loc.mkdir(exist_ok=True)
dest = Path.cwd().parent.parent.parent.parent / 'processing' / 'nro_declassified' / 'bert_ocr'
dest.mkdir(exist_ok = True)

In [139]:
dictionary = Dictionary.from_files('en_US')
def validate_word(word, confidence, next_words):
    og = word
    word = word.lower()
    valid = dictionary.lookup(word)
    try:
        suggestion = list(dictionary.suggest(word))[0]
    except:
        suggestion = ''
    distance = nltk.edit_distance(word, suggestion)
    if word == suggestion.lower() or ( distance < 1 and not valid):
        return suggestion
    elif not valid and confidence < .8:
        # mask it for bert
        return '[MASK]'
    else:
        # keep it
        return og

In [140]:
# lets read in the files we will continue to work on
with open( good_docs / 'analyze.json','r') as f:
    docs = json.loads(f.read())['documents']
good = [file for file in files if file.stem in docs]

In [141]:
examples = {}
for file in good:
    with open(file, 'r') as f:
        data = json.load(f)
    data = data[file.stem] 
    print(file)
    print(len(data.keys()))
    for pg_num in data.keys():
        words = data[pg_num]['text'] 
        conf = data[pg_num]['conf']
        print(pg_num)
        for idx, word in enumerate(words):
            if word != '':
                word = re.sub("[^A-Za-z0-9 ]+", '', word)
                replacement = validate_word(word, conf[idx], next_words=words[idx:idx+2])
                if replacement != word:
                    examples[word] = replacement
                data[pg_num]['text'][idx] = replacement
    # save the masked data
    with open(dest / file.name, 'w') as f:
        f.write(json.dumps(data))

c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\ocr\1956-06-26MEMO PHYSICAL RECOVERY OF SATELLITE PAYLOADS A PRELIMINARY I_586.json
36
0
1
10
11
12
13
14
15
16
17
18
19
2
20
21
22
23
24
25
26
27
28
29
3
30
31
32
33
34
35
4
5
6
7
8
9
c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\ocr\1957-05-16MEMO DATA SHEET ON EARTH SATELLITE PROJECT (AIR FORCE 117L)_2307.json
2
0
1
c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\ocr\1957-09-27EXCERPTS FROM MEMORANDUM FOR RECORD BY LB KIRKPATRICK ON SUB_587.json
1
0
c:\Users\brasw\Desktop\School\Spring 24\GGS 590\project\processing\nro_declassified\ocr\1957-11-12DOC RAND RESEARCH MEMORANDUM A FAMILY OF RECOVERABLE RECONNA_588.json
128
0
1
10
100
101
102
103
104
105
106
107
108
109
11
110
111
112
113
114
115
116
117
118
119
12
120
121
122
123
124
125
126
127
13
14
15
16
17
18
19
2
20
21
22
23
24
25
26
27
28
29
3
30
31
32
33


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#Predict words for mask using BERT; 
from difflib import SequenceMatcher
#refine prediction by matching with proposals from SpellChecker
def predict_word(text_original, predictions, maskids): # MEDIUM: 
    for i in range(len(maskids)):
        preds = torch.topk(predictions[0, maskids[i]], k=50) 
        indices = preds.indices.tolist()
        simmax = 0
        predicted_token = ''
        for pred_word in tokenizer.convert_ids_to_tokens(indices):
            for spell_checked in list(dictionary.suggest(word)):
                s = SequenceMatcher(None, pred_word, spell_checked).ratio()
                if s is not None and s > simmax:
                    simmax = s
                    predicted_token = pred_word
        text_original = text_original.replace('[MASK]', predicted_token, 1)
    return text_original

In [None]:
files = list(dest.glob('*json')) # these are ready for bert
bert_results = dest.parent / 'bert_results'
bert_results.mkdir(exist_ok = True)

In [133]:
results = {}
for file in files[:3]:
    # let's open it, open the text, and join it all so we can use BERT to create a token
    # then we need to grab the mask ids
    with open(file, 'r') as f:
        data = json.load(f)
    doc_words = []
    for pg_num in data.keys():
        doc_words.extend( data[pg_num]['text'] )
    # we could do it by paragraph or max token length, we'll do it by both
    doc_text = (' ').join(words)
    doc_text = re.sub("\s\s+" , " ", doc_text)
    max_token = 500 # bad
    for idx in range(0, len(doc_text), max_token):
        text = ('').join(doc_text[idx: idx+max_token])
        # let's add periods every 10 words
        words = []
        if '.' not in text:
            for idx, word in enumerate(text.split(' ')):
                if idx%10 == 0 and idx != 0: # bad
                    words.extend(['.', word])
                else:
                    words.append(word)
            words.append('.')
            text = (' ').join(words)
        tokened = tokenizer.tokenize(text)
        ids = tokenizer.convert_tokens_to_ids(tokened)
        masked_idx = [i for i, e in enumerate(tokened) if e == '[MASK]']
        if len(masked_idx) > 0:
            segs = [i for i, e in enumerate(tokened) if e == "."]
            segments_ids=[]
            prev=-1
            for k, s in enumerate(segs):
                segments_ids = segments_ids + [k] * (s-prev)
                prev=s
            segments_ids = segments_ids + [len(segs)] * (len(tokened) - len(segments_ids))
            segments_tensors = torch.tensor([segments_ids])
            tokens_tensor = torch.tensor([ids])
            model = BertForMaskedLM.from_pretrained('bert-base-uncased')
            with torch.no_grad():
                predictions = model(tokens_tensor, segments_tensors)
        else:
            print('no predictions')
        
        replaced = predict_word(text, predictions, maskids)
        results[file][idx] = {'text': text, 'predictions': replaced}
    with open(bert_results / file, 'w') as f:
        f.write(json.dumps(results))