In [1]:
import glob
import tqdm
import unicodedata
import re

In [2]:
train_files = open("/home/akoelsch/Downloads/rvl-cdip/labels/train.txt").readlines()
ocr_files = ["/home/akoelsch/Downloads/rvl-cdip/ocr/" + path.split()[0][:-4] + ".txt" for path in train_files]


In [3]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)
            
            
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def removeShortWords(s):
    return ' '.join(w for w in s.split() if w in [".", "?", "!"] or len(w)>=3)

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    s = removeShortWords(s)
    return s



In [4]:
voc = Vocabulary("cdip")

for f in tqdm.tqdm_notebook(ocr_files):
    content = open(f).read()
    voc.addSentence(normalizeString(content))


HBox(children=(IntProgress(value=0, max=320000), HTML(value='')))




In [5]:
print(len(voc.word2count))
voc.trim(5)
print(len(voc.word2count))

3161486
keep_words 241952 / 3161486 = 0.0765
241952


In [6]:
def tokenize(voc, string):
    tokens = []
    for word in string.split():
        tokens.append(voc.word2index.get(word, 0))
    return tokens


In [7]:
max_len = 0

for f in tqdm.tqdm_notebook(ocr_files):
    content = open(f).read()
    tokens = tokenize(voc, normalizeString(content))
    max_len = max(max_len, len(tokens))
    
print(max_len)

HBox(children=(IntProgress(value=0, max=320000), HTML(value='')))


6918
