In [1]:
import wandb
import matplotlib.pyplot as plt
import scipy
import os
from collections import Counter
import numpy as np
import csv

plt.rcParams["axes.grid"] = False

import sys
sys.path.append('..')

In [2]:
language = 'english'

sys.argv = [
"--device cuda",
"--data-folder", "..\\data",
"--seed", "13",
"--configuration", "char-to-char-encoder-decoder",
"--language", language,
"--challenge", "post-ocr-correction"]

In [3]:
# Configure container:
from dependency_injection.ioc_container import IocContainer

container = IocContainer()

data_service = container.data_service()
plot_service = container.plot_service()
# vocabulary_service = container.vocabulary_service()
metrics_service = container.metrics_service()
# cache_service = container.cache_service()
process_service = container.process_service()

In [4]:
# vocabulary_data = cache_service.get_item_from_cache(
#     item_key='char-vocabulary')

# vocabulary_service.initialize_vocabulary_data(vocabulary_data)

In [5]:
entity = 'eval-historical-texts'
project = 'post-ocr-correction'

In [6]:
unique_runs = {
    'none': 'h512-e128-l2-bi-d0.50.0001',
    'fast-text': 'ft-h512-e128-l2-bi-d0.50.0001',
    'bert': 'pretr-h512-e128-l2-bi-d0.50.0001',
    'both': 'pretr-ft-h512-e128-l2-bi-d0.50.0001',
    'bert-finetune': 'pretr-h512-e128-l2-bi-d0.5-tune0.0001',
    'both-finetune': 'pretr-ft-h512-e128-l2-bi-d0.5-tune0.0001',
    'bert-finetune-ac': 'pretr-h512-e128-l2-bi-d0.5-tune-ac0.0001',
    'both-finetune-ac': 'pretr-ft-h512-e128-l2-bi-d0.5-tune-ac0.0001',
}

In [7]:
def get_run_info(language: str, checkpoint_name: str):
    output_path = os.path.join('..', 'results', 'post-ocr-correction', 'char-to-char-encoder-decoder', language, 'output')
    csv_path = os.path.join(output_path, f'output-BEST_{language}--{checkpoint_name}.csv')
    pickle_name = f'output-BEST_{language}--{checkpoint_name}'
    
    run_info = data_service.load_python_obj(output_path, pickle_name, print_on_success=False, print_on_error=False)
    
    if not os.path.exists(csv_path):
        return None
    
    with open(csv_path, 'r', encoding='utf-8') as csv_file:
        lines = csv_file.read().splitlines()
        last_line = lines[-1]
        improvement_percentage = round(float(last_line.split(',')[0].split('Improvement percentage: ')[-1]), 2)
        run_info['improvement_percentage'] = improvement_percentage
        
    input_characters =[]
    predicted_characters = []
    target_characters = []
    with open(csv_path, 'r', encoding='utf-8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            input_data = row['Input']
            predicted_data = row['Prediction']
            target_data = row['Target']
            
            if target_data == '' or target_data is None:
                break
                
            
            input_characters.append(input_data)
            predicted_characters.append(predicted_data)
            target_characters.append(target_data)
            
#         predicted_tokens = [process_service._vocabulary_service.ids_to_string(
#             x, exclude_special_tokens=True) for x in predicted_characters]
#         target_tokens = [process_service._vocabulary_service.ids_to_string(
#             x, exclude_special_tokens=True) for x in target_characters]
        jaccard_scores = [metrics_service.calculate_jaccard_similarity(
            target_characters[i], input_characters[i]) for i in range(len(input_characters))]
    
        pr_jaccard_scores = [metrics_service.calculate_jaccard_similarity(
            target_characters[i], predicted_characters[i]) for i in range(len(predicted_characters))]     
        
        original_levenshtein_distances = [metrics_service.calculate_levenshtein_distance(
            target_characters[i], input_characters[i]) for i in range(len(input_characters))]
    
        levenshtein_distances = [metrics_service.calculate_levenshtein_distance(
            target_characters[i], predicted_characters[i]) for i in range(len(predicted_characters))]
        
        
        
#         batch_jaccard_scores = []
#         for i in range(0, len(jaccard_scores), 32):
#             batch_jaccard_scores.append(np.mean(jaccard_scores[i:i+32]))
        
        run_info['improvement_percentage'] = round((1 - (float(sum(levenshtein_distances)) / sum(original_levenshtein_distances))) * 100, 3)
        run_info['jaccard-similarities'] = pr_jaccard_scores
        run_info["edit-distances"] = levenshtein_distances
        run_info["original-edit-distances"] = original_levenshtein_distances
        
        run_info['original-jaccard-scores'] = jaccard_scores
        run_info['original-jaccard-score'] = round(np.mean(jaccard_scores), 3)
    
    return run_info

In [8]:
def plot_histogram(xs, ys, color='r'):
    new_xs = []
    new_ys = []

    for x, y in zip(xs, ys):
        if y > -1:
            new_xs.append(x)
            new_ys.append(y)

    # plt.plot(new_xs, new_ys, color)
    plt.fill_between(new_xs, new_ys, interpolate=True, color=color, alpha=0.5)

In [9]:
original_edit_distances = {}
edit_distances_per_run = {}

for run_name, run_unique_str in unique_runs.items():
    for seed in [13, 7, 42]:
        checkpoint_name = f'{run_unique_str}-seed{seed}'
        run_info = get_run_info(language, checkpoint_name)
        if run_info is None:
            continue
    
        original_jaccard = run_info['original-jaccard-score']
        original_jaccards = sum(run_info['original-jaccard-scores'])
        jaccards = sum(run_info['jaccard-similarities'])
        jaccard_improvement_percentage = -round(((1 - (float(jaccards) / original_jaccards)) * 100), 3)
        
        original_edit_mean = round(np.mean(run_info["original-edit-distances"]), 3)
        
        original_edit_sum = sum(run_info["original-edit-distances"])
        predicted_edit_sum = sum(run_info["edit-distances"])
        improvement_percentage = run_info['improvement_percentage']
        jaccard_similarity_mean = round(np.mean(run_info['jaccard-similarities']), 3)
        edit_dist = round(np.mean(run_info["edit-distances"]), 3)
        print('{:16s}, seed {:2d} | {:8.3f}; sum: {:7.3f}; jacc mean: {:6.3f}; jacc percent: {:10.3f} || {:100s}'.format(run_name, seed, improvement_percentage, edit_dist, jaccard_similarity_mean, jaccard_improvement_percentage,
                                                                                                                      f'{edit_dist} & {improvement_percentage} & {jaccard_similarity_mean} & {jaccard_improvement_percentage}'))
        
        original_edit_distances = run_info["original-edit-distances"]
        edit_distances_per_run[run_name] = [run_info["edit-distances"][i] for i in range(len(original_edit_distances))]
    
print('-----------')
print(f"original jaccard similarity mean: {original_jaccard}")
print(f"original edit mean: {original_edit_mean}")


edit_distances_per_run['original'] = [x for x in original_edit_distances]

none            , seed 13 |   -5.968; sum:   9.958; jacc mean:  0.821; jacc percent:     -0.450 || 9.958 & -5.968 & 0.821 & -0.45                                                                      
fast-text       , seed 13 |   -5.735; sum:   9.936; jacc mean:  0.824; jacc percent:     -0.035 || 9.936 & -5.735 & 0.824 & -0.035                                                                     
bert            , seed 13 |   -8.840; sum:  10.228; jacc mean:  0.822; jacc percent:     -0.364 || 10.228 & -8.84 & 0.822 & -0.364                                                                     
bert            , seed  7 |  -61.873; sum:  15.211; jacc mean:  0.732; jacc percent:    -11.253 || 15.211 & -61.873 & 0.732 & -11.253                                                                  
both            , seed 13 |   -7.658; sum:  10.116; jacc mean:  0.821; jacc percent:     -0.482 || 10.116 & -7.658 & 0.821 & -0.482                                                                    


In [10]:
# alpha_values = [1, 0.6, 0.3, 0.2]

# norm_values = {}

# # bins = np.arange(0, 10000, 2)

# for i, (run_name, edit_distances) in enumerate(edit_distances_per_run.items()):
    
#     rounded_distances = [x for x in edit_distances]
#     counter_values = Counter(rounded_distances)
# #     counter_values = Counter()
# #     current_index = 0
# #     for key, value in temp_values.items():
# #         if key >= bins[current_index]:
# #             current_index += 1
            
# #         counter_values[current_index] += value
        
#     x = [v for v in list(sorted(counter_values.keys())) if v < 20]
#     y = [counter_values[key] for key in x]
#     norm_y = [(float(i)/sum(y)) * 100 for i in y]
#     norm_values[run_name] = norm_y
    
    
#     plt.fill_between(x, norm_y, interpolate=True, alpha=alpha_values[i])
    

# plt.legend(edit_distances_per_run.keys())