In [88]:
# encoding=utf8
import dynet as dy
import random
import sys
import time
import utils
from utils import *

dyparams = dy.DynetParams()
dyparams.from_args()
dyparams.set_mem(4096)
dyparams.set_random_seed(1)
dyparams.init()

%load_ext autoreload
%autoreload 2

reload(sys)
sys.setdefaultencoding('utf8')

conll_train = "/Users/huseyinalecakir/NLP_LAB/data/tr_imst-ud-train.conllu"
conll_dev = "/Users/huseyinalecakir/NLP_LAB/data/tr_imst-ud-dev.conllu"

c2i, w2i, features = utils.vocab(conll_train)

EOS = '<s>'

int2char = {c2i[i] : i for i in c2i}
char2int = c2i

VOCAB_SIZE = len(c2i)

LSTM_NUM_OF_LAYERS = 2
EMBEDDINGS_SIZE = 128
STATE_SIZE = 256
ATTENTION_SIZE = 64

model = dy.Model()

enc_fwd_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, model)
enc_bwd_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, model)

dec_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS, STATE_SIZE*2+EMBEDDINGS_SIZE, STATE_SIZE, model)

input_lookup = model.add_lookup_parameters((VOCAB_SIZE, EMBEDDINGS_SIZE))
attention_w1 = model.add_parameters( (ATTENTION_SIZE, STATE_SIZE*2))
attention_w2 = model.add_parameters( (ATTENTION_SIZE, STATE_SIZE*LSTM_NUM_OF_LAYERS*2))
attention_v = model.add_parameters( (1, ATTENTION_SIZE))
decoder_w = model.add_parameters( (VOCAB_SIZE, STATE_SIZE))
decoder_b = model.add_parameters( (VOCAB_SIZE))
output_lookup = model.add_lookup_parameters((VOCAB_SIZE, EMBEDDINGS_SIZE))

In [89]:
def embed_sentence(sentence):
    sentence = ["<s>"] + list(sentence) + ["<s>"]
    sentence = [char2int[c] for c in sentence]

    global input_lookup

    return [input_lookup[char] for char in sentence]

def run_lstm(init_state, input_vecs):
    s = init_state

    out_vectors = []
    for vector in input_vecs:
        s = s.add_input(vector)
        out_vector = s.output()
        out_vectors.append(out_vector)
    return out_vectors

def encode_sentence(enc_fwd_lstm, enc_bwd_lstm, sentence):
    sentence_rev = list(reversed(sentence))

    fwd_vectors = run_lstm(enc_fwd_lstm.initial_state(), sentence)
    bwd_vectors = run_lstm(enc_bwd_lstm.initial_state(), sentence_rev)
    bwd_vectors = list(reversed(bwd_vectors))
    vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]

    return vectors


def attend(input_mat, state, w1dt):
    global attention_w2
    global attention_v
    w2 = dy.parameter(attention_w2)
    v = dy.parameter(attention_v)

    # input_mat: (encoder_state x seqlen) => input vecs concatenated as cols
    # w1dt: (attdim x seqlen)
    # w2dt: (attdim x attdim)
    w2dt = w2*dy.concatenate(list(state.s()))
    # att_weights: (seqlen,) row vector
    unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
    att_weights = dy.softmax(unnormalized)
    # context: (encoder_state)
    context = input_mat * att_weights
    return context


def decode(dec_lstm, vectors, output):
    #output = [EOS] + list(output) + [EOS]
    #output = [char2int[c] for c in output]
    output = [c2i["<s>"]] + output + [c2i["<s>"]]

    w = dy.parameter(decoder_w)
    b = dy.parameter(decoder_b)
    w1 = dy.parameter(attention_w1)
    input_mat = dy.concatenate_cols(vectors)
    w1dt = None

    last_output_embeddings = output_lookup[c2i["<s>"]]
    s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE*2), last_output_embeddings]))
    loss = []

    for char in output:
        # w1dt can be computed and cached once for the entire decoding phase
        w1dt = w1dt or w1 * input_mat
        vector = dy.concatenate([attend(input_mat, s, w1dt), last_output_embeddings])
        s = s.add_input(vector)
        out_vector = w * s.output() + b
        probs = dy.softmax(out_vector)
        last_output_embeddings = output_lookup[char]
        loss.append(-dy.log(dy.pick(probs, char)))
    loss = dy.esum(loss)
    return loss


def generate(in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
    embedded = embed_sentence(in_seq)
    encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)

    w = dy.parameter(decoder_w)
    b = dy.parameter(decoder_b)
    w1 = dy.parameter(attention_w1)
    input_mat = dy.concatenate_cols(encoded)
    w1dt = None

    last_output_embeddings = output_lookup[c2i["<s>"]]
    s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2), last_output_embeddings]))

    out = []
    count_EOS = 0
    for i in range(len(in_seq)*2):
        if count_EOS == 2: break
        # w1dt can be computed and cached once for the entire decoding phase
        w1dt = w1dt or w1 * input_mat
        vector = dy.concatenate([attend(input_mat, s, w1dt), last_output_embeddings])
        s = s.add_input(vector)
        out_vector = w * s.output() + b
        probs = dy.softmax(out_vector).vec_value()
        next_char = probs.index(max(probs))
        last_output_embeddings = output_lookup[next_char]
        if int2char[next_char] == EOS:
            count_EOS += 1
            continue

        out.append(next_char)
    return out


def get_loss(input_sentence, output_sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
    dy.renew_cg()
    embedded = embed_sentence(input_sentence)
    encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)
    return decode(dec_lstm, encoded, output_sentence)

def convert2chars(ints, i2c):
    return [i2c[i] for i in ints]

def compute_accuracy(gold, predicted, metric="set"):
    result = 0.0
    if metric is "set_match":
        correct_out = 0.0
        for g in gold:
            if gold in predicted:
                correct_out += 1
        result += correct_out/len(gold)
    elif metric is "exact_match":
        if len(predicted) == len(gold):
            all_equal = True
            for g, p in zip(gold, predicted):
                if g != p:
                    all_equal = False
            if all_equal:
                result = 1.0
    return result

In [90]:
def train(model, conll_path):
    trainer = dy.AdamTrainer(model)
    total_loss = 0
    entry_count = 0
    start = time.time()
    with open(conll_path, 'r') as conllFP:
        shuffled_data = list(read_conll(conllFP, c2i, w2i))
        shuffled_data = shuffled_data
        random.shuffle(shuffled_data)
        for iSentence, sentence in enumerate(shuffled_data):
            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            for entry in conll_sentence:
                loss = get_loss(entry.chars, entry.decoder_gold_output, enc_fwd_lstm, enc_bwd_lstm, dec_lstm)
                loss_value = loss.value()
                loss.backward()
                trainer.update()
                total_loss += loss_value
                entry_count += 1
            if iSentence % 500 == 0:
                print("Sentence: {} Loss: {} Time: {}".format(iSentence, total_loss/(entry_count), time.time() - start))
                start = time.time()

In [None]:
def evaluate(model,conll_path):
    count = 0
    correct = 0
    start = time.time()
    with open(conll_path, 'r') as conllFP:
        for iSentence, sentence in enumerate(read_conll(conllFP, c2i, w2i)):
            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            for entry in conll_sentence:
                predicted_sequence = generate(entry.chars, enc_fwd_lstm, enc_bwd_lstm, dec_lstm)
                correct = compute_accuracy(entry.decoder_gold_output, predicted_sequence, "exact_match")
                count += 1
        score = float(correct) * 100 / count
    print("Evaluation duration : {}".format(time.time()-start))
    return score

In [None]:
num_epoch = 3
highestScore = 0.0
eId = 0
start = time.time()
for epoch in range(num_epoch):
    print("--- epoch {} --- ".format(epoch))
    train(model, conll_train)
    score = evaluate(model, conll_dev)
    print "---\nAccuracy:\t%.2f" % score
    if score >= highestScore:
        highestScore = score
        eId = epoch + 1
    print "Highest: %.2f at epoch %d" % (highestScore, eId)
    print("Epoch: {} Total duration: {}".format(epoch, time.time()-start))
    start = time.time()

In [None]:
# metrics and dy.renew_cg() in generate function