In [6]:
import tensorflow as tf
import numpy as np

In [8]:
class model(object):
    
    def __init__(self, ckpt_path, lr, epochs=1000, dim_lang=5, model_name="MLPM"):
        self.epochs = epochs
        self.ckpt_path = ckpt_path
        self.model_name = model_name
        self.lr = lr
        self.dim_lang = dim_lang
        
        self.all_lang_rep = dict()
        
        for lang in ["en", "pt", "es"]:
            self.all_lang_rep[lang] = tf.Variable(tf.truncated_normal([1,dim_lang]),\
                                               stddev=1/tf.sqrt(dim_lang), name='lang_encoder'+lang)
        
        def __graph__():
            tf.reset_default_graph()
            self.graph = tf.Graph()
            with self.graph.as_default():
                
                # source and target vector representation
                self.source_words = tf.placeholder(tf.float32, shape=[None, 300])
                self.target_words = tf.placeholder(tf.float32, shape=[None, 300])
                
                # parameter matrices
                self.encoder = tf.Variable(tf.truncated_normal([300, 300]),\
                                           stddev=1/tf.sqrt(300.0), name='encoder')
                self.decoder = tf.Variable(tf.truncated_normal([300, 300]),\
                                           stddev=1/tf.sqrt(300.0), name='decoder')
                
                self.lang_encoder = tf.Variable(tf.truncated_normal([2*lang_dim,300]),\
                                               stddev=1/tf.sqrt(300), name='lang_encoder')

                # language representation
                self.lang_rep = tf.Variable(tf.truncated_normal([1,2*self.dim_lang]),\
                                            stddev=1/tf.sqrt(self.dim_lang), name='lang_rep')
                
                # model equation
                self.target_pred = self.get_model(self.encoder, self.decoder, self.source_words,\
                                                  self.lang_rep, self.lang_encoder)
                
                #squared loss
                self.loss = tf.reduce_sum(tf.square(self.target_words-self.target_pred))
                
                self.train_step = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)
                
                self.init = tf.global_variables_initializer()
                self.saver = tf.train.Saver()
        
        print('start building graph')
        __graph__()
        print('graph built')
        
    # encoder, decoder: 300x300   source_words: Nonex300    lang_rep: 1x2*dim_lang    lang_encoder: 2*dim_langx300
    # Output: Nonex300
    def get_model(self, encoder, decoder, source_words, lang_rep, lang_encoder):
        shared_source_words = tf.matmul(source_words, encoder)
        shared_lang_rep = tf.matmul(lang_rep, lang_encoder)
        
        num_examples = tf.shape(shared_source_words)[0]
        
        shared_lang_aux = tf.reshape(tf.tile(shared_lang_rep, [num_examples]), [num_examples, 300])
        shared_embedding_vector = shared_lang_rep+shared_lang_aux
        
        return tf.matmul(shared_embedding_vector, decoder)
    
    def get_feed(self, X, Y, src_lang, dest_lang):
        feed_dict = {self.source_words: X, self.target_words: Y}
        self.lang_rep = tf.concat([self.all_lang_rep[src_lang], self.all_lang_rep[dest_lang]], axis=1)
        return feed_dict
    
    #train -> train[en_pt]
    def train(self, train, train_lang_pairs, batch_size, validation, num_epochs=10, sess=None):
        if sess == None:
            sess = tf.Session(graph = self.graph)
            sess.run(self.init)
        
        #max data for a language pair/batch size
        max_data_size = max([len(x) for x in train])/batch_size
    
        for epoch in num_epochs:
            batch_index = np.zeros(shape=(len(train),1))
            for batch_number in range(max_data_size):
                #make batches
                for lang_pair_data_enum in enumerate(train):
                    i, lang_pair_data = lang_pair_data_enum

                    cur_batch = lang_pair_data[batch_index[i]:batch_index[i]+batch_size,:]
                    batch_index[i] += batch_size

                    X = cur_batch[:,0]
                    Y = cur_batch[:,1]
                    lang = train_lang_pairs[i].split(' ')

                    #batch index more that lang pair data length not handled :P
                    
                    _, train_loss = sess.run([self.train_step, self.loss], self.get_feed(X, Y, lang[0], lang[1]))

                    self.all_lang_rep[lang[0]], self.all_lang_rep[lang[1]] = \
                        tf.split(self.lang_rep, 2);
                    print("Batch:" + str(batch_number))
                    print("Loss:" + str(train_loss))
                    print(lang)
                    print("------------------------------")
            
            #save epoch
            if epoch and epoch%10==0:
                self.saver.save(sess, self.ckpt_path+self.model_name+".ckpt", global_step=epoch)
                
                #print losses #todo: format for train data with lang
        self.saver.save(sess, self.ckpt_path+self.model_name+".ckpt", global_step=num_epochs+1)
        np.savetxt()



In [5]:
a=[1,2,3,4]
a[1:7]

[2, 3, 4]