In [1]:
import os
import sys
from collections import Counter

sys.path.append('..')

In [2]:
sys.argv = [
"--device cuda",
"--seed", "13",
"--configuration", "kbert",
"--language", "german",
"--challenge", "semantic-change",
"--evaluation-type", "cosine-distance", "euclidean-distance",
"--pretrained-weights", "bert-base-multilingual-cased",
"--pretrained-max-length", "512",
"--experiment-types", "word-similarity"]


In [3]:
# Configure container:
from dependency_injection.ioc_container import IocContainer

container = IocContainer()

In [4]:
file_service = container.file_service()
arguments_service = container.arguments_service()
metrics_service = container.metrics_service()

In [10]:
def read_language_text(corpus: int):
    data_path = os.path.join('..', file_service.get_challenge_path(), 'eval', str(arguments_service.language), f'corpus{corpus}')
    for dir_name in os.listdir(data_path):
        dir_path = os.path.join(data_path, dir_name)
        if not os.path.isdir(dir_path):
            continue

        for file_name in os.listdir(dir_path):
            if not file_name.endswith('.txt'):
                continue

            with open(os.path.join(dir_path, file_name), 'r', encoding='utf-8') as language_file:
                language_text = language_file.read().replace('\n', ' \n ').lower().split(' ')
                language_text = list(filter(None, language_text))
                return language_text
            
def get_target_words():
    data_path = os.path.join('..', file_service.get_challenge_path(), 'eval', str(arguments_service.language), 'targets.txt')
    with open(data_path, 'r', encoding='utf-8') as targets_file:
        targets = [x.replace('\n', '').lower() for x in targets_file.readlines() if x != '']
        return targets
    
def indices_of_targets(list_to_check, targets):
    result = [(x,i) for i, x in enumerate(list_to_check) if x in targets]
    return result

def get_word_contexts(target_words, all_tokens):
    close_word_contexts = {}
    far_word_contexts = {}
    all_tokens_length = len(all_tokens)
    min_threshold = 0
    window_sizes = list(range(1, 11))
    window_size_threshold = 5
    
    target_indices = indices_of_targets(all_tokens, target_words)
    for target_word in target_words:
        current_word_close_context = Counter()
        current_word_far_context = Counter()
        indices = [x[1] for x in target_indices if x[0] == target_word]
        for index in indices:
            back_window_depleted = False
            forward_window_depleted = False
            for window_size in window_sizes:
                if index - window_size >= 0 and not back_window_depleted:
                    if all_tokens[index-window_size] == '\n':
                        back_window_depleted = True
                        continue
                        
                    if window_size <= window_size_threshold:
                        current_word_close_context.update({all_tokens[index-window_size] : 1})
                    else:
                        current_word_far_context.update({all_tokens[index-window_size] : 1})
                        
                if index + window_size < all_tokens_length and not forward_window_depleted:
                    if all_tokens[index+window_size] == '\n':
                        forward_window_depleted = True
                        continue
                        
                    if window_size <= window_size_threshold:
                        current_word_close_context.update({all_tokens[index+window_size] : 1})
                    else:
                        current_word_far_context.update({all_tokens[index+window_size] : 1})
        
        current_word_close_context = Counter({x : current_word_close_context[x] for x in current_word_close_context if current_word_close_context[x] >= min_threshold})
        current_word_far_context = Counter({x : current_word_far_context[x] for x in current_word_far_context if current_word_far_context[x] >= min_threshold})
        close_word_contexts[target_word] = current_word_close_context
        far_word_contexts[target_word] = current_word_far_context
    
    return close_word_contexts, far_word_contexts

In [11]:
language_tokens_1 = read_language_text(corpus=1)
language_tokens_2 = read_language_text(corpus=2)

target_words = get_target_words()

close_word_contexts_1, far_word_contexts_1 = get_word_contexts(target_words, language_tokens_1)
close_word_contexts_2, far_word_contexts_2 = get_word_contexts(target_words, language_tokens_2)

In [13]:
close_jaccard_similarities = {}
far_jaccard_similarities = {}

close_cosine_distances = {}
far_cosine_distances = {}

target_words.sort(key=lambda x: x.upper())
for target_word in target_words:
    close_context_words_1 = list(close_word_contexts_1[target_word].elements())
    close_context_words_1.sort(key=lambda x: x.upper())
    
    close_context_words_2 = list(close_word_contexts_2[target_word].elements())
    close_context_words_2.sort(key=lambda x: x.upper())
    
    close_jaccard_similarity = metrics_service.calculate_jaccard_similarity(close_context_words_1, close_context_words_2)
    if not close_jaccard_similarity:
        raise Exception(f'nan close similarity for {target_word}')
    close_jaccard_similarities[target_word] = close_jaccard_similarity
    
    close_context_words = list(set(list(close_word_contexts_1[target_word].keys()) + list(close_word_contexts_2[target_word].keys())))
    new_close_context_words_1 = [close_word_contexts_1[target_word][key] for key in close_context_words]
    new_close_context_words_2 = [close_word_contexts_2[target_word][key] for key in close_context_words]
    
    close_cosine_distance = metrics_service.calculate_cosine_distance(new_close_context_words_1, new_close_context_words_2)
    if not close_cosine_distance:
        raise Exception(f'nan close distance for {target_word}')
    close_cosine_distances[target_word] = close_cosine_distance
    
    far_context_words_1 = list(far_word_contexts_1[target_word].elements())
    far_context_words_1.sort()
    
    far_context_words_2 = list(far_word_contexts_2[target_word].elements())
    far_context_words_2.sort()
    
    far_jaccard_similarity = metrics_service.calculate_jaccard_similarity(far_context_words_1, far_context_words_2)
    far_jaccard_similarities[target_word] = far_jaccard_similarity
    
    far_context_words = list(set(list(far_word_contexts_1[target_word].keys()) + list(far_word_contexts_2[target_word].keys())))
    new_far_context_words_1 = [far_word_contexts_1[target_word][key] for key in far_context_words]
    new_far_context_words_2 = [far_word_contexts_2[target_word][key] for key in far_context_words]
    
    far_cosine_distance = metrics_service.calculate_cosine_distance(new_far_context_words_1, new_far_context_words_2)
    far_cosine_distances[target_word] = far_cosine_distance

   
print('\n'.join([str(x) for x in list(close_jaccard_similarities.values())]))
print('---')
print('\n'.join([str(x) for x in list(far_jaccard_similarities.values())]))
print('---')
print('\n'.join([str(x) for x in list(close_cosine_distances.values())]))
print('---')
print('\n'.join([str(x) for x in list(far_cosine_distances.values())]))
print('---')

# print(far_jaccard_similarities)
# print(close_cosine_distances)
# print(far_cosine_distances)

0.04924660051451672
0.05355776587605203
0.07653061224489796
0.10153482880755609
0.10149253731343283
0.11389236545682102
0.08180839612486544
0.04941634241245136
0.10682110682110682
0.12409420289855072
0.12167606768734891
0.07036374478234943
0.11987860394537178
0.09480122324159021
0.08578856152512998
0.16620206434114368
0.06388459556929418
0.1357142857142857
0.11180937418172296
0.11425462459194777
0.10512483574244415
0.06690528634361234
0.10901001112347053
0.10444177671068428
0.12546125461254612
0.1111111111111111
0.1506172839506173
0.09345794392523364
0.11693548387096774
0.165303086649312
0.09930715935334873
0.11976047904191617
0.07334766254701773
0.0836283185840708
0.13396481732070364
0.15020303483650352
0.12107034860500941
0.1056782334384858
0.10320284697508897
0.11696658097686376
0.11425339366515837
0.12293388429752067
0.12055507372072853
0.164159365184328
0.07205067300079177
0.11240875912408758
0.10609697551608258
0.1278127812781278
---
0.04863344051446945
0.06506238859180036
0.0957