In [1]:
import numpy as np
import random
from keras.layers import LSTM, Dense, Activation, Input, Lambda, Concatenate
from keras.models import Model
from keras import backend as K
from scipy.spatial import distance_matrix

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
len_sample = 12
dims = 32

In [3]:
document, summary = [], []
for i in range(len_sample):
    if i % 2 == 0:
        s = np.random.normal(loc = 0., scale = 0.2, size = (dims,))
        document.append(s)
    else:
        s = np.random.normal(loc = 1, scale = 0.1, size = (dims,))
        document.append(s)
        summary.append(s)
document = np.array(document)
summary = np.array(summary)

In [4]:
def build_model():
    d_lstm = 16
    inp = Input(shape=(None, dims))
    state_h = Input(shape=(d_lstm,))
    state_c = Input(shape=(d_lstm,))
    gen_summ = Input(shape=(None, dims))
    
    lstm = LSTM(d_lstm, activation = "tanh", name = "lstm_1", return_sequences=False, return_state=True)
    
    o1, lstm_state_h, lstm_state_c  = lstm(inp, initial_state = [state_h, state_c])
    o2, _, _  = lstm(gen_summ, initial_state = [state_h, state_c])
    
    diff = Lambda(lambda x: K.abs(x[0] - x[1]))([o1, o2])
    diff = Concatenate()([o1, o2, diff])
    o = Dense(2, activation = "linear")(diff)
    
    model_s = Model(inputs = [inp, state_h, state_c], outputs = [o1, lstm_state_h, lstm_state_c])
    model_q = Model(inputs = [inp, state_h, state_c, gen_summ], outputs = o)
    model_q.compile(optimizer = "adam", loss = "mse")
    
    return model_s, model_q

In [5]:
def act(logits, epsilon):
    if np.random.rand() <= epsilon:
        return random.randrange(2)
    else:
        return np.argmax(logits)

In [6]:
def get_reward(gen_summary, summary):
    acum_sims = 0.
    for i in range(len(gen_summary)):
        for j in range(len(summary)):
            acum_sims += int((gen_summary[i] == summary[j]).all())
    if len(gen_summary) == len(summary) and int(acum_sims) == len(summary):
        return 100
    return acum_sims / (len(gen_summary) + 1e-16 + 1)

In [7]:
n_episodes = 100000
epsilon = 1.0
epsilon_decay = 0.999
epsilon_min = 0.01
discount_factor = 0.99
TARGET_UPDATE = 1
model_s, model_q = build_model()
_, target_model_q = build_model()
target_model_q.set_weights(model_q.get_weights())

for i in range(n_episodes):
    # Leer el documento para sacar el cell_state, con eso se condicionan las futuras selecciones #
    c_state = np.zeros((1, 16)) + 1e-16
    h_state = np.zeros((1, 16)) + 1e-16
    gen_summary = [np.zeros(dims) + 1e-16]
    actions = []
    _, lstm_h_state, lstm_c_state = model_s.predict([np.array([document]), h_state, c_state])
    
    for j in range(len(document)):
        s = np.array([document[0 : j + 1]])
        next_s = np.array([document[0 : j + 2]])
        q = target_model_q.predict([s, lstm_h_state, lstm_c_state, np.array([gen_summary])])[0]
        next_gen_summary = gen_summary[:]
        
        if j < len(document)-1:
            a = act(q, epsilon)
            actions.append(a)
            if a == 1:
                next_gen_summary.append(document[j].tolist())
            next_q = model_q.predict([next_s, lstm_h_state, lstm_c_state, np.array([next_gen_summary])])[0]
            reward = get_reward(np.array(next_gen_summary[1:]), summary)
            td_target = reward + (discount_factor * np.amax(next_q))
            q[a] = td_target
            
        else:
            a = act(q, epsilon)     
            if a == 1:
                next_gen_summary.append(document[j].tolist())
            reward = get_reward(np.array(next_gen_summary[1:]), summary)
            q[a] = reward
        
        model_q.fit([s, h_state, c_state, np.array([gen_summary])], np.array([q]), verbose=0)
        gen_summary = next_gen_summary[:]
        
        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

    if i % TARGET_UPDATE == 0:
        target_model_q.set_weights(model_q.get_weights()) 
        
    print("With: %s, reward: %.3f" % (str(actions), reward))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
With: [0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0], reward: 0.400
With: [1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1], reward: 0.333
With: [1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0], reward: 0.500
With: [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0], reward: 0.667
With: [1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1], reward: 0.300
With: [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0], reward: 0.750
With: [0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0], reward: 0.500
With: [1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0], reward: 0.333
With: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1], reward: 0.333
With: [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], reward: 0.500
With: [0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1], reward: 0.429
With: [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1], reward: 0.400
With: [0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1], reward: 0.400
With: [0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0], reward: 0.625
With: [1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1], reward: 0.429
With: [1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0], reward: 0.500


With: [1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1], reward: 0.444
With: [1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1], reward: 0.400
With: [1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1], reward: 0.429
With: [1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0], reward: 0.333
With: [1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], reward: 0.500
With: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], reward: 0.500
With: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], reward: 0.333
With: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], reward: 0.333
With: [1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0], reward: 0.375
With: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], reward: 0.333
With: [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], reward: 0.400
With: [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0], reward: 0.500
With: [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0], reward: 0.500
With: [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1], reward: 0.500
With: [1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0], reward: 0.429
With: [1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0], reward: 0.429
With: [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], reward: 0.500
With: [1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1], reward: 0.444
With: [1, 

With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0], reward: 0.375
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1], reward: 0.455
With: [1, 

With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.500
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444
With: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0], reward: 0.444


KeyboardInterrupt: 

In [61]:
np.array(gen_summary)

array([[ 1.00000000e-16,  1.00000000e-16,  1.00000000e-16,
         1.00000000e-16,  1.00000000e-16,  1.00000000e-16,
         1.00000000e-16,  1.00000000e-16,  1.00000000e-16,
         1.00000000e-16,  1.00000000e-16,  1.00000000e-16,
         1.00000000e-16,  1.00000000e-16,  1.00000000e-16],
       [-1.95860523e-01, -1.48624463e-01,  2.05362670e-01,
        -3.56251706e-01,  3.48223227e-01,  2.25886253e-01,
         2.39534579e-01,  1.20167298e-01, -8.50210747e-02,
         6.33946391e-02,  6.65442656e-02,  1.96169176e-02,
         8.68051540e-02,  1.29136797e-01, -2.15384343e-02],
       [ 1.55930239e-01,  1.38189397e-02,  1.81326074e-01,
        -7.47221953e-03,  2.50736286e-01,  1.43381466e-01,
         3.53293954e-02, -2.20677282e-01,  1.83077639e-01,
         2.11772725e-03,  9.26315797e-02,  1.33916583e-01,
        -1.67737636e-01,  1.48920313e-01,  1.13809057e-01],
       [ 1.04235380e+00,  9.94427380e-01,  8.79743740e-01,
         9.69194678e-01,  1.06138919e+00,  8.63963548

In [62]:
summary

array([[0.89410697, 0.87822673, 0.81418227, 0.91827562, 0.86433149,
        1.02317608, 0.8692232 , 0.80247144, 0.97223869, 0.87627265,
        0.90806928, 1.04980662, 1.05057083, 1.02164273, 1.09206619],
       [0.85960949, 1.09164693, 0.99243014, 1.06337246, 1.04717202,
        1.01252072, 0.9654333 , 1.10102357, 1.05334094, 0.88277796,
        0.92554983, 0.90313714, 0.98633856, 0.95520528, 1.13562806],
       [1.0423538 , 0.99442738, 0.87974374, 0.96919468, 1.06138919,
        0.86396355, 1.07804401, 1.01316914, 0.92851321, 1.28972365,
        1.16725639, 0.86980464, 0.90320178, 0.88144851, 0.79701791]])

In [63]:
document

array([[-0.19586052, -0.14862446,  0.20536267, -0.35625171,  0.34822323,
         0.22588625,  0.23953458,  0.1201673 , -0.08502107,  0.06339464,
         0.06654427,  0.01961692,  0.08680515,  0.1291368 , -0.02153843],
       [ 0.89410697,  0.87822673,  0.81418227,  0.91827562,  0.86433149,
         1.02317608,  0.8692232 ,  0.80247144,  0.97223869,  0.87627265,
         0.90806928,  1.04980662,  1.05057083,  1.02164273,  1.09206619],
       [-0.00701799,  0.0058077 ,  0.33801603, -0.15885691,  0.08222441,
         0.02576878,  0.29602875,  0.21940847,  0.0065998 , -0.48437767,
        -0.00557872,  0.04545019, -0.02583858, -0.03731495,  0.06318176],
       [ 0.85960949,  1.09164693,  0.99243014,  1.06337246,  1.04717202,
         1.01252072,  0.9654333 ,  1.10102357,  1.05334094,  0.88277796,
         0.92554983,  0.90313714,  0.98633856,  0.95520528,  1.13562806],
       [ 0.15593024,  0.01381894,  0.18132607, -0.00747222,  0.25073629,
         0.14338147,  0.0353294 , -0.22067728, 