In [1]:
import os
import sys
import csv
import math
from collections import Counter, OrderedDict
from enum import Enum
from typing import Dict, List, Tuple, Callable
import itertools
from termcolor import colored
from IPython.display import Markdown, display
import numpy as np


import matplotlib.pyplot as plt
plt.rcParams["axes.grid"] = False

sys.path.append('..')

from enums.language import Language

In [2]:
def printmd(string):
    display(Markdown(string))

In [3]:
sys.argv = [
"--device cuda",
"--data-folder", "..\\data",
"--seed", "13",
"--configuration", "rnn-simple",
"--challenge", "named-entity-recognition",
"--entity-tag-types", "literal-coarse"]

In [4]:
# Configure container:
from dependency_injection.ioc_container import IocContainer

container = IocContainer()

In [5]:
plot_service = container.plot_service()
file_service = container.file_service()


In [6]:
languages = [
    Language.English,
    Language.French,
    Language.German
]

counters_per_language = {
    language: Counter() for language in languages
}


version = '1.2'

for language in languages:
    language_path = file_service.get_data_path(language=language.value)
    filenames = [x for x in os.listdir(language_path) if f'v{version}' in x and 'Copy' not in x and 'old' not in x]
    if language == Language.English:
        filenames = [x for x in filenames if 'train' not in x]

    current_tokens_count = 0
    current_decade = None
    for filename in filenames:
        with open(os.path.join(language_path, filename), 'r', encoding='utf-8') as csv_file:
            csv_reader = csv.DictReader(csv_file, dialect=csv.excel_tab, quoting=csv.QUOTE_NONE)
            for row in csv_reader:
                if row['TOKEN'].startswith('# date'):
                    if current_decade is not None:
                        counters_per_language[language][current_decade] += current_tokens_count

                    current_decade = math.floor(float(row['TOKEN'][9:13]) / 10) * 10
                    current_tokens_count = 0
                elif not row['TOKEN'].startswith('#') and not row['TOKEN'].startswith(' '):
                    current_tokens_count += 1

unique_decades = list(sorted(set([label for x in counters_per_language.values() for label in x.keys()])))
tags_per_decade = { }

for language in languages:
    language_path = file_service.get_data_path(language=language.value)
    filenames = [x for x in os.listdir(language_path) if f'v{version}' in x and 'Copy' not in x and 'old' not in x]
    if language == Language.English:
        filenames = [x for x in filenames if 'train' not in x]

    current_tokens_count = 0
    current_decade = None
    for filename in filenames:
        with open(os.path.join(language_path, filename), 'r', encoding='utf-8') as csv_file:
            csv_reader = csv.DictReader(csv_file, dialect=csv.excel_tab, quoting=csv.QUOTE_NONE)
            for row in csv_reader:
                if row['TOKEN'].startswith('# date'):
                    current_decade = math.floor(float(row['TOKEN'][9:13]) / 10) * 10
                elif not row['TOKEN'].startswith('#') and not row['TOKEN'].startswith(' '):
                    coarse_entity_str = row['NE-COARSE-LIT']
                    if coarse_entity_str.startswith('B-'):
                        coarse_entity = coarse_entity_str[2:]
                        if coarse_entity not in tags_per_decade.keys():
                            tags_per_decade[coarse_entity] = Counter()

                        tags_per_decade[coarse_entity][current_decade] += 1

In [10]:
plot = plot_service.create_plot()

language_labels = [x.value.capitalize() for x in counters_per_language.keys()]

plot_service.plot_counters_histogram(
    counter_labels=language_labels,
    counters=counters_per_language.values(),
    counter_colors=['firebrick', 'royalblue', 'black'],
    save_path=file_service.get_experiments_path(),
    filename='ner_tokens_per_language_and_decade',
    title='Tokens per decade and language',
    ylabel='amount of tokens',
    xlabel='decade')

decade_labels = [str(x).capitalize() for x in tags_per_decade.keys()]

plot_service.plot_counters_histogram(
    counter_labels=decade_labels,
    counters=tags_per_decade.values(),
    counter_colors=['darkred', 'olive', 'orange', 'forestgreen', 'lightseagreen', 'lightcoral'],
    save_path=file_service.get_experiments_path(),
    filename='ner_mentions_per_tag_and_decade',
    title='Mentions per tag and decade',
    ylabel='amount of mentions',
    xlabel='decade')


<Figure size 1440x720 with 0 Axes>