In [2]:
# encoding=utf8

import dynet_config

dynet_config.set(mem=2048, random_seed=9)

import dynet as dy
import random
import sys
import time
from utils import read_conll, vocab, ConllEntry

reload(sys)
sys.setdefaultencoding('utf8')

random.seed(1)

# Static variables
EOS = '<s>'
SW = "start"

# HyperParameters
LSTM_NUM_OF_LAYERS = 2
EMBEDDINGS_SIZE = 128
STATE_SIZE = 256
ATTENTION_SIZE = 64


class Learner:
    def __init__(self, mode):
        self.mode = mode

        if self.mode:
            self.conll_train = "/home/huseyin/Data/UD_Turkish-IMST/tr_imst-ud-train.conllu"
            self.conll_dev = "/home/huseyin/Data/UD_Turkish-IMST/tr_imst-ud-dev.conllu"
        else:
            self.conll_train = "/Users/huseyinalecakir/NLP_LAB/data/tr_imst-ud-train.conllu"
            self.conll_dev = "/Users/huseyinalecakir/NLP_LAB/data/tr_imst-ud-dev.conllu"

        self.c2i, self.t2i = vocab(self.conll_train)

        self.i2c = {self.c2i[i]: i for i in self.c2i}
        self.i2t = {self.t2i[i]: i for i in self.t2i}

        CHAR_VOCAB_SIZE = len(self.c2i)
        TAG_VOCAB_SIZE = len(self.t2i)

        self.model = dy.Model()

        # Encoder
        self.enc_fwd_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, self.model)
        self.enc_bwd_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, self.model)

        # Decoder
        self.dec_lstm = dy.LSTMBuilder(LSTM_NUM_OF_LAYERS,
                                       2 * STATE_SIZE + EMBEDDINGS_SIZE + STATE_SIZE * 2,
                                       STATE_SIZE,
                                       self.model)

        # Attention
        self.attention_w1 = self.model.add_parameters((ATTENTION_SIZE, STATE_SIZE * 2))
        self.attention_w2 = self.model.add_parameters((ATTENTION_SIZE, STATE_SIZE * LSTM_NUM_OF_LAYERS * 2))
        self.attention_v = self.model.add_parameters((1, ATTENTION_SIZE))
        
        # Attention Context
        self.attention_w1_context = self.model.add_parameters((ATTENTION_SIZE, STATE_SIZE * 2))
        self.attention_w2_context = self.model.add_parameters((ATTENTION_SIZE, STATE_SIZE * LSTM_NUM_OF_LAYERS * 2))
        self.attention_v_context = self.model.add_parameters((1, ATTENTION_SIZE))

        # MLP - Softmax
        self.decoder_w = self.model.add_parameters((TAG_VOCAB_SIZE, STATE_SIZE))
        self.decoder_b = self.model.add_parameters((TAG_VOCAB_SIZE))

        # Lookups
        self.input_lookup = self.model.add_lookup_parameters((CHAR_VOCAB_SIZE, EMBEDDINGS_SIZE))
        self.output_lookup = self.model.add_lookup_parameters((TAG_VOCAB_SIZE, EMBEDDINGS_SIZE))

    def embed_word(self, word):
        return [self.input_lookup[char] for char in word]

    def run_lstm(self, init_state, input_vecs):
        s = init_state
        out_vectors = []
        for vector in input_vecs:
            s = s.add_input(vector)
            out_vector = s.output()
            out_vectors.append(out_vector)
        return out_vectors

    def encode_word(self, word):
        word_rev = list(reversed(word))
        fwd_vectors = self.run_lstm(self.enc_fwd_lstm.initial_state(), word)
        bwd_vectors = self.run_lstm(self.enc_bwd_lstm.initial_state(), word_rev)
        bwd_vectors = list(reversed(bwd_vectors))
        vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
        return vectors

    def convert2chars(self, ints):
        return [self.i2t[i] for i in ints]

    def compute_accuracy(self, gold, predicted, metric):
        result = 0.0
        if metric is "set_match":
            correct_out = 0.0
            for g in gold:
                if g in predicted:
                    correct_out += 1.0
            result = correct_out / len(gold)
        elif metric is "exact_match":
            if len(predicted) == len(gold):
                all_equal = True
                for g, p in zip(gold, predicted):
                    if g != p:
                        all_equal = False
                if all_equal:
                    result = 1.0
        else:
            raise ValueError("Undefined metric.")
        return result
    
    def generate(self, encoded, word_context):
        w = dy.parameter(self.decoder_w)
        b = dy.parameter(self.decoder_b)
        w1 = dy.parameter(self.attention_w1)
        
        w1_context = dy.parameter(self.attention_w1_context)

        input_mat = dy.concatenate_cols(encoded)
        input_context = dy.concatenate_cols(word_context)

        w1dt = None
        w1dt_context = None
        
        last_output_embeddings = self.output_lookup[self.t2i[EOS]]
        s = self.dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2),
                                                                    last_output_embeddings,
                                                                    dy.vecInput(STATE_SIZE * 2)]))

        out = []
        count_EOS = 0
        limit_features = 10
        for i in range(limit_features):
            if count_EOS == 2: break
            # w1dt can be computed and cached once for the entire decoding phase
            w1dt = w1dt or w1 * input_mat
            w1dt_context = w1dt_context or w1_context * input_context
            vector = dy.concatenate([self.attend(input_mat, s, w1dt),
                                     last_output_embeddings,
                                     self.attend_context(input_context, s, w1dt_context)])
            
            s = s.add_input(vector)
            out_vector = w * s.output() + b
            probs = dy.softmax(out_vector).vec_value()
            next_char = probs.index(max(probs))
            last_output_embeddings = self.output_lookup[next_char]
            if self.i2t[next_char] == EOS:
                count_EOS += 1
                continue
            out.append(next_char)
        return out

    def attend(self, input_mat, state, w1dt):
        w2 = dy.parameter(self.attention_w2)
        v = dy.parameter(self.attention_v)

        # input_mat: (encoder_state x seqlen) => input vecs concatenated as cols
        # w1dt: (attdim x seqlen)
        # w2dt: (attdim,1)
        w2dt = w2 * dy.concatenate(list(state.s()))
        # att_weights: (seqlen,) row vector
        # unnormalized: (seqlen,)
        unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
        att_weights = dy.softmax(unnormalized)
        # context: (encoder_state)
        context = input_mat * att_weights
        return context
    
    def attend_context(self, input_mat, state, w1dt_context):
        w2_context = dy.parameter(self.attention_w2_context)
        v_context = dy.parameter(self.attention_v_context)

        # input_mat: (encoder_state x seqlen) => input vecs concatenated as cols
        # w1dt: (attdim x seqlen)
        # w2dt: (attdim,1)
        w2dt_context = w2_context * dy.concatenate(list(state.s()))
        # att_weights: (seqlen,) row vector
        # unnormalized: (seqlen,)
        unnormalized = dy.transpose(v_context * dy.tanh(dy.colwise_add(w1dt_context, w2dt_context)))
        att_weights = dy.softmax(unnormalized)
        # context: (encoder_state)
        context = input_mat * att_weights
        return context

    def decode(self, vectors, decoder_seq, word_context):
        w = dy.parameter(self.decoder_w)
        b = dy.parameter(self.decoder_b)
        w1 = dy.parameter(self.attention_w1)
        
        w1_context = dy.parameter(self.attention_w1_context)
        input_mat = dy.concatenate_cols(vectors)
        input_context = dy.concatenate_cols(word_context)
        
        w1dt = None
        w1dt_context = None
        
        last_output_embeddings = self.output_lookup[self.t2i[EOS]]
        s = self.dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2),
                                                                    last_output_embeddings,
                                                                    dy.vecInput(STATE_SIZE * 2)]))
        loss = []

        for char in decoder_seq:
            # w1dt can be computed and cached once for the entire decoding phase
            w1dt = w1dt or w1 * input_mat
            w1dt_context = w1dt_context or w1_context * input_context
            vector = dy.concatenate([self.attend(input_mat, s, w1dt),
                                     last_output_embeddings,
                                     self.attend_context(input_context, s, w1dt_context)])
            s = s.add_input(vector)
            out_vector = w * s.output() + b
            probs = dy.softmax(out_vector)
            last_output_embeddings = self.output_lookup[char]
            loss.append(-dy.log(dy.pick(probs, char)))
        loss = dy.esum(loss)
        return loss

    def get_loss_entry(self, encoded, decoder_seq, word_context):
        return self.decode(encoded, decoder_seq, word_context)

    def train(self):
        trainer = dy.AdamTrainer(self.model)
        total_loss = 0
        entry_count = 0
        start = time.time()
        with open(self.conll_train, 'r') as conllFP:
            shuffled_data = list(read_conll(conllFP, self.c2i, self.t2i))
            if not self.mode:
                shuffled_data = shuffled_data[:10]
            random.shuffle(shuffled_data)
            for iSentence, sentence in enumerate(shuffled_data):
                conll_sentence = [entry for entry in sentence if isinstance(entry, ConllEntry)]
                dy.renew_cg()
                
                context = [] 
                embedded = self.embed_word([self.c2i[SW]])
                encoded = self.encode_word(embedded)
                context.append(encoded[-1])
                
                for entry in conll_sentence:
                    embedded = self.embed_word(entry.idChars)
                    encoded = self.encode_word(embedded)
                    entry.encoded_all_s = encoded
                    entry.encoded_last_s = encoded[-1]
                    context.append(entry.encoded_last_s)
                    
                for idx, entry in enumerate(conll_sentence):
                    word_context = [c for i, c in enumerate(context) if i != idx]
                    loss = self.get_loss_entry(entry.encoded_all_s, entry.decoder_gold_input, word_context)
                    loss_value = loss.value()
                    loss.backward()
                    trainer.update()
                    total_loss += loss_value
                    entry_count += 1
                    
                if iSentence % 500 == 0:
                    print(
                    "Sentence: {} Loss: {} Time: {}".format(iSentence, total_loss / (entry_count), time.time() - start))
                    start = time.time()

    def evaluate(self):
        count = 0
        correct_set = 0.0
        correct_exact = 0.0
        start = time.time()
        with open(self.conll_dev, 'r') as conllFP:
            for iSentence, sentence in enumerate(read_conll(conllFP, self.c2i, self.t2i)):
                if not self.mode:
                    if iSentence > 2:
                        break
                conll_sentence = [entry for entry in sentence if isinstance(entry, ConllEntry)]
               
                context = [] 
                embedded = self.embed_word([self.c2i[SW]])
                encoded = self.encode_word(embedded)
                context.append(encoded[-1])
                
                for entry in conll_sentence:
                    embedded = self.embed_word(entry.idChars)
                    encoded = self.encode_word(embedded)
                    entry.encoded_all_s = encoded
                    entry.encoded_last_s = encoded[-1]
                    context.append(entry.encoded_last_s)
                    
                for idx, entry in enumerate(conll_sentence):
                    word_context = [c for i, c in enumerate(context) if i != idx]
                    predicted_sequence = self.generate(entry.encoded_all_s, word_context)
                    correct_set += self.compute_accuracy(entry.decoder_gold_output, predicted_sequence, "set_match")
                    correct_exact += self.compute_accuracy(entry.decoder_gold_output, predicted_sequence, "exact_match")
                    count += 1
            score_set = float(correct_set) * 100 / count
            score_exact = float(correct_exact) * 100 / count
        print("Evaluation duration : {}".format(time.time() - start))
        return score_set, score_exact

    def run(self):
        num_epoch = 30
        highestExactScore = 0.0
        highestSetScore = 0.0
        set_eId = 0
        exact_eId = 0
        start = time.time()
        for epoch in range(num_epoch):
            print("--- epoch {} --- ".format(epoch + 1))
            self.train()
            score_set, score_exact = self.evaluate()
            print ("---Accuracy Set: {} Exact: {}".format(score_set, score_exact))
            if score_exact >= highestExactScore:
                highestExactScore = score_exact
                exact_eId = epoch + 1
            if score_set >= highestSetScore:
                highestSetScore = score_set
                set_eId = epoch + 1
            print ("Highest Exact: {} at epoch {}".format(highestExactScore, exact_eId))
            print ("Highest Set: {} at epoch {}".format(highestSetScore, set_eId))
            print ("Epoch: {} Total duration: {}".format(epoch+1, time.time() - start))
            start = time.time()
            
try:
    if len(sys.argv) > 1:
        experiment_mode = True if int(sys.argv[1]) is 1 else False
    else:
        experiment_mode = False
except ValueError:
    experiment_mode = False

learner = Learner(experiment_mode)
learner.run()


KeyboardInterrupt: 