In [25]:
import wandb
import matplotlib.pyplot as plt
import scipy
import os
from collections import Counter
import numpy as np
import csv

plt.rcParams["axes.grid"] = False
import seaborn as sns

sns.set()
sns.set_style("ticks")

import sys
sys.path.append('..')

In [26]:
language = 'english'

sys.argv = [
"--device cuda",
"--data-folder", "..\\data",
"--seed", "13",
"--configuration", "char-to-char-encoder-decoder",
"--language", language,
"--challenge", "post-ocr-correction"]

In [27]:
# Configure container:
from dependency_injection.ioc_container import IocContainer

container = IocContainer()

data_service = container.data_service()
plot_service = container.plot_service()
metrics_service = container.metrics_service()
process_service = container.process_service()
file_service = container.file_service()

In [28]:
entity = 'eval-historical-texts'
project = 'post-ocr-correction'

In [29]:
unique_runs = {
    'Base': 'h512-e128-l2-bi-d0.50.0001',
#     'Base + FT': 'ft-h512-e128-l2-bi-d0.50.0001',
#     'Base + BERT': 'pretr-h512-e128-l2-bi-d0.50.0001',
#     'Base + FT + BERT': 'pretr-ft-h512-e128-l2-bi-d0.50.0001',
#     'Base + BERT (fine-tuned)': 'pretr-h512-e128-l2-bi-d0.5-tune0.0001',
#     'Base + FT + BERT (fine-tuned)': 'pretr-ft-h512-e128-l2-bi-d0.5-tune0.0001',
#     'Base + BERT (fine-tuned, after convergence)': 'pretr-h512-e128-l2-bi-d0.5-tune-ac0.0001',
    'Base + FT + BERT (fine-tuned, after convergence)': 'pretr-ft-h512-e128-l2-bi-d0.5-tune-ac0.0001',
}

In [30]:
def get_run_info(language: str, checkpoint_name: str):
    output_path = os.path.join('..', 'results', 'post-ocr-correction', 'char-to-char-encoder-decoder', language, 'output')
    csv_path = os.path.join(output_path, f'output-BEST_{language}--{checkpoint_name}.csv')
    pickle_name = f'output-BEST_{language}--{checkpoint_name}'
    
    run_info = data_service.load_python_obj(output_path, pickle_name, print_on_success=False, print_on_error=False)
    
    if not os.path.exists(csv_path):
        return None
    
    with open(csv_path, 'r', encoding='utf-8') as csv_file:
        lines = csv_file.read().splitlines()
        last_line = lines[-1]
        improvement_percentage = round(float(last_line.split(',')[0].split('Improvement percentage: ')[-1]), 2)
        run_info['improvement_percentage'] = improvement_percentage
        
    input_characters =[]
    predicted_characters = []
    target_characters = []
    with open(csv_path, 'r', encoding='utf-8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            input_data = row['Input']
            predicted_data = row['Prediction']
            target_data = row['Target']
            
            if target_data == '' or target_data is None:
                break
                
            
            input_characters.append(input_data)
            predicted_characters.append(predicted_data)
            target_characters.append(target_data)
            
#         predicted_tokens = [process_service._vocabulary_service.ids_to_string(
#             x, exclude_special_tokens=True) for x in predicted_characters]
#         target_tokens = [process_service._vocabulary_service.ids_to_string(
#             x, exclude_special_tokens=True) for x in target_characters]
        jaccard_scores = [metrics_service.calculate_jaccard_similarity(
            target_characters[i], input_characters[i]) for i in range(len(input_characters))]
    
        pr_jaccard_scores = [metrics_service.calculate_jaccard_similarity(
            target_characters[i], predicted_characters[i]) for i in range(len(predicted_characters))]     
        
        original_levenshtein_distances = [metrics_service.calculate_levenshtein_distance(
            target_characters[i], input_characters[i]) for i in range(len(input_characters))]
    
        levenshtein_distances = [metrics_service.calculate_levenshtein_distance(
            target_characters[i], predicted_characters[i]) for i in range(len(predicted_characters))]
        
        
        
#         batch_jaccard_scores = []
#         for i in range(0, len(jaccard_scores), 32):
#             batch_jaccard_scores.append(np.mean(jaccard_scores[i:i+32]))
        
        run_info['improvement_percentage'] = round((1 - (float(sum(levenshtein_distances)) / sum(original_levenshtein_distances))) * 100, 3)
        run_info['jaccard-similarities'] = pr_jaccard_scores
        run_info["edit-distances"] = levenshtein_distances
        run_info["original-edit-distances"] = original_levenshtein_distances
        
        run_info['original-jaccard-scores'] = jaccard_scores
        run_info['original-jaccard-score'] = round(np.mean(jaccard_scores), 3)
    
    return run_info

In [31]:
def plot_histogram(xs, ys, color='r'):
    new_xs = []
    new_ys = []

    for x, y in zip(xs, ys):
        if y > -1:
            new_xs.append(x)
            new_ys.append(y)

    # plt.plot(new_xs, new_ys, color)
    plt.fill_between(new_xs, new_ys, interpolate=True, color=color, alpha=0.5)

In [32]:
original_edit_distances = {}
edit_distances_per_run = {}

for run_name, run_unique_str in unique_runs.items():
    for seed in [13, 7, 25]:
        checkpoint_name = f'{run_unique_str}-seed{seed}'
        run_info = get_run_info(language, checkpoint_name)
        if run_info is None:
            continue
    
        original_jaccard = run_info['original-jaccard-score']
        original_jaccards = sum(run_info['original-jaccard-scores'])
        jaccards = sum(run_info['jaccard-similarities'])
        jaccard_improvement_percentage = -round(((1 - (float(jaccards) / original_jaccards)) * 100), 3)
        
        original_edit_mean = round(np.mean(run_info["original-edit-distances"]), 3)
        
        original_edit_sum = sum(run_info["original-edit-distances"])
        predicted_edit_sum = sum(run_info["edit-distances"])
        improvement_percentage = run_info['improvement_percentage']
        jaccard_similarity_mean = round(np.mean(run_info['jaccard-similarities']), 3)
        edit_dist = round(np.mean(run_info["edit-distances"]), 3)
        print('{:48s}, seed {:2d} | {:8.3f}; sum: {:7.3f}; jacc mean: {:6.3f}; jacc percent: {:10.3f} || {:100s}'.format(run_name, seed, improvement_percentage, edit_dist, jaccard_similarity_mean, jaccard_improvement_percentage,
                                                                                                                      f'{edit_dist} & {improvement_percentage} & {jaccard_similarity_mean} & {jaccard_improvement_percentage}'))
        
        original_edit_distances = run_info["original-edit-distances"]
        edit_distances_per_run['No correction'] = [x for x in original_edit_distances]
        edit_distances_per_run[run_name] = [run_info["edit-distances"][i] for i in range(len(original_edit_distances))]
    
print('-----------')
print(f"original jaccard similarity mean: {original_jaccard}")
print(f"original edit mean: {original_edit_mean}")



Base                                            , seed 13 |   -5.968; sum:   9.958; jacc mean:  0.821; jacc percent:     -0.450 || 9.958 & -5.968 & 0.821 & -0.45                                                                      
Base + FT + BERT (fine-tuned, after convergence), seed 13 |   -8.382; sum:  10.184; jacc mean:  0.821; jacc percent:     -0.418 || 10.184 & -8.382 & 0.821 & -0.418                                                                    
Base + FT + BERT (fine-tuned, after convergence), seed  7 |   -7.215; sum:  10.075; jacc mean:  0.820; jacc percent:     -0.574 || 10.075 & -7.215 & 0.82 & -0.574                                                                     
Base + FT + BERT (fine-tuned, after convergence), seed 25 |   -5.639; sum:   9.927; jacc mean:  0.826; jacc percent:      0.108 || 9.927 & -5.639 & 0.826 & 0.108                                                                      
-----------
original jaccard similarity mean: 0.825
original edit mean: 

In [33]:
def merge_dicts(*dict_args):
    """
    Given any number of dictionaries, shallow copy and merge into a new dict,
    precedence goes to key value pairs in latter dictionaries.
    """
    result = {}
    for dictionary in dict_args:
        result.update(dictionary)
    return result

In [34]:
plot_service.plot_overlapping_bars(
    numbers_per_type=list(edit_distances_per_run.values()),
    bar_titles=['\\textbf{' + x + '}' for x in list(edit_distances_per_run.keys())],
    colors=['seagreen', 'peru', 'darkkhaki', 'black', 'gold'],
    show_legend=True,
    save_path=os.path.join(file_service.get_experiments_path(), 'post-ocr'),
    filename=f'histogram-{language}',
    tight_layout=True,
    ylim=30,
    xlim=42,
    ylabel='\\textbf{\% of total}',
    xlabel='\\textbf{edit distance}')

<matplotlib.axes._subplots.AxesSubplot at 0x1c78d428cc8>

<Figure size 1440x720 with 0 Axes>