In [None]:
import numpy as np
import random
from datetime import datetime

from conllu import parse

from keras.preprocessing.sequence import pad_sequences

from keras import backend as K
from keras.models import Model, load_model
from keras.layers import *
from keras.optimizers import Adam
from keras_contrib.layers import CRF
from keras import callbacks
from keras.layers import Lambda

from sklearn.metrics import f1_score

from matplotlib import pyplot as plt
from IPython.display import clear_output

def check_word(word):
    return len([c for c in word if not (c.isalnum() or c in ALLOWED_PUNCT)]) == 0

def fix_word(word):
    if UPPERCASE:
        word = ''.join([c for c in word if (c.isalnum() or c in ALLOWED_PUNCT)])
    else:
        word = ''.join([c.lower() for c in word if (c.isalnum() or c in ALLOWED_PUNCT)]).lower()
    return word

def fix_punct(word):
    word = fix_word(word)
    for p in [P for P in TAG_PUNCT if P != ',']:
        word = word.replace(p, '.')
    for p in [P for P in ALLOWED_PUNCT]:
        if p not in TAG_PUNCT:
            word = word.replace(p, '')
    return word

def stats(data):
    D1 = [y['form'] for x in data for y in x if y['upostag'] != 'PUNCT']
    D2 = [y['lemma'] for x in data for y in x if y['upostag'] != 'PUNCT']
    D3 = [[y['form'] for y in x if y['upostag'] != 'PUNCT'] for x in data]

    print('Total words:', len(D1))
    print('Total unique words:', len(set(D1)))
    print('Total unique lemma words:', len(set(D2)))
    print('Total chars:', len(''.join(D1)))
    print('Total unique chars:', len(set(''.join(D1))))
    print('Total sentences:', len(D3))
    print('Min word len:', min([len(w) for w in D2]))
    print('Max word len:', max([len(w) for w in D2]))
    print('Mean word len:', int(np.mean([len(w) for w in D2])))
    print('Min sentence len:', min([len(s) for s in D3]))
    print('Max sentence len:', max([len(s) for s in D3]))
    print('Mean sentence len:', int(np.mean([len(s) for s in D3])))


def create_text_and_tagged(data, dataset_path):
    
    sentences = parse(data)
    feature_sentences = []
    tagged_sentences = []
    stats(sentences)
     
    for i in range(len(sentences)):
        
        check = True
        
        sentence = sentences[i]
        feature_sentence = []
        tag_sentence = []
        words = []
        prev_punct = False

        try:
            for j in range(len(sentence)):
                
                word = sentence[j]
                
                check = check and check_word(word)

                if fix_word(word['form']) != '':
                    if word['upostag'] == 'PUNCT' and fix_word(word['form']) in ALLOWED_PUNCT:
                        tag_sentence.append(fix_punct(word['form']))
                        prev_punct = True
                    elif word['upostag'] != 'PUNCT':
                        if not prev_punct:
                            tag_sentence.append('')
                        feature_sentence.append(fix_word(word['form']))
                        prev_punct = False
                        
            if check and (len(feature_sentence) == len(tag_sentence[1:])): 
                
                feature_sentence = [dataset_path] + feature_sentence

                feature_sentences.append(feature_sentence)
                tagged_sentences.append(tag_sentence[1:])
        except:
            pass
        
    return feature_sentences, tagged_sentences

def prepare_dataset():
   
    print('Reading files')
    
    all_train_sentences, all_train_tags = [], []
    all_test_sentences, all_test_tags = [], []
    
    dataset_id = 0
    
    for PATH in PATHS_TRAIN:
        with open('UD-2.3/ud-treebanks-v2.3/' + PATH + 'train.conllu', encoding='utf-8', newline='') as f:
            data_train = f.read()

        print('Parsing UD-2.3/ud-treebanks-v2.3/' + PATH + 'train.conllu')
        train_sentences, train_tags = create_text_and_tagged(data_train, PATH)
        all_train_sentences += train_sentences
        all_train_tags += train_tags
        
    for PATH in PATHS_TEST:
        with open('UD-2.3/ud-treebanks-v2.3/' + PATH + 'test.conllu', encoding='utf-8', newline='') as f:
            data_test = f.read()

        print('Parsing UD-2.3/ud-treebanks-v2.3/' + PATH + 'test.conllu')
        test_sentences, test_tags = create_text_and_tagged(data_test, PATH)
        all_test_sentences += test_sentences
        all_test_tags += test_tags

    print('MAX_SENT_LENGTH =', MAX_SENT_LENGTH)
    train_sentences = [elem for elem in all_train_sentences if MIN_SENT_LENGTH <= len(elem[1:]) <= MAX_SENT_LENGTH]
    test_sentences = [elem for elem in all_test_sentences if MIN_SENT_LENGTH <= len(elem[1:]) <= MAX_SENT_LENGTH]
    train_tags = [elem for elem in all_train_tags if MIN_SENT_LENGTH <= len(elem) <= MAX_SENT_LENGTH]
    test_tags = [elem for elem in all_test_tags if MIN_SENT_LENGTH <= len(elem) <= MAX_SENT_LENGTH] 

    print(len(train_sentences), 'train samples,', len(test_sentences), 'test samples')

    return train_sentences, train_tags, test_sentences, test_tags


def get_dicts(train_sentences, train_tags):
    
    print('Getting dicts')
    
    chars = set([])

    for s in train_sentences:
        for w in s:
            for c in w:
                chars.add(c)

    char2index = {w: i + 2 for i, w in enumerate(list(chars))}
    char2index['CHAR_PAD'] = 0
    char2index['CHAR_OOV'] = 1
    
    tags = set([])
    
    for s in train_tags:
        for t in s:
            tags.add(t)
    for t in ALLOWED_PUNCT:
        tags.add(t)

    tag2index = {t: i + 1 for i, t in enumerate(list(tags))}
    tag2index['PAD'] = 0  # The special value used to padding
    
    return char2index, tag2index

def tokenize(sentences, tags, char2index, tag2index):

    sentences_X, tags_y = [], []

    for i in range(len(sentences)):
        s_int = []
        for j in range(len(sentences[i]) - 1):
            s_int2 = []
            for c in sentences[i][j + 1]:
                try:
                    s_int2.append(char2index[c])
                except KeyError:
                    s_int2.append(char2index['CHAR_OOV'])
                    
            s_int.append(s_int2)
        
        s_int = list(pad_sequences(s_int, maxlen=MAX_WORD_LENGTH, padding='post'))
        sentences_X.append([sentences[i][0]] + s_int)

    for s in tags:
        tags_y.append([tag2index[t] for t in s])

    return sentences_X, tags_y

def shuffle_and_join(sentences, tags):
    
    all_sentences2 = []
    all_tags2 = []
    
    for PATH in set(PATHS_TRAIN + PATHS_TEST):
        
        indexes = [i for i in range(len(sentences)) if sentences[i][0] == PATH]

        sentences2 = [sentences[i][1:] for i in indexes]
        tags2 = [tags[i] for i in indexes]

        sentences2 = [[y for x in sentences2[i:i + SENTENCES_JOINING] for y in x] for i in range(len(sentences2) - SENTENCES_JOINING + 1)]
        tags2 = [[y for x in tags2[i:i + SENTENCES_JOINING] for y in x] for i in range(len(tags2) - SENTENCES_JOINING + 1)]
        
        sentences2 = pad_sequences(sentences2, maxlen=MAX_SAMPLE_LENGTH, padding='post')
        tags2 = pad_sequences(tags2, maxlen=MAX_SAMPLE_LENGTH, padding='post')

        all_sentences2 += list(sentences2)
        all_tags2 += list(tags2)
   
    return np.array(all_sentences2), np.array(all_tags2)

def to_categorical(sequences, categories):
    cat_sequences = []
    for s in sequences:
        cats = []
        for item in s:
            cats.append(np.zeros(categories))
            cats[-1][item] = 1.0
        cat_sequences.append(cats)
    return np.array(cat_sequences)

def ignore_class_accuracy(to_ignore=0):
    def ignore_accuracy(y_true, y_pred):
        y_true_class = K.argmax(y_true, axis=-1)
        y_pred_class = K.argmax(y_pred, axis=-1)
 
        ignore_mask = K.cast(K.not_equal(y_pred_class, to_ignore), 'int32')
        matches = K.cast(K.equal(y_true_class, y_pred_class), 'int32') * ignore_mask
        accuracy = K.sum(matches) / K.maximum(K.sum(ignore_mask), 1)
        return accuracy
    return ignore_accuracy

def f1(y_true, y_pred, tag2index, c):
    
    y_true_class = K.argmax(y_true, axis=-1)
    y_pred_class = K.argmax(y_pred, axis=-1)
    
    true_mask = K.cast(K.equal(y_true_class, tag2index[c]), 'float')
    pred_mask = K.cast(K.equal(y_pred_class, tag2index[c]), 'float')    
    
    eq = K.cast(K.equal(y_true_class, y_pred_class), 'float')
    neq = K.cast(K.not_equal(y_true_class, y_pred_class), 'float')      

    tp = K.sum(eq * true_mask)
    fp = K.sum(neq * true_mask)
    fn = K.sum(neq * pred_mask)

    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    f1 = 2 * pr * rc / (pr + rc)

    return f1

def split_class_accuracy(char, tag2index):
    def split_accuracy_dot(y_true, y_pred):
        return f1(y_true, y_pred, tag2index, '.')

    def split_accuracy_comma(y_true, y_pred):
        return f1(y_true, y_pred, tag2index, ',')
    
    def split_accuracy_qmark(y_true, y_pred):
        return f1(y_true, y_pred, tag2index, '?')
    
    def split_accuracy_emark(y_true, y_pred):
        return f1(y_true, y_pred, tag2index, '!')
    
    f = {
        '.': split_accuracy_dot,
        ',': split_accuracy_comma,
        '?': split_accuracy_qmark,
        '!': split_accuracy_emark
    }
    
    return f[char]


class PlotLosses(callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
                
        draw_loss['logs'].append(logs)
        draw_loss['x'].append(draw_loss['i'])
        draw_loss['losses'].append(logs.get('loss'))
        draw_loss['val_losses'].append(logs.get('val_loss'))
        draw_loss['acc'].append(logs.get('crf_viterbi_accuracy'))
        draw_loss['val_acc'].append(logs.get('val_crf_viterbi_accuracy'))
        draw_loss['val_ignore'].append(logs.get('val_ignore_accuracy'))
        draw_loss['val_split_dot'].append(logs.get('val_split_accuracy_dot'))
        draw_loss['val_split_comma'].append(logs.get('val_split_accuracy_comma'))
        draw_loss['val_split_qmark'].append(logs.get('val_split_accuracy_qmark'))
        draw_loss['val_split_emark'].append(logs.get('val_split_accuracy_emark'))
        draw_loss['i'] += 1
        f, (ax1, ax2) = plt.subplots(1, 2, sharex=True, figsize=(15,8))
        
        clear_output(wait=True)
        
        ax1.plot(draw_loss['x'], draw_loss['losses'], label="loss")
        ax1.plot(draw_loss['x'], draw_loss['val_losses'], label="val_loss")
        ax1.legend()
        ax1.grid()
        ax1.set_xlabel('Number of epochs')
        ax1.set_ylabel('Loss')
        ax1.set_title('Loss functions')

#         ax2.plot(draw_loss['x'], draw_loss['acc'], label="acc")
#         ax2.plot(draw_loss['x'], draw_loss['val_acc'], label="val_acc")
#         ax2.plot(draw_loss['x'], draw_loss['val_ignore'], label="val_acc_no_pad")
        if '.' in TAG_PUNCT:
            ax2.plot(draw_loss['x'], draw_loss['val_split_dot'], label="Period")
        if ',' in TAG_PUNCT:
            ax2.plot(draw_loss['x'], draw_loss['val_split_comma'], label="Comma")
#         if '?' in TAG_PUNCT:
#             ax2.plot(draw_loss['x'], draw_loss['val_split_qmark'], label="Qmark")
#         if '!' in TAG_PUNCT:
#             ax2.plot(draw_loss['x'], draw_loss['val_split_emark'], label="Emark")

        ax2.legend()
        ax2.grid()        
        ax2.set_xlabel('Number of epochs')
        ax2.set_ylabel('F1 score')
        ax2.set_title('Punctuation restoration F1 scores')
        
        draw_loss['plt'] = plt
        
        plt.show()
        
        f.savefig('pics/' + str(datetime.timestamp(datetime.now())) + '.png')

def create_model(char2index, tag2index):

    print('Model creating')

    input_layer = Input(shape=(MAX_SAMPLE_LENGTH,MAX_WORD_LENGTH,))

    embedding_layer = Embedding(len(char2index), EMBEDDING_SIZE)(input_layer)

    lstm = Bidirectional(LSTM(HID_SIZE, return_sequences=False))
    
    td0 = TimeDistributed(lstm, input_shape=(MAX_WORD_LENGTH, EMBEDDING_SIZE))(embedding_layer)
    
    lstm2 = Bidirectional(LSTM(HID_SIZE2, return_sequences=True, dropout=DROPOUT))(td0)

    td1 = TimeDistributed(Dense(TD_SIZE, activation='relu'))(lstm2)
    
    crf = CRF(len(tag2index))

    output_layer = crf(td1)

    model = Model(input_layer, output_layer)
    model.summary()

    metrics = [crf.accuracy, ignore_class_accuracy(0)]
    
    for c in TAG_PUNCT:
        metrics.append(split_class_accuracy(c, tag2index))

    model.compile(optimizer='adam', loss=crf.loss_function, metrics=metrics)
    
    return model

def reset_draw():
    draw_loss = {}
    draw_loss['i'] = 0
    draw_loss['x'] = []
    draw_loss['losses'] = []
    draw_loss['val_losses'] = []
    draw_loss['acc'] = []
    draw_loss['val_acc'] = []
    draw_loss['val_ignore'] = []
    draw_loss['val_split_dot'] = []
    draw_loss['val_split_comma'] = []
    draw_loss['val_split_qmark'] = []
    draw_loss['val_split_emark'] = []
    draw_loss['fig'] = plt.figure()
    draw_loss['logs'] = []
    return draw_loss

def train_model():
    model = create_model(char2index, tag2index)

    train_sentences_X2, train_tags_y2 = shuffle_and_join(train_sentences_X, train_tags_y)
    test_sentences_X2, test_tags_y2 = shuffle_and_join(test_sentences_X, test_tags_y)

    model.fit(train_sentences_X2, to_categorical(train_tags_y2, len(tag2index)),
          batch_size=BATCH,
          epochs=EPOCH,
          callbacks=[PlotLosses()],
          validation_data=(test_sentences_X2, to_categorical(test_tags_y2, len(tag2index))))
    
    return model
    
    
def logits_to_tokens(sequences, index):
    token_sequences = []
    for categorical_sequence in sequences:
        token_sequence = []
        for categorical in categorical_sequence:
            token_sequence.append(index[np.argmax(categorical)])
 
        token_sequences.append(token_sequence)
 
    return token_sequences


def predict(test_sample):
    
    test_sample[0] = [''] + test_sample[0]
    
    test_samples_X, _ = tokenize(test_sample, [], char2index, tag2index)
    
    test_samples_X[0] = test_samples_X[0][1:]
    
    test_samples_X = pad_sequences(test_samples_X, maxlen=MAX_SAMPLE_LENGTH, padding='post')
        
    predictions = model.predict(test_samples_X)

    result = logits_to_tokens(predictions, {i: t for t, i in tag2index.items()})

    return result


def transform(sentence_features):
    tags = predict(sentence_features)
    return ' '.join([' '.join(sentence_features[i][j+1] + tags[i][j] for j in range(len(sentence_features[i])-1)) for i in range(len(sentence_features))])

def validate():
    r = random.randint(0, len(test_sentences) - SENTENCES_JOINING)
    ind = list(range(r, r + SENTENCES_JOINING))
    sentence_features = [test_sentences[q][1:] for q in ind]
    sentence_tags = [test_tags[q] for q in ind]
    sentence_features = [[y for x in sentence_features for y in x]]

    pred = transform(sentence_features)

    print(sentence_features)
    print(sentence_tags)
    print(pred)

In [None]:
paths = {
    'ru': 'UD_Russian-SynTagRus/ru_syntagrus-ud-',
    'ru-small': 'UD_Russian-GSD/ru_gsd-ud-',
    'en': 'UD_English-GUM/en_gum-ud-',
    'uk': 'UD_Ukrainian-IU/uk_iu-ud-',
    'be': 'UD_Belarusian-HSE/be_hse-ud-',
}

In [None]:
PATHS_TRAIN = [paths[key] for key in ['ru']]
PATHS_TEST = [paths[key] for key in ['ru']]

ALLOWED_PUNCT = ['.',',','!','?','?!','!?','..','...']
TAG_PUNCT = ['.', ',', '?', '!']
NGRAM = 2

DO_SHUFFLE = False
UPPERCASE = False

MIN_SENT_LENGTH = -1
MAX_SENT_LENGTH = 40
MAX_WORD_LENGTH = 30

random.seed(0)
train_sentences, train_tags, test_sentences, test_tags = prepare_dataset()
char2index, tag2index = get_dicts([i[1:] for i in train_sentences], train_tags)
train_sentences_X, train_tags_y = tokenize(train_sentences, train_tags, char2index, tag2index)
test_sentences_X, test_tags_y = tokenize(test_sentences, test_tags, char2index, tag2index)

SENTENCES_JOINING = 3
DROPOUT = 0.7
EMBEDDING_SIZE = 100
HID_SIZE = 200
HID_SIZE2 = 500
TD_SIZE = 10
BATCH = 50
EPOCH = 100

if not MAX_SENT_LENGTH:
    MAX_SENT_LENGTH = len(max(train_sentences, key=len))

MAX_SAMPLE_LENGTH = MAX_SENT_LENGTH * SENTENCES_JOINING

draw_loss = reset_draw()

model = train_model()
validate()

In [None]:
size = 50
r = random.randint(0, len(test_sentences) - size)
ind = list(range(r, r + size))
sentence_features = [test_sentences[q][1:] for q in ind]
sentence_tags = [test_tags[q] for q in ind]
sentence_features = [[y for x in sentence_features for y in x]]

counters = [{i:0 for i in TAG_PUNCT} for j in range(len(sentence_features[0]))]
for i in range(len(counters)):
    counters[i][''] = 0
    counters[i]['PAD'] = 0
    
preds = [predict([sentence_features[0][r:r+MAX_SAMPLE_LENGTH]]) for r in range(len(sentence_features[0])-MAX_SAMPLE_LENGTH+1)]
for i in range(len(preds)):
    for j in range(len(preds[i][0])):
        counters[i+j][preds[i][0][j]] += 1

best = [max(i, key=i.get) for i in counters]

# print(sentence_features)
# print(sentence_tags)
# print(counters)
# print(best)
# print(preds)
print(' '.join([''.join(i) for i in list(zip(sentence_features[0], best))]))

In [None]:
validate()

In [None]:
index2char = {char2index[i]:i for i in char2index}

In [None]:
# test_sentences_X2, test_tags_y2 = shuffle_and_join(test_sentences_X, test_tags_y)
[''.join([index2char[j] for j in i]).replace('CHAR_PAD','') for i in test_sentences_X2[10]]

In [None]:
with open('UD-2.3/ud-treebanks-v2.3/' + paths['en'] + 'train.conllu', encoding='utf-8', newline='') as f:
    data = parse(f.read())

In [None]:
D1 = [y['form'] for x in data for y in x if y['upostag'] != 'PUNCT']
D2 = [y['lemma'] for x in data for y in x if y['upostag'] != 'PUNCT']
D3 = [[y['form'] for y in x if y['upostag'] != 'PUNCT'] for x in data]

print('Total words:', len(D1))
print('Total unique words:', len(set(D1)))
print('Total unique lemma words:', len(set(D2)))
print('Total chars:', len(''.join(D1)))
print('Total unique chars:', len(set(''.join(D1))))
print('Total sentences:', len(D3))
print('Min word len:', min([len(w) for w in D2]))
print('Max word len:', max([len(w) for w in D2]))
print('Mean word len:', int(np.mean([len(w) for w in D2])))
print('Min sentence len:', min([len(s) for s in D3]))
print('Max sentence len:', max([len(s) for s in D3]))
print('Mean sentence len:', int(np.mean([len(s) for s in D3])))

In [None]:
import matplotlib.pyplot as plt
plt.hist([len(w) for w in D3], bins=30)