In [15]:
import numpy as np
from collections import Counter
from sklearn.utils.linear_assignment_ import linear_assignment
import sys
import time
import os
from os.path import join, isdir, basename
from tqdm import tqdm
import json
import collections
import tensorflow as tf
import operator
import copy

In [8]:
def flatten(l):
  return [item for sublist in l for item in sublist]

class DocumentState(object):
    def __init__(self):
        self.doc_key = None
        self.text = []
        self.text_speakers = []
        self.speakers = []
        self.sentences = []
        self.clusters = collections.defaultdict(list)
        self.stacks = collections.defaultdict(list)

    def assert_empty(self):
        assert self.doc_key is None
        assert len(self.text) == 0
        assert len(self.text_speakers) == 0
        assert len(self.sentences) == 0
        assert len(self.speakers) == 0
        assert len(self.clusters) == 0
        assert len(self.stacks) == 0

    def assert_finalizable(self):
        assert self.doc_key is not None
        assert len(self.text) == 0
        assert len(self.text_speakers) == 0
        assert len(self.sentences) > 0
        assert len(self.speakers) > 0
        assert all(len(s) == 0 for s in self.stacks.values())

    def finalize(self):
        merged_clusters = []
        for c1 in self.clusters.values():
            existing = None
            for m in c1:
                for c2 in merged_clusters:
                    if m in c2:
                        existing = c2
                        break
                if existing is not None:
                    break
            if existing is not None:
                print("Merging clusters (shouldn't happen very often.)")
                existing.update(c1)
            else:
                merged_clusters.append(set(c1))
        merged_clusters = [list(c) for c in merged_clusters]
        all_mentions = flatten(merged_clusters)
        # print len(all_mentions), len(set(all_mentions))

        if len(all_mentions) != len(set(all_mentions)):
            c = Counter(all_mentions)
            for x in c:
                if c[x] > 1:
                    z = x
                    break
            for i in range(len(all_mentions)):
                if all_mentions[i] == z:
                    all_mentions.remove(all_mentions[i])
                    break
        assert len(all_mentions) == len(set(all_mentions))

        return {
            "doc_key": self.doc_key,
            "sentences": self.sentences,
            "speakers": self.speakers,
            "clusters": merged_clusters
        }

def normalize_word(word):
    if word == "/." or word == "/?":
        return word[1:]
    else:
        return word

def conll2modeldata(data):
    document_state = DocumentState()
    document_state.assert_empty()
    document_state.doc_key = "{}_{}".format(data['doc_id'][0], data['part_id'][0])
    for i in range(len(data['doc_id'])):
        word = normalize_word(data['word'][i])
        coref = data['coreference'][i]
        speaker = data['speaker'][i]
        word_index = i + 1
        document_state.text.append(word)
        document_state.text_speakers.append(speaker)

        if coref != "-":
            for segment in coref.split("|"):
                if segment[0] == "(":
                    if segment[-1] == ")":
                        cluster_id = int(segment[1:-1]) # Need Int
                        document_state.clusters[cluster_id].append((word_index, word_index))
                    else:
                        cluster_id = int(segment[1:])
                        document_state.stacks[cluster_id].append(word_index)
                else:
                    cluster_id = int(segment[:-1])
                    start = document_state.stacks[cluster_id].pop()
                    document_state.clusters[cluster_id].append((start, word_index))
        else:                 
            if (data['part_of_speech'][i] == 'End_of_sentence'):
                document_state.sentences.append(tuple(document_state.text))
                del document_state.text[:]
                document_state.speakers.append(tuple(document_state.text_speakers))
                del document_state.text_speakers[:]
                continue
            else:
                continue
    
    document_state.assert_finalizable()
    return document_state.finalize()

In [9]:
def conll2dict(iter_id, conll, agent, mode, epoch_done=False):
    data = {'doc_id': [],
            'part_id': [],
            'word_number': [],
            'word': [],
            'part_of_speech': [],
            'parse_bit': [],
            'lemma': [],
            'sense': [],
            'speaker': [],
            'entiti': [],
            'predict': [],
            'coreference': [],
            'iter_id': iter_id,
            'id': agent,
            'epoch_done': epoch_done,
            'mode': mode}

    with open(conll, 'r') as f:
        for line in f:
            row = line.split('\t')
            if row[0].startswith('#'):
                continue
            elif row[0] == '\n':
                data['doc_id'].append('bc')
                data['part_id'].append('0')
                data['word_number'].append('0')
                data['word'].append('SeNt')
                data['part_of_speech'].append('End_of_sentence')
                data['parse_bit'].append('-')
                data['lemma'].append('-')
                data['sense'].append('-')
                data['speaker'].append('-')
                data['entiti'].append('-')
                data['predict'].append('-')
                data['coreference'].append('-')
            else:
                assert len(row) >= 12
                data['doc_id'].append(row[0])
                data['part_id'].append(row[1])
                data['word_number'].append(row[2])
                data['word'].append(row[3])
                data['part_of_speech'].append(row[4])
                data['parse_bit'].append(row[5])
                data['lemma'].append(row[6])
                data['sense'].append(row[7])
                data['speaker'].append(row[8])
                data['entiti'].append(row[9])
                data['predict'].append(row[10])
                data['coreference'].append(row[11][0:-1])
        f.close()
    return data

def dict2conll(data, predict):
    #
    with open(predict, 'w') as CoNLL:
        for i in range(len(data['doc_id'])):
            if i == 0:
                CoNLL.write('#begin document ({}); part {}\n'.format(data['doc_id'][i], data["part_id"][i]))
                CoNLL.write(u'{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(data['doc_id'][i],
                                                    data["part_id"][i],
                                                    data["word_number"][i],
                                                    data["word"][i],
                                                    data["part_of_speech"][i],
                                                    data["parse_bit"][i],
                                                    data["lemma"][i],
                                                    data["sense"][i],
                                                    data["speaker"][i],
                                                    data["entiti"][i],
                                                    data["predict"][i],
                                                    data["coreference"][i]))
            elif i == len(data['doc_id'])-1 and data['part_of_speech'][i] == 'End_of_sentence':
                CoNLL.write('#end document\n')
            elif data['part_of_speech'][i] == 'End_of_sentence':
                continue
            else:
                if data['doc_id'][i] == data['doc_id'][i+1]:
                    CoNLL.write(u'{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(data['doc_id'][i],
                                                        data["part_id"][i],
                                                        data["word_number"][i],
                                                        data["word"][i],
                                                        data["part_of_speech"][i],
                                                        data["parse_bit"][i],
                                                        data["lemma"][i],
                                                        data["sense"][i],
                                                        data["speaker"][i],
                                                        data["entiti"][i],
                                                        data["predict"][i],
                                                        data["coreference"][i]))
                elif data['part_of_speech'][i] == 'End_of_sentence':
                    continue
                else:
                    CoNLL.write(u'{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(data['doc_id'][i],
                                                    data["part_id"][i],
                                                    data["word_number"][i],
                                                    data["word"][i],
                                                    data["part_of_speech"][i],
                                                    data["parse_bit"][i],
                                                    data["lemma"][i],
                                                    data["sense"][i],
                                                    data["speaker"][i],
                                                    data["entiti"][i],
                                                    data["predict"][i],
                                                    data["coreference"][i]))
                    CoNLL.write('\n')
        CoNLL.close()
    return None

In [10]:
conll = '/home/petrov/coreference_kpi/coreference/src/parlai/data/coreference/russian/train/0.russian.v4_conll'
data = conll2dict(0, conll, 'agent', 'test', epoch_done=False)

In [None]:
print(data['coreference'])

In [11]:
predict = './test.conll'
dict2conll(data, predict)

In [12]:
a = conll2modeldata(data)

In [13]:
print(a)

{'doc_key': 'bc1_0', 'sentences': [('Во', 'время', 'своих', 'прогулок', 'в', 'окрестностях', 'Симеиза', 'я', 'обратил', 'внимание', 'на', 'одинокую', 'дачу', ',', 'стоявшую', 'на', 'крутом', 'склоне', 'горы', '.', 'SeNt'), ('К', 'этой', 'даче', 'не', 'было', 'проведено', 'даже', 'дороги', '.', 'SeNt'), ('Кругом', 'она', 'была', 'обнесена', 'высоким', 'забором', ',', 'с', 'единственной', 'низкой', 'калиткой', ',', 'которая', 'всегда', 'была', 'плотно', 'прикрыта', '.', 'SeNt'), ('И', 'ни', 'куста', 'зелени', ',', 'ни', 'дерева', 'не', 'виднелось', 'над', 'забором', '.', 'SeNt'), ('Кругом', 'дачи', '-', 'голые', 'уступы', 'желтоватых', 'скал', ';', 'меж', 'ними', 'кое', '-', 'где', 'росли', 'чахлые', 'можжевельники', 'и', 'низкорослые', ',', 'кривые', 'горные', 'сосны', '.', 'SeNt'), ('"', 'Что', 'за', 'фантазия', 'пришла', 'кому', '-', 'то', 'в', 'голову', 'поселиться', 'на', 'этом', 'диком', ',', 'голом', 'утесе', '?', 'Да', 'и', 'живет', 'ли', 'там', 'кто', '-', 'нибудь', '?', '"', '-

In [14]:
from os.path import join
from os import listdir
from tqdm import tqdm

path = '/home/petrov/coreference_kpi/coreference/src/parlai/data/coreference/russian/train/'
conll_list = listdir(path)
for x in tqdm(conll_list):
    data = conll2dict(0, join(path,x), 'agent', 'test', epoch_done=False)
    a = conll2modeldata(data)

 57%|█████▋    | 57/100 [00:00<00:00, 292.27it/s]

Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very 

100%|██████████| 100/100 [00:00<00:00, 315.50it/s]

Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)
Merging clusters (shouldn't happen very often.)





In [16]:
path1 = '../../../src/parlai/data/coreference/russian/vocab/char_vocab.russian.txt'
path2 = '/home/petrov/coreference_kpi/code/e2e-coref/vocab/char_vocab.russian.txt'

In [49]:
with open(path1,'r') as new:
    x = new.readlines()
with open(path2,'r') as old:
    y = old.readlines()

print(len(sorted(set(x))), len(sorted(set(y))))

199 197


In [56]:
z = list(set(x)&set(y))
print(len(z))

197


In [44]:
if ' \n' in set(x):
    print(True)
else:
    print(False)

True


In [55]:
print(set(z))

{'\n', ' \n'}


In [48]:
s = [' \n','\n']
with open(path2,'r+') as file:
    for c in z:
        file.write(c)