In [None]:
import gensim
from gensim.models import Word2Vec

import os
import re
import string

from nltk.tokenize import RegexpTokenizer

import sys
sys.path.insert(0, '..')

from enums.language import Language
from enums.configuration import Configuration
from enums.ocr_output_type import OCROutputType

In [None]:
# Load data

def get_folder_paths(language: Language):
    newseye_path = os.path.join('..', 'data', 'newseye')

    result = None
    if language == Language.English:
        icdar_2017_1_path = os.path.join(newseye_path, '2017', 'full', 'eng_monograph')
        icdar_2017_2_path = os.path.join(newseye_path, '2017', 'full', 'eng_periodical')
        icdar_2019_path = os.path.join(newseye_path, '2019', 'full', 'EN')
        result = [icdar_2017_1_path, icdar_2017_2_path, icdar_2019_path]
    elif language == Language.Dutch:
        icdar_2019_path = os.path.join(newseye_path, '2019', 'full', 'NL', 'NL1')
        result = [icdar_2019_path]

    return result


In [None]:
tokenizer = RegexpTokenizer(r'\w+')

In [None]:
def read_documents(tokenizer, language: Language, ocr_output_type: OCROutputType):
    documents = []

    folder_paths = get_folder_paths(language)
    for folder_path in folder_paths:
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'r', encoding='utf-8') as text_file:
                file_lines = text_file.readlines()
                gt_line = file_lines[2] if ocr_output_type == OCROutputType.GroundTruth else file_lines[1]
                processed_line = gt_line[14:].replace('#', '').replace('@', '')

                text_nonum = re.sub(r'\d+', '', processed_line)
                text_nopunct = "".join([char.lower() for char in text_nonum if char not in string.punctuation])
                text_no_doublespace = re.sub('\s+', ' ', text_nopunct).strip()
                result = tokenizer.tokenize(text_no_doublespace)
                documents.append(result)

    return documents

In [None]:
def get_model_path(language: Language, configuration: Configuration, randomly_initialized: bool, ocr_output_type: OCROutputType):
    rnd_suffix = 'random' if randomly_initialized else 'pretr'

    model_name = f'gensim_{language.value}_{configuration.value}_{rnd_suffix}_{ocr_output_type.value}.model'

    results_folder = 'results'
    if not os.path.exists(results_folder):
        os.mkdir(results_folder)

    result = os.path.join(results_folder, model_name)
    return result

In [None]:
def load_model(model_path):
    if not os.path.exists(model_path):
        return None

    model = Word2Vec.load(model_path)
    return model


In [None]:
def get_word2vec_model_info(language: Language):
    if language == Language.English:
        return 'GoogleNews-vectors-negative300.bin', True
    elif language == Language.Dutch:
        return 'combined-320.txt', False
    elif language == Language.French:
        return 'frwiki_20180420_300d.txt', False
    elif language == Language.German:
        return 'dewiki_20180420_300d.txt', False

    error_message = 'Unsupported word2vec language'
    raise Exception(error_message)

def get_pretrained_matrix(language: Language):
    data_path = os.path.join('..', 'data', 'ocr-evaluation', 'word2vec', language.value)
    word2vec_model_name, word2vec_binary = get_word2vec_model_info(language)
    word2vec_model_path = os.path.join(data_path, word2vec_model_name)
    word2vec_weights  = gensim.models.KeyedVectors.load_word2vec_format(word2vec_model_path, binary=word2vec_binary)
    return word2vec_weights, word2vec_model_path, word2vec_binary

In [None]:
# TRAIN

def create_model(corpus, model_path: str, configuration: Configuration, randomly_initialized: bool, language: Language):
    sg = 1 if configuration == Configuration.SkipGram else 0
    vector_size = 320 if language == Language.Dutch else 300

    # initialize the model
    model = Word2Vec(vector_size=vector_size, window=5, min_count=5, workers=2, sg=sg)

    if not randomly_initialized:
        word2vec_weights, word2vec_model_path, word2vec_binary = get_pretrained_matrix(language)
        model.build_vocab([list(word2vec_weights.key_to_index.keys())], update=True)
        model.intersect_word2vec_format(word2vec_model_path, binary=word2vec_binary, lockf=1.0)

    # build the vocabulary
    model.build_vocab(corpus, progress_per=1000)

    # train the model
    model.train(corpus, total_examples=model.corpus_count, epochs=300, report_delay=1)

    # save the model
    model.save(model_path)

    return model

In [None]:
# language = Language.Dutch
# configuration = Configuration.SkipGram
# randomly_initialized = False
# ocr_output_type = OCROutputType.GroundTruth

for language in [Language.English]:#, Language.English]:
    for configuration in [Configuration.SkipGram]:#, Configuration.CBOW]:
        for randomly_initialized in [False]:#, True]:
            for ocr_output_type in [OCROutputType.GroundTruth]:#, OCROutputType.Raw]:
                print(f'Training: [\'{language.value}\', {configuration.value}, {randomly_initialized}, {ocr_output_type.value}]')
                documents = read_documents(tokenizer, language, ocr_output_type)
                model_path = get_model_path(language, configuration, randomly_initialized, ocr_output_type)
                model = load_model(model_path)
                if model is None:
                    print('Model is not loaded. Creating and training now...')
                    model = create_model(documents, model_path, configuration, randomly_initialized, language)

In [None]:
# target_words = {
#     Language.English: ['man', 'new', 'time', 'day', 'good', 'old', 'little', 'one', 'two', 'three'],
#     Language.Dutch: ['man', 'jaar', 'tijd', 'dag', 'huis', 'dier', 'werk', 'naam', 'groot', 'kleine', 'twee', 'drie', 'vier', 'vijf']
# }

# for word in target_words[language]:
#     print(f'-- \'{word}\':')
#     print(model.wv.most_similar(positive=[word]))