In [15]:
import numpy as np
import random
from keras.layers import LSTM, Dense, Activation, Input, Lambda, Concatenate
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from collections import deque

In [16]:
# 1º versión funcionando (A-C Experience Replay 1-batchsize) #

## Make synthetic problem

In [17]:
len_sample = 16
dims = 16
n_samples = 500
documents, summaries = [], []

for i in range(n_samples):
    document, summary = [], []
    for j in range(len_sample):
        rnd = random.randint(0, 1)
        if rnd == 0:
            s = np.random.normal(loc = 0., scale = 0.2, size = (dims,))
            document.append(s)
        else:
            s = np.random.normal(loc = 1, scale = 0.1, size = (dims,))
            document.append(s)
            summary.append(s)
            
    documents.append(document)
    summaries.append(summary)
    
documents = np.array(documents)
summaries = np.array(summaries)

## Actor Critic Agent

In [18]:
class A2CAgent:
    
    def __init__(self, doc_state_dim, summ_state_dim, action_dim = 2):
        
        self.doc_state_dim = doc_state_dim
        self.summ_state_dim = summ_state_dim
        self.action_dim = action_dim
        
        self.discount_factor = 0.99
        self.actor_lr = 0.001
        self.critic_lr = 0.005
  
        self.exploration_max = 1.0
        self.exploration_min = 0.01
        self.exploration_decay = 0.9995
        self.exploration_rate = self.exploration_max
        
        self.memory = deque(maxlen=MEMORY_SIZE)
        
        self.lstm_dims = 16
        self.reader, self.actor, self.critic = self.build_models()

        
    def build_models(self):
        
        doc_state = Input(shape=(None, dims))
        summ_state = Input(shape=(None, dims))
        state_h = Input(shape=(self.lstm_dims,))
        state_c = Input(shape=(self.lstm_dims,))
        
        lstm = LSTM(self.lstm_dims, activation = "tanh", name = "lstm_1", return_sequences=False, return_state=True)
        
        o1, lstm_state_h, lstm_state_c  = lstm(doc_state, initial_state = [state_h, state_c])
        o2, _, _  = lstm(summ_state, initial_state = [state_h, state_c])
    
        diff = Lambda(lambda x: K.abs(x[0] - x[1]))([o1, o2])
        diff = Concatenate()([o1, o2, diff])
        
        # Shared actor and critic #
        actor_output = Dense(self.action_dim, activation = "softmax")(diff)
        critic_output = Dense(1, activation="linear")(diff)
        
        reader = Model(inputs = [doc_state, state_h, state_c], outputs = [lstm_state_h, lstm_state_c])
        actor = Model(inputs = [doc_state, summ_state, state_h, state_c], outputs = actor_output)
        critic = Model(inputs = [doc_state, summ_state, state_h, state_c], outputs = critic_output)

        actor.compile(loss='categorical_crossentropy', optimizer=Adam(lr=self.actor_lr))
        critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
                       
        return reader, actor, critic
                        

    # using the output of policy network, pick action stochastically
    def get_action(self, doc_state, summ_state, state_h, state_c):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_dim)
        policy = self.actor.predict([doc_state, summ_state, state_h, state_c])[0]
        r = np.random.choice(self.action_dim, 1, p=policy)[0]
        return r

                       
    def remember(self, doc_state, summ_state, state_h, state_c, 
                 action, reward, next_doc_state, next_summ_state, 
                 done):
        
        self.memory.append((doc_state, summ_state, state_h, state_c,
                            action, reward, next_doc_state, next_summ_state,
                            done)) 
     
    # update policy network every episode (Memory Replay & Batchified)
    def train_model(self):
        if len(self.memory) < MEMORY_SIZE:
            return
        
        target = np.zeros(1,)
        advantage = np.zeros((self.action_dim,))
        
        rnd_idx = random.randint(0, len(self.memory)-1)
        (doc_state, summ_state, state_h, state_c,
        action, reward, next_doc_state, next_summ_state,
        done)= self.memory[rnd_idx]
        summ_state = np.array([summ_state])
        next_summ_state = np.array([next_summ_state])

        values = self.critic.predict([doc_state, summ_state, state_h, state_c])[0]     
        next_values = self.critic.predict([next_doc_state, next_summ_state, state_h, state_c])[0]
              
        
        for i in range(1):
            if done:
                advantage[action] = max(min(1., reward - values), 0.)
                target[0] = reward
            else:
                # Explicacion: https://lilianweng.github.io/lil-log/2018/05/05/implementing-deep-reinforcement-learning-models.html#actor-critic
                advantage[action] = max(min(1., reward + self.discount_factor * (next_values) - values), 0.)
                target[0] = reward + self.discount_factor * next_values
                
        self.exploration_rate *= self.exploration_decay
        self.exploration_rate = max(self.exploration_min, self.exploration_rate)
        
        advantage = np.array([advantage])
        target = np.array([target])

        self.actor.fit([doc_state, summ_state, state_h, state_c], advantage, epochs=1, batch_size = 1, verbose=0)
        self.critic.fit([doc_state, summ_state, state_h, state_c], target, epochs=1, batch_size = 1, verbose=0)

In [19]:
def get_reward(gen_summary, summary):
    acum_sims = 0.
    for i in range(len(gen_summary)):
        found = False
        for j in range(len(summary)):
            if (gen_summary[i] == summary[j]).all() == True:
                found = True
                break
        if found == True:
            acum_sims += 1
        else:
            acum_sims -= 1
    return acum_sims #/ max(len(gen_summary), len(summary))

In [20]:
EPISODES = 10
N_SAMPLES = n_samples
BATCH_SIZE = 1
MEMORY_SIZE = 1024

# get size of state and action from environment
doc_state_dim = summ_state_dim = dims
action_dim = 2

# make A2C agent
agent = A2CAgent(doc_state_dim, summ_state_dim, action_dim)


for e in range(EPISODES):

    scores = []
    
    for i in range(N_SAMPLES):
        document = documents[i]
        summary = summaries[i]
        
        # Leer doc por completo #
        init_c_state = np.zeros((1, 16)) + 1e-16
        init_h_state = np.zeros((1, 16)) + 1e-16
        lstm_h_state, lstm_c_state = agent.reader.predict([np.array([document]), init_h_state, init_c_state])

        summ_state = [np.zeros(summ_state_dim) + 1e-16]
        actions = []

        for j in range(len(document)):

            doc_state = np.array([document[0 : j + 1]])
            next_doc_state = np.array([document[0 : j + 2]])
            next_summ_state = summ_state[:]

            action = agent.get_action(doc_state, np.array([summ_state]), lstm_h_state, lstm_c_state)
            
            if action == 1:
                next_summ_state.append(document[j].tolist())
                actions.append(1)
            else:
                actions.append(0)

            reward = get_reward(np.array(next_summ_state[1:]), summary)

            if j < len(document) - 1:
                done = False
            else:
                done = True

            agent.remember(doc_state, summ_state, lstm_h_state, lstm_c_state, 
                           action, reward, next_doc_state, next_summ_state, done)

            agent.train_model()

            summ_state = next_summ_state
        scores.append(reward)
        #print("Sample: %d, Actions: %s, Reward: %.4f" % (i, str(actions), reward))
    print("Avg reward on episode: %d -> %.3f" % (e, np.array(scores).mean()))

Avg reward on episode: 0 -> 1.912
Avg reward on episode: 1 -> 5.740
Avg reward on episode: 2 -> 5.956
Avg reward on episode: 3 -> 6.568
Avg reward on episode: 4 -> 6.960
Avg reward on episode: 5 -> 6.652
Avg reward on episode: 6 -> 6.766
Avg reward on episode: 7 -> 6.972
Avg reward on episode: 8 -> 6.910
Avg reward on episode: 9 -> 7.260


In [23]:
document = documents[4]
summary = summaries[4]

# Leer doc por completo #
init_c_state = np.zeros((1, 16)) + 1e-16
init_h_state = np.zeros((1, 16)) + 1e-16
lstm_h_state, lstm_c_state = agent.reader.predict([np.array([document]), init_h_state, init_c_state])

summ_state = [np.zeros(summ_state_dim) + 1e-16]
actions = []

for j in range(len(document)):

    doc_state = np.array([document[0 : j + 1]])
    next_doc_state = np.array([document[0 : j + 2]])
    next_summ_state = summ_state[:]

    action = agent.get_action(doc_state, np.array([summ_state]), lstm_h_state, lstm_c_state)

    if action == 1:
        next_summ_state.append(document[j].tolist())
        actions.append(1)
    else:
        actions.append(0)

    reward = get_reward(np.array(next_summ_state[1:]), summary)

    if j < len(document) - 1:
        done = False
    else:
        done = True
    
    summ_state = next_summ_state

In [24]:
actions

[0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0]

In [25]:
document

array([[-0.0926033 , -0.18200941, -0.13242487,  0.38948635,  0.06211748,
        -0.45572787, -0.06779616, -0.21168341,  0.08835016, -0.23462003,
        -0.1235629 , -0.05787389, -0.05209931,  0.08730754, -0.09298691,
         0.27572553],
       [ 0.32847872,  0.13556859, -0.15419601,  0.04420499,  0.0993616 ,
         0.05395936, -0.04386647, -0.20157574, -0.0125658 , -0.14496135,
         0.36515771, -0.1773753 ,  0.12540683,  0.21266409,  0.04859447,
        -0.13192555],
       [ 1.03383392,  0.97858038,  1.06356421,  0.94733702,  0.97512407,
         0.97770182,  1.0662604 ,  1.13132905,  1.09529157,  0.94120605,
         0.98215044,  0.85283281,  0.9238479 ,  1.15825165,  0.94103274,
         0.97156026],
       [ 0.98122508,  1.03821773,  0.9054462 ,  0.89277277,  0.96816157,
         1.05397734,  1.18913346,  0.84027841,  1.08830363,  0.92971897,
         1.10815857,  1.0450787 ,  1.12114833,  0.96708264,  0.96321032,
         1.01580441],
       [-0.19205855, -0.34768088, -0

In [26]:
summary

[array([1.03383392, 0.97858038, 1.06356421, 0.94733702, 0.97512407,
        0.97770182, 1.0662604 , 1.13132905, 1.09529157, 0.94120605,
        0.98215044, 0.85283281, 0.9238479 , 1.15825165, 0.94103274,
        0.97156026]),
 array([0.98122508, 1.03821773, 0.9054462 , 0.89277277, 0.96816157,
        1.05397734, 1.18913346, 0.84027841, 1.08830363, 0.92971897,
        1.10815857, 1.0450787 , 1.12114833, 0.96708264, 0.96321032,
        1.01580441]),
 array([0.93718471, 0.75003822, 0.94647728, 1.2178097 , 0.95452597,
        1.12183398, 0.80789445, 1.02298515, 1.11002949, 0.98113608,
        0.97309826, 1.02898817, 1.07190151, 1.01635855, 0.92094405,
        1.07269246]),
 array([1.03581015, 1.14293062, 1.22965493, 0.956355  , 0.87186639,
        1.04662282, 0.98877225, 0.89037612, 1.03423343, 1.08351568,
        1.02447254, 1.02152896, 0.98753297, 0.97952194, 1.08474729,
        1.00228203]),
 array([0.9986617 , 1.0729275 , 0.89186723, 0.8971852 , 1.08159337,
        0.78105736, 0.920450

In [27]:
np.array(summ_state[1:])

array([[1.03383392, 0.97858038, 1.06356421, 0.94733702, 0.97512407,
        0.97770182, 1.0662604 , 1.13132905, 1.09529157, 0.94120605,
        0.98215044, 0.85283281, 0.9238479 , 1.15825165, 0.94103274,
        0.97156026],
       [0.98122508, 1.03821773, 0.9054462 , 0.89277277, 0.96816157,
        1.05397734, 1.18913346, 0.84027841, 1.08830363, 0.92971897,
        1.10815857, 1.0450787 , 1.12114833, 0.96708264, 0.96321032,
        1.01580441],
       [0.93718471, 0.75003822, 0.94647728, 1.2178097 , 0.95452597,
        1.12183398, 0.80789445, 1.02298515, 1.11002949, 0.98113608,
        0.97309826, 1.02898817, 1.07190151, 1.01635855, 0.92094405,
        1.07269246],
       [1.03581015, 1.14293062, 1.22965493, 0.956355  , 0.87186639,
        1.04662282, 0.98877225, 0.89037612, 1.03423343, 1.08351568,
        1.02447254, 1.02152896, 0.98753297, 0.97952194, 1.08474729,
        1.00228203],
       [0.9986617 , 1.0729275 , 0.89186723, 0.8971852 , 1.08159337,
        0.78105736, 0.92045049, 

In [28]:
get_reward(summ_state[1:], summary)

6.0