# initializations

In [1]:
import networkx as nx
from graph_embeddings.data_loader import DataLoader
import os
import pickle
import numpy as np
from random import choice
import torch
import torch.nn as nn
import torch.nn.functional as F

import gym
import matplotlib.pyplot as plt
from copy import copy
from enum import Enum
from collections import deque
import torch.optim as optim
from random import random
from copy import deepcopy
from tqdm import tqdm
from operator import add


# observing dataset 

In [2]:
data_loader = DataLoader(dataset='MetaQA', reverse_rel=True)

In [3]:
entities_voc, relations_voc = data_loader.load_entity_relations_vocab()
entities_inv_voc = {v: k for k, v in entities_voc.items()}

In [4]:
relations_inv_voc = {v: k for k, v in relations_voc.items()}
relations_voc


{'in_language_reverse': 0,
 'release_year': 1,
 'has_genre': 2,
 'has_tags_reverse': 3,
 'directed_by_reverse': 4,
 'written_by': 5,
 'written_by_reverse': 6,
 'has_imdb_rating_reverse': 7,
 'starred_actors': 8,
 'in_language': 9,
 'release_year_reverse': 10,
 'starred_actors_reverse': 11,
 'has_imdb_votes_reverse': 12,
 'has_tags': 13,
 'directed_by': 14,
 'has_genre_reverse': 15,
 'has_imdb_votes': 16,
 'has_imdb_rating': 17}

In [5]:
#entities_voc

# creating embeddings from KGE

## loading trained kge model

In [6]:
from kg_env import load_kge_model
import torch

model = load_kge_model(
    dataset_name='MetaQA',
    model_name='TuckER',
    ent_vec_dim=200,
    rel_vec_dim=200,
    loss_type='CE',
    device=('cuda' if torch.cuda.is_available() else 'cpu'),
    path='TuckER_MetaQA',
    input_dropout=0.3,
    hidden_dropout1=0.4,
    hidden_dropout2=0.5,
    l3_reg=0.2,
)

operating on cpu
building tucker model for embedding generation


In [14]:
# 'Adrian Moat': 970,
#'Adrian Pasdar': 971,
# 'Adrian Rawlins': 972,
# 'Adrian Shergold': 973

model.E(torch.Tensor([970, 971, 972]).long()).shape


torch.Size([3, 200])

## generating embeddings

In [24]:
from kg_env import generate_entity_embeddings
import pickle

ent_emb_dict = generate_entity_embeddings(model, entities_voc)
rel_emb_dict = generate_entity_embeddings(model, relations_voc)

with open('data/MetaQA/entity_emb.pickle', 'wb') as f:
    pickle.dump(ent_emb_dict, f)
    

with open('data/MetaQA/relation_emb.pickle', 'wb') as f:
    pickle.dump(rel_emb_dict, f)
    

    


## load embeddings

In [8]:
with open('data/MetaQA/entity_emb.pickle', 'rb') as f:
    ent_emb_dict = pickle.load(f)
    

with open('data/MetaQA/relation_emb.pickle', 'rb') as f:
    rel_emb_dict = pickle.load(f)
    

In [9]:
ent_emb_dict[list(ent_emb_dict.keys())[0]].shape

(200,)

## generating question embeddings

## analyze qa dataset

In [10]:
def load_raw_qa(path, mode='train_1hop'):
    with open(os.path.join(path, f'qa_{mode}.txt'), 'r') as f:
        text = f.read().strip().split('\n')
    return text

qa_train_raw = load_raw_qa('./data/QA_data/MetaQA/')

In [11]:
import re


def extract_question_entity_target(raw_questions):
    all_questions = []
    all_entities = []
    all_targets = []

    for raw_q in raw_questions:
        question, targets = raw_q.split('\t')
        entity = re.findall('\[.*?\]', question)[0] \
            .replace('[', '') \
            .replace(']', '')

        # todo: should I replace entity with some special token?
        question = question.replace(']', '').replace('[', '')
        targets = targets.strip().split('|')
        all_questions.append(question)
        all_targets.append(targets)
        all_entities.append(entity)

    return all_questions, all_entities, all_targets


qa_train_raw = load_raw_qa('./data/QA_data/MetaQA/', mode='train_3hop')
questions, _, _ = extract_question_entity_target(qa_train_raw)
print(max([len(s) for s in questions]))

151


In [47]:
from typing import List, Dict
from transformers import BertTokenizer, BertModel
import torch
from tqdm import tqdm


device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODELS_CACHE_PATH = '/projects/academic/kjoseph/navid/models/cache'


def tokenize_sentences(sentences: List[str], tokenizer_path=MODELS_CACHE_PATH, max_len=160, batch_size=128):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=MODELS_CACHE_PATH)
    tokenizer_results = []
    pbar = tqdm(total=len(sentences)//batch_size)
    i = 0
    while i < len(sentences):
        batch = sentences[i:i+batch_size]
        tokenizer_res = tokenizer(
            batch, return_tensors='pt', max_length=max_len, truncation=True, padding='max_length')
        tokenizer_results.append(tokenizer_res)
        i += batch_size
        pbar.update(1)
    return tokenizer_results

def get_batch_embeddings(tokenizer_results):
    i = 0
    embeddings = []
    bert = BertModel.from_pretrained('bert-base-uncased', cache_dir=MODELS_CACHE_PATH).to(device)
    bert.eval()
    
    for tokenizer_dict in tqdm(tokenizer_results):
        for key in tokenizer_dict:
            tokenizer_dict[key] = tokenizer_dict[key].to(device)
        
        with torch.no_grad():
            outputs = bert(**tokenizer_dict)
        embeddings.append(outputs.pooler_output.detach().cpu().numpy())
    
    return embeddings



In [None]:
tokenizer_res = tokenize_sentences(questions)
print(tokenizer_res['input_ids'].shape)

In [48]:
embeddings = get_batch_embeddings(tokenizer_res, batch_size=1)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).





  0%|                                                                                                                                                                             | 0/14274 [00:00<?, ?it/s][A[A[A[A[A




  0%|                                                               

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0486,  0.0761, -0.0288,  ..., -0.1312, -0.0418, -0.0443],
         [ 0.0131, -0.1777, -0.0614,  ..., -0.3306,  0.1801,  0.0568],
         [ 0.0847,  0.1649,  0.0865,  ...,  0.2064, -0.0259, -0.0863],
         ...,
         [ 0.0565,  0.0920,  0.0201,  ..., -0.1220, -0.0156, -0.0208],
         [ 0.0565,  0.0920,  0.0201,  ..., -0.1220, -0.0156, -0.0208],
         [ 0.0565,  0.0920,  0.0201,  ..., -0.1220, -0.0156, -0.0208]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-1.1188e-02, -2.0919e-01, -2.2745e-01, -1.1578e-01,  1.6160e-01,
          1.9390e-01,  2.5668e-01, -7.2665e-02, -5.9200e-02, -1.7820e-01,
          2.3287e-01, -1.2546e-03, -8.6979e-02,  8.9313e-02, -1.3310e-01,
          5.0997e-01,  2.2478e-01, -4.6840e-01,  7.0964e-03, -3.6931e-02,
         -2.4895e-01,  5.7112e-02,  4.6074e-01,  2.9635e-01,  1.3442e-01,
          7.6506e-02, -1.2433e-01, -3.5329e-02,  1.8925e-01,  2.2929






  0%|                                                                                                                                                                   | 2/14274 [00:01<3:06:13,  1.28it/s][A[A[A[A[A

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0483,  0.0658, -0.0396,  ..., -0.0925, -0.0512, -0.0151],
         [ 0.0408, -0.3234, -0.2024,  ..., -0.2671,  0.1068, -0.0395],
         [ 0.0705, -0.1757, -0.1514,  ..., -0.5613, -0.1411, -0.0526],
         ...,
         [ 0.0405, -0.0021,  0.0159,  ..., -0.0421, -0.0428,  0.0964],
         [ 0.0405, -0.0021,  0.0159,  ..., -0.0421, -0.0428,  0.0964],
         [ 0.0405, -0.0021,  0.0159,  ..., -0.0421, -0.0428,  0.0964]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-3.0601e-03, -2.1779e-01, -2.0223e-01, -9.0303e-02,  1.1932e-01,
          1.8957e-01,  2.5946e-01, -7.7788e-02, -7.1099e-02, -1.6559e-01,
          2.2084e-01, -1.8129e-02, -1.0156e-01,  8.3201e-02, -1.3346e-01,
          4.9150e-01,  2.0551e-01, -4.5526e-01,  2.3696e-02, -3.3467e-02,
         -2.4801e-01,  6.0791e-02,  4.5524e-01,  3.0210e-01,  1.1699e-01,
          7.6753e-02, -1.2701e-01, -1.6016e-02,  1.9760e-01,  2.1011






  0%|                                                                                                                                                                   | 3/14274 [00:02<3:01:31,  1.31it/s][A[A[A[A[A

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0443,  0.0958, -0.0258,  ..., -0.1095, -0.0524, -0.0792],
         [ 0.0289, -0.2903, -0.0484,  ..., -0.3300,  0.1950, -0.2918],
         [ 0.3139,  0.2377,  0.0223,  ..., -0.1324,  0.3695, -0.1022],
         ...,
         [-0.0213, -0.0294,  0.0015,  ...,  0.0400, -0.0248, -0.0330],
         [-0.0213, -0.0294,  0.0015,  ...,  0.0400, -0.0248, -0.0330],
         [-0.0213, -0.0294,  0.0015,  ...,  0.0400, -0.0248, -0.0330]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-1.1580e-02, -2.2915e-01, -2.1842e-01, -1.1189e-01,  1.5850e-01,
          2.0361e-01,  2.5357e-01, -1.0500e-01, -8.0442e-02, -1.5308e-01,
          2.3498e-01, -4.6535e-02, -9.5518e-02,  9.9269e-02, -1.4516e-01,
          4.9914e-01,  2.1731e-01, -4.7995e-01,  1.4769e-02, -3.3811e-02,
         -2.4594e-01,  4.8498e-02,  4.6983e-01,  3.0698e-01,  1.1170e-01,
          7.3765e-02, -1.1803e-01, -2.7452e-02,  2.2611e-01,  2.2226






  0%|                                                                                                                                                                   | 4/14274 [00:03<3:00:54,  1.31it/s][A[A[A[A[A

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0427,  0.0858, -0.0220,  ..., -0.1157, -0.0512, -0.0309],
         [ 0.0114, -0.2390, -0.0557,  ..., -0.3332,  0.1852,  0.0751],
         [ 0.0879,  0.1554,  0.1251,  ...,  0.1847, -0.0294, -0.0489],
         ...,
         [-0.0561,  0.0998, -0.0271,  ..., -0.2057,  0.0201, -0.1525],
         [-0.0561,  0.0998, -0.0271,  ..., -0.2057,  0.0201, -0.1525],
         [-0.0561,  0.0998, -0.0271,  ..., -0.2057,  0.0201, -0.1525]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-1.3738e-02, -2.0784e-01, -2.1963e-01, -1.1118e-01,  1.5323e-01,
          1.8332e-01,  2.5949e-01, -7.3172e-02, -5.0565e-02, -1.7314e-01,
          2.2505e-01, -3.7901e-05, -8.0733e-02,  7.8247e-02, -1.2294e-01,
          5.1120e-01,  2.2906e-01, -4.6926e-01,  1.0668e-02, -3.9428e-02,
         -2.3985e-01,  7.2482e-02,  4.5569e-01,  2.9660e-01,  1.3645e-01,
          7.0573e-02, -1.2183e-01, -3.9758e-02,  1.9024e-01,  2.2503




In [49]:
embeddings[0].shape

(1, 768)

## gather all questions for embeddings generation

In [10]:
all_questions = []

for ds in ['train', 'dev', 'test']:
    for nhop in ['1', '2', '3']:
        raw = load_raw_qa('./data/QA_data/MetaQA/', mode=f'{ds}_{nhop}hop')
        questions, _, _ = extract_question_entity_target(raw)
        all_questions.extend(questions)

print(len(all_questions))

746105


In [11]:
all_questions[100]

'which movies can be described by robert schwentke'

In [56]:
with open('all-questions.pickle', 'wb') as f:
    pickle.dump(all_questions, f)

## tokenize and generate embeddings

In [None]:
tokenizer_res = tokenize_sentences(all_questions, batch_size=256)
embeddings = get_batch_embeddings(tokenizer_res)


with open('./embeddings.pickle', 'wb') as f:
    pickle.dump(embeddings, f)

## load question embeddings

In [12]:
with open('q-emb-dict.pickle', 'rb') as f:
    question_embs = pickle.load(f)

In [13]:
question_embs['which movies can be described by robert schwentke'].shape

(768,)

In [74]:
np.min(np.array(list(question_embs.values())))

-1.0

# Knowledge graph

In [19]:
graph = data_loader.build_graph()

In [20]:
graph.adj[4]

AtlasView({767: {'relation_id': 0, 'relation_text': 'directed_by'}, 7125: {'relation_id': 10, 'relation_text': 'has_genre'}, 184: {'relation_id': 2, 'relation_text': 'release_year'}})

In [21]:
# same as graph.adj
graph[1024]

AtlasView({1121: {'relation_id': 12, 'relation_text': 'written_by'}, 167: {'relation_id': 2, 'relation_text': 'release_year'}, 15595: {'relation_id': 6, 'relation_text': 'in_language'}, 31150: {'relation_id': 0, 'relation_text': 'directed_by'}, 40514: {'relation_id': 14, 'relation_text': 'has_tags'}})

In [22]:
graph.nodes[1024]

{'entity_text': 'has_tags_reverse'}

In [23]:
for nbr, eattr in graph[5].items():
    print(nbr, eattr)

37824 {'relation_id': 12, 'relation_text': 'written_by'}
165 {'relation_id': 2, 'relation_text': 'release_year'}
28874 {'relation_id': 0, 'relation_text': 'directed_by'}


## moving in graph

In [24]:
def traverse(entity_name, relation):
    entity_id = entities_voc[entity_name]
    neighbors = graph[entity_id]
    for nbr_id, eattr in neighbors.items():
        if eattr['relation_text'] == relation:
            return entities_inv_voc[nbr_id]
        
traverse('After the Rain', 'written_by')

'Akira Kurosawa'

## creating knowledge graph environment

In [36]:
import gym
from gym import spaces

class KGEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, n_relations, observation_space, q_embeddings, ent_embeddings, rel_embeddings,
                 entities_vocab, entities_inv_vocab, rel_inv_voc, rel_voc, mode='train', nhops=[1,2,3]):
        super(KGEnv, self).__init__()
        
        self.entities_vocab = entities_vocab
        self.entities_inv_vocab = entities_inv_vocab
        self.relations_inv_voc = rel_inv_voc
        self.action_space = spaces.Discrete(n_relations)
        self.rel_embeddings = rel_embeddings
        self.observation_space = spaces.Box(low=-3, high=3, shape=(observation_space,))
        self.q_embeddings = q_embeddings
        self.ent_embeddings = ent_embeddings
        self.relations_voc = rel_voc
        self.questions = []
        self.entities = []
        self.targets = []
        
        for nhop in [str(h) for h in nhops]:
            raw = load_raw_qa('./data/QA_data/MetaQA/', mode=f'{mode}_{nhop}hop')
            questions, entities, targets = extract_question_entity_target(raw)
            self.questions.extend(questions)
            self.entities.extend(entities)
            self.targets.extend(targets)
            
        print(f"Total env size in mode {mode}: {len(self.targets)}")

        self.entities_idx = [entities_voc[t] for t in self.entities]
        self.targets_idx = [[entities_voc[t] for t in ts] for ts in self.targets]
        self.env_size = len(self.targets)
        self.goal = None
        self.hops = 0
        self.current_q = self.current_ent = None
        self.ent_hist = []
        self.rel_hist = []
    
    def step(self, action):
        self.hops += 1
        next_ent_id = self.traverse(self.current_ent, action)
        if next_ent_id == -1:
            next_ent_id = self.entities_vocab[self.current_ent]
            
        reward = 0
        done = False
        
        #TODO:
        if next_ent_id in self.goal:
            done = True
            reward = 1
            
        if self.hops > 3:
            done = True
            
        e2_emb = self.ent_embeddings[next_ent_id]
        e1_emb = self.ent_embeddings[self.entities_vocab[self.current_ent]]
        r_emb = self.rel_embeddings[action]
        
#         self.rel_hist = r_emb if self.rel_hist is None else (r_emb + self.rel_hist)/2
#         self.ent_hist = (e1_emb + e2_emb)/2
        
        self.rel_hist.append(r_emb)
        self.ent_hist.append(e2_emb)
        
        self.current_ent = self.entities_inv_vocab[next_ent_id]
        next_state = self.build_state_vec(self.current_q, next_ent_id, self.ent_hist, self.rel_hist)
        
        return next_state, reward, done, {}
    
    
    def build_state_vec(self, question: str, entity: int, ent_hist, rel_hist):
        q_emb = self.q_embeddings[question]
        e_emb = self.ent_embeddings[entity]
        
        ent_track = len(ent_hist)
        ent_hist_emb = sum(ent_hist)/ent_track
        
        rel_track = len(rel_hist)
        rel_hist_emb = sum(rel_hist)/rel_track
        
        return np.concatenate([q_emb, e_emb, ent_hist_emb, rel_hist_emb])
        
    def traverse(self, entity_name, relation):
        entity_id = entities_voc[entity_name]
        neighbors = graph[entity_id]
        possible_moves = []
        for nbr_id, eattr in neighbors.items():
            if eattr['relation_id'] == relation:
                possible_moves.append(nbr_id)
        
        if len(possible_moves):
            return choice(possible_moves)
        return -1
        
    
    def reset(self):
        self.hops = 0
        self.rel_hist = [np.zeros(200)]
        self.ent_hist = []
        init_state = np.random.randint(low=0, high=self.env_size-1)
        self.goal = self.targets_idx[init_state]
        self.current_q = self.questions[init_state]
        self.current_ent = self.entities[init_state]
        e_emb = self.ent_embeddings[self.entities_idx[init_state]]
        self.ent_hist.append(e_emb)
        return self.build_state_vec(self.current_q, self.entities_idx[init_state], self.ent_hist, self.rel_hist)
    
    
    def render(self, mode='human', close=False):
        print(self.current_q)
        print(f"current entity: {self.current_ent}")
        entity_id = entities_voc[self.current_ent]
        neighbors = graph[entity_id]
        
        print("available moves:")
        for nbr_id, eattr in neighbors.items():
            print(f"--{eattr['relation_text']}({eattr['relation_id']})--> {self.entities_inv_vocab[nbr_id]}")
    
    
kg_env = KGEnv(len(relations_voc), observation_space=768+200+200+200, q_embeddings=question_embs,
               ent_embeddings=ent_emb_dict, rel_embeddings=rel_emb_dict, entities_vocab=entities_voc, entities_inv_vocab=entities_inv_voc,
                rel_inv_voc=relations_inv_voc, rel_voc=relations_voc)
s = kg_env.reset()
s.shape

Total env size in mode train: 667874


(1368,)

In [371]:
def manual_play_in_env():
    state = kg_env.reset()
    done = False
    
    while not done:
        kg_env.render()        
        action = int(input())
        state, reward, done, _ = kg_env.step(action)
        
manual_play_in_env()

when was the film Be with Me released
current entity: Be with Me
available moves:
--written_by(11)--> Eric Khoo
--written_by(11)--> Theresa Poh Lin Chan
--release_year(16)--> 2005
--has_genre(13)--> Drama
--in_language(8)--> English
16


# RL agent definition

## QActor-critic

In [37]:

class QActorCritic:
    def __init__(self, tau, buffer_size, gamma, actor_lr, critic_lr, state_size, action_size, batch_size,
                 device, warmup_step=1000, verbose=0, critic_warmup=5000, eps=0.5, min_eps=0.05):

        self.max_buffer_size = buffer_size
        self.gamma = torch.tensor(gamma)
        
        self.buffer = deque()
        self.action_size = action_size
        self.state_size = state_size
        self.critic_warmup = critic_warmup
        self.batch_size = batch_size
        self.device = device
        self.verbose = verbose
        self.tau = tau
        self.min_eps = min_eps
        self.eps = eps
        
        self.q = FCNet(state_size, action_size)
        self.q_target = FCNet(state_size, action_size)
        self.q.to(self.device)
        self.q_target.to(self.device)
        self.q_target.load_state_dict(self.q.state_dict())
        self.critic_opt = optim.Adam(self.q.parameters(), lr=critic_lr)
        
        self.actor = FCNet(state_size, action_size, last_activation='softmax')
        self.actor_target = FCNet(state_size, action_size, last_activation='softmax')
        self.actor.to(self.device)
        self.actor_target.to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        
        
        self.update_step = 0
        self.warmup_step = warmup_step
        
        
        
    def get_q_val(self, qnet, state, action=None):
        if action is None:
            return qnet(state)
        return qnet(state).gather(1, action.reshape(-1,1).long())
        
    def add_to_buffer(self, experience):
        if len(self.buffer) >= self.max_buffer_size:
            self.buffer.pop()
        self.buffer.appendleft(experience)
        
    def _np_to_tensor(self, nparrs):
        return [torch.from_numpy(arr).float().to(self.device) for arr in nparrs]
        
    def sample_minibatch(self):
        idxs = np.random.randint(0, len(self.buffer), self.batch_size)
        sample = [self.buffer[i] for i in idxs]
        
        states, actions, rewards, next_states, dones = list(map(lambda x: np.array(x, dtype=np.float64), zip(*sample)))
        return tuple(self._np_to_tensor([states, actions, rewards, next_states, dones]))
        
    def get_action(self, state, det=False):
        """
        if det=False it will sample based on output probabilities of policy function
        """
#         state = torch.from_numpy(state.reshape(-1, self.state_size)).float().to(self.device)
        with torch.no_grad():
            if not det:
                dist = Categorical(self.policy(state))
                actions = dist.sample()
            else:
                actions = torch.argmax(self.policy(state), dim=1)
            
            return actions
        
    def get_target_action(self, state, det=False):
        with torch.no_grad():
            if not det:
                dist = Categorical(self.actor_target(state))
                actions = dist.sample()
            else:
                actions = torch.argmax(self.actor_target(state), dim=1)
            
            return actions
        
    def get_single_action(self, state, det=False):
        if self.update_step < self.critic_warmup:
            return np.random.randint(self.action_size)
        
        if (random() < self.eps and not det):
            return np.random.randint(self.action_size)
        
        state = torch.from_numpy(state.reshape(-1, self.state_size)).float().to(self.device)
        with torch.no_grad():
            if not det:
                dist = Categorical(self.policy(state))
                action = dist.sample()
            else:
                action = torch.argmax(self.policy(state), dim=1)
            
        return action.cpu().numpy()[0]
    
    def policy(self, state):
        return self.actor(state)
    
    def synchronize(self):
        with torch.no_grad():
            params = self.q.parameters()
            targ_params = self.q_target.parameters()
            for p, p_targ in zip(params, targ_params):
                p_targ.data.mul_(1 - self.tau)
                p_targ.data.add_(self.tau * p.data)
                
            params = self.actor.parameters()
            targ_params = self.actor_target.parameters()
            for p, p_targ in zip(params, targ_params):
                p_targ.data.mul_(1 - self.tau)
                p_targ.data.add_(self.tau * p.data)
            
        
    def update_critic(self, batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones):
        with torch.no_grad():
            last_step_target = batch_dones * batch_rewards
                        
            next_actions = self.get_target_action(batch_next_states, det=False)
            
            middle_step_target = (1-batch_dones)*(batch_rewards + self.gamma*self.get_q_val(self.q_target, batch_next_states, next_actions).squeeze())
            targets =  last_step_target + middle_step_target
        
        predictions = self.get_q_val(self.q, batch_states, batch_actions).squeeze()
        loss = F.mse_loss(predictions, targets)
        self.critic_opt.zero_grad()
        loss.backward()
        self.critic_opt.step()
        
        self.synchronize()
        
        return loss.detach().cpu().numpy()
    

    def update_actor(self, batch_states):
#         with torch.no_grad():
#             q_vals = self.get_q_val(self.q_target, batch_states, batch_actions).squeeze()
        dist = self.policy(batch_states)
        actions = dist.sample()
        with torch.no_grad():
            q_vals = self.get_q_val(self.q, batch_states, actions).squeeze()
            
        action_probs = self.policy(batch_states)
        action_logprobs = torch.clip(dist.log_prob(actions), min=-10, max=10)
        
        loss = -torch.mean(q_vals*action_logprobs)
        
        self.actor_opt.zero_grad()
        loss.backward()
        self.actor_opt.step()
        
        return loss.detach().cpu().numpy()
    
    
    def update(self, state, action, reward, next_state, done):
        self.update_step += 1
    
        self.add_to_buffer([state, action, reward, next_state, done])
        
        #TODO: warmup. should be configurable.
        if len(self.buffer) < self.warmup_step:
            return None, None         


        batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = self.sample_minibatch()

        if self.update_step < self.critic_warmup:
            critic_loss = self.update_critic(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
            return None, critic_loss

        if done:
            self.eps = max(self.min_eps, self.eps*0.999)
        
        critic_loss = self.update_critic(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
        actor_loss = self.update_actor(batch_states)
        
        return actor_loss, critic_loss



## Double DQN

In [38]:
class FCNet(nn.Module):

    def __init__(self, input_size, output_size, last_activation=None):
        super(FCNet, self).__init__()
        
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, output_size)
        self.last_activation = last_activation
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        
        if self.last_activation == 'softmax':
            x = F.softmax(x, dim=1)
        
        return x
    
class DoubleDQN:
    def __init__(self, buffer_size, gamma, lr, eps, state_size, action_size, batch_size,
                 device, tau=0.99, eps_decay=0.999, min_eps=0.05, warmup_step=1000):
        self.max_buffer_size = buffer_size
        self.gamma = torch.tensor(gamma)
        self.lr = lr
        self.eps = eps
        
        #self.buffer = deque()
        self.buffer = [None]*self.max_buffer_size
        self.buf_idx = 0
        
        self.action_size = action_size
        self.device = device
        self.q = FCNet(state_size, action_size)
        self.q_target = FCNet(state_size, action_size)
        self.q.to(self.device)
        self.q_target.to(self.device)
        self.batch_size = batch_size
        self.optimizer = optim.Adam(self.q.parameters(), lr=lr)
        self.update_step = 0
        self.tau = tau
        self.min_eps = min_eps
        self.eps_decay = eps_decay
        self.warmup_steps = warmup_step
        
       
        
    def get_q_val(self, qnet, state, action=None):
        if action is None:
            return qnet(state)
        return qnet(state).gather(1, action.reshape(-1,1).long())
        
    def add_to_buffer(self, experience):
#         if len(self.buffer) >= self.max_buffer_size:
#             self.buffer.pop()
#         self.buffer.appendleft(experience)
        self.buffer[self.buf_idx%self.max_buffer_size] = experience
        self.buf_idx += 1
        
    def _np_to_tensor(self, nparrs):
        return [torch.from_numpy(arr).float().to(self.device) for arr in nparrs]
        
    def sample_minibatch(self):
        #idxs = np.random.randint(0, len(self.buffer), self.batch_size)
        idxs = np.random.randint(0, min(self.max_buffer_size, self.buf_idx), self.batch_size)
        sample = [self.buffer[i] for i in idxs]
        
        states, actions, rewards, next_states, dones = list(map(lambda x: np.array(x, dtype=np.float64), zip(*sample)))
        return tuple(self._np_to_tensor([states, actions, rewards, next_states, dones]))
    
    def synchronize(self):
        with torch.no_grad():
            params = self.q.parameters()
            targ_params = self.q_target.parameters()
            for p, p_targ in zip(params, targ_params):
                p_targ.data.mul_(1 - self.tau)
                p_targ.data.add_(self.tau * p.data)
        
    def get_action(self, state, det=False):
        state = torch.from_numpy(state).float().to(self.device)
        with torch.no_grad():
            act = np.argmax(self.get_q_val(self.q, state).cpu().numpy())
        if det:
            return act
        if random() < self.eps:
            return np.random.randint(0, self.action_size)
        else:
            #TODO: sample from q dist?
            return act
        
    def update(self, state, action, reward, next_state, done):
        self.update_step += 1
        
        self.add_to_buffer([state, action, reward, next_state, done])
        
        if self.buf_idx < self.warmup_steps:
            return
        
        if done:
            self.eps = max(self.min_eps, self.eps * self.eps_decay)
        
        
        batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = self.sample_minibatch()
        
        with torch.no_grad():
            last_step_target = batch_dones * batch_rewards
            
            random_vals = torch.rand(size=(self.batch_size, ), device=self.device)
            rand_mask = random_vals < self.eps
            det_mask = random_vals >= self.eps
            random_actions = torch.randint(low=0, high=self.action_size, size=(self.batch_size, ), device=self.device)
            random_actions[det_mask] = 0
            
            next_actions = torch.argmax(self.get_q_val(self.q_target, batch_next_states), dim=1)
            next_actions[rand_mask] = 0
            
            next_actions = next_actions + random_actions
            middle_step_target = (1-batch_dones)*(batch_rewards + self.gamma*self.get_q_val(self.q, batch_next_states, next_actions).squeeze())
            targets =  last_step_target + middle_step_target
        
        predictions = self.get_q_val(self.q, batch_states, batch_actions).squeeze()
        loss = F.mse_loss(predictions, targets)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.synchronize()

        


## training

In [None]:
def log(_str, path='./log.txt'):
    print(_str)
    with open(path, 'a') as f:
        f.write(f"{_str}\n")


def evaluate_agent(env, agent, k=10):
    returns = []
    for _ in range(k):
        state = env.reset()
        done = False
        ret = 0
        while not done:
            action = agent.get_action(state)
            state, reward, done, _ = env.step(action)
            ret += reward
        returns.append(ret)
    return returns
            
def train_and_evaluate(env, agent, epochs, eval_env, eval_step=100, eval_it=50):
    
    returns = []
    epsilons = []
    eval_rets = []
    for i in tqdm(range(1, epochs+1)):
        done = False
        state = env.reset()
        
        ret = 0
        while not done:
            a = agent.get_action(state)
            next_state, r, done, _ = env.step(a)
            agent.update(state, a, r, next_state, done)
            ret += r
            state = next_state
        done = False
        epsilons.append(agent.eps)
        returns.append(ret)
        
        
        if i%eval_step == 0:
            avg_ret = np.mean(evaluate_agent(eval_env, agent, k=eval_it))
            log(f"epoch:{i} average return={avg_ret} current_eps:{agent.eps} buffer_size:{agent.buf_idx}")
            eval_rets.append(avg_ret)
            
    return returns, epsilons, eval_rets



env = KGEnv(
    len(relations_voc),
    observation_space=768+200+200+200,
    q_embeddings=question_embs,
    ent_embeddings=ent_emb_dict,
    rel_embeddings=rel_emb_dict,
    entities_vocab=entities_voc, 
    entities_inv_vocab=entities_inv_voc,
    rel_inv_voc=relations_inv_voc, rel_voc=relations_voc,
    mode='train'
)


eval_env = KGEnv(
    len(relations_voc),
    observation_space=768+200+200+200,
    q_embeddings=question_embs,
    ent_embeddings=ent_emb_dict,
    rel_embeddings=rel_emb_dict,
    entities_vocab=entities_voc, 
    entities_inv_vocab=entities_inv_voc,
    rel_inv_voc=relations_inv_voc, rel_voc=relations_voc,
    mode='dev'
)

ddqn = DoubleDQN(
    buffer_size=150000,
    gamma=0.99,
    lr=3e-5,
    eps=0.8,
    state_size=env.observation_space.shape[0],
    action_size=env.action_space.n,
    batch_size=128,
    eps_decay=0.9995,
    warmup_step=50000,
    min_eps=0.05,
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
)


returns, epsilons, eval_rets = train_and_evaluate(env, ddqn, epochs=800000, eval_env=eval_env, eval_step=1000, eval_it=100)

Total env size in mode train: 667874
Total env size in mode dev: 39138


  0%|          | 1070/800000 [00:01<29:56, 444.61it/s]

epoch:1000 average return=0.07 current_eps:0.8 buffer_size:3779


  0%|          | 2070/800000 [00:03<28:21, 468.88it/s]

epoch:2000 average return=0.07 current_eps:0.8 buffer_size:7541


  0%|          | 3066/800000 [00:05<28:43, 462.37it/s]

epoch:3000 average return=0.07 current_eps:0.8 buffer_size:11317


  1%|          | 4108/800000 [00:07<30:09, 439.95it/s]

epoch:4000 average return=0.02 current_eps:0.8 buffer_size:15075


  1%|          | 5112/800000 [00:09<28:08, 470.67it/s]

epoch:5000 average return=0.1 current_eps:0.8 buffer_size:18857


  1%|          | 6124/800000 [00:11<27:55, 473.96it/s]

epoch:6000 average return=0.03 current_eps:0.8 buffer_size:22672


  1%|          | 7103/800000 [00:12<29:19, 450.53it/s]

epoch:7000 average return=0.05 current_eps:0.8 buffer_size:26464


  1%|          | 8118/800000 [00:14<28:31, 462.60it/s]

epoch:8000 average return=0.11 current_eps:0.8 buffer_size:30254


  1%|          | 9083/800000 [00:16<28:44, 458.70it/s]

epoch:9000 average return=0.08 current_eps:0.8 buffer_size:34012


  1%|▏         | 10103/800000 [00:18<27:43, 474.96it/s]

epoch:10000 average return=0.08 current_eps:0.8 buffer_size:37809


  1%|▏         | 11125/800000 [00:19<26:51, 489.51it/s]

epoch:11000 average return=0.1 current_eps:0.8 buffer_size:41592


  2%|▏         | 12075/800000 [00:21<28:04, 467.84it/s]

epoch:12000 average return=0.1 current_eps:0.8 buffer_size:45348


  2%|▏         | 13106/800000 [00:23<28:08, 466.10it/s]

epoch:13000 average return=0.04 current_eps:0.8 buffer_size:49118


  2%|▏         | 14009/800000 [00:40<5:35:53, 39.00it/s] 

epoch:14000 average return=0.15 current_eps:0.5459438850462682 buffer_size:52632


  2%|▏         | 15012/800000 [01:01<4:16:42, 50.96it/s] 

epoch:15000 average return=0.33 current_eps:0.3310903020863681 buffer_size:55413


  2%|▏         | 16012/800000 [01:18<4:04:28, 53.45it/s]

epoch:16000 average return=0.23 current_eps:0.20079131049586138 buffer_size:57988


  2%|▏         | 17000/800000 [01:37<4:25:52, 49.08it/s] 

epoch:17000 average return=0.31 current_eps:0.12177085863459788 buffer_size:60499


  2%|▏         | 18000/800000 [01:57<3:30:08, 62.02it/s] 

epoch:18000 average return=0.31 current_eps:0.07384852450033136 buffer_size:63036


  2%|▏         | 19013/800000 [02:13<4:10:21, 51.99it/s]

epoch:19000 average return=0.39 current_eps:0.05 buffer_size:65559


  3%|▎         | 20010/800000 [02:27<3:58:20, 54.54it/s] 

epoch:20000 average return=0.43 current_eps:0.05 buffer_size:68039


  3%|▎         | 21009/800000 [02:41<3:34:30, 60.53it/s]

epoch:21000 average return=0.48 current_eps:0.05 buffer_size:70496


  3%|▎         | 22010/800000 [02:57<8:33:58, 25.23it/s] 

epoch:22000 average return=0.45 current_eps:0.05 buffer_size:73082


  4%|▎         | 29016/800000 [05:00<3:19:18, 64.47it/s] 

epoch:29000 average return=0.67 current_eps:0.05 buffer_size:90882


  4%|▍         | 30000/800000 [05:16<4:21:06, 49.15it/s]

epoch:30000 average return=0.62 current_eps:0.05 buffer_size:93437


  4%|▍         | 31003/800000 [05:36<5:34:02, 38.37it/s] 

epoch:31000 average return=0.61 current_eps:0.05 buffer_size:96069


  4%|▍         | 32012/800000 [05:53<3:36:45, 59.05it/s]

epoch:32000 average return=0.64 current_eps:0.05 buffer_size:98607


  4%|▍         | 33013/800000 [06:15<3:56:26, 54.07it/s] 

epoch:33000 average return=0.63 current_eps:0.05 buffer_size:101173


  4%|▍         | 34009/800000 [06:38<4:17:12, 49.64it/s] 

epoch:34000 average return=0.5 current_eps:0.05 buffer_size:103796


  4%|▍         | 35016/800000 [06:56<3:49:52, 55.46it/s] 

epoch:35000 average return=0.64 current_eps:0.05 buffer_size:106543


  5%|▍         | 36009/800000 [07:19<3:45:41, 56.42it/s] 

epoch:36000 average return=0.65 current_eps:0.05 buffer_size:109157


  5%|▍         | 37010/800000 [07:36<4:08:20, 51.21it/s]

epoch:37000 average return=0.62 current_eps:0.05 buffer_size:111792


  5%|▍         | 38006/800000 [07:52<4:04:42, 51.90it/s]

epoch:38000 average return=0.56 current_eps:0.05 buffer_size:114436


  5%|▍         | 39008/800000 [08:08<3:42:19, 57.05it/s]

epoch:39000 average return=0.62 current_eps:0.05 buffer_size:116989


  5%|▌         | 40010/800000 [08:25<3:18:40, 63.76it/s]

epoch:40000 average return=0.61 current_eps:0.05 buffer_size:119638


  5%|▌         | 41014/800000 [08:46<3:51:08, 54.73it/s] 

epoch:41000 average return=0.67 current_eps:0.05 buffer_size:122324


  5%|▌         | 42008/800000 [09:04<4:40:55, 44.97it/s] 

epoch:42000 average return=0.57 current_eps:0.05 buffer_size:124938


  5%|▌         | 43002/800000 [09:25<10:04:23, 20.88it/s]

epoch:43000 average return=0.64 current_eps:0.05 buffer_size:127560


  6%|▌         | 44002/800000 [09:46<15:15:44, 13.76it/s]

epoch:44000 average return=0.6 current_eps:0.05 buffer_size:130167


  6%|▌         | 45003/800000 [10:03<3:28:10, 60.44it/s] 

epoch:45000 average return=0.73 current_eps:0.05 buffer_size:132771


  6%|▌         | 46011/800000 [10:24<6:48:40, 30.75it/s] 

epoch:46000 average return=0.59 current_eps:0.05 buffer_size:135403


  6%|▌         | 47001/800000 [10:46<10:16:37, 20.35it/s]

epoch:47000 average return=0.67 current_eps:0.05 buffer_size:138026


  6%|▌         | 48007/800000 [11:05<4:54:53, 42.50it/s] 

epoch:48000 average return=0.64 current_eps:0.05 buffer_size:140729


  7%|▋         | 54003/800000 [13:05<4:00:44, 51.65it/s]]

epoch:54000 average return=0.7 current_eps:0.05 buffer_size:156695


  7%|▋         | 55014/800000 [13:24<3:22:13, 61.40it/s]

epoch:55000 average return=0.61 current_eps:0.05 buffer_size:159306


  7%|▋         | 56014/800000 [13:48<4:13:58, 48.82it/s] 

epoch:56000 average return=0.62 current_eps:0.05 buffer_size:161984


  7%|▋         | 57007/800000 [14:11<7:23:42, 27.91it/s] 

epoch:57000 average return=0.7 current_eps:0.05 buffer_size:164645


  7%|▋         | 58009/800000 [14:32<3:53:08, 53.04it/s] 

epoch:58000 average return=0.63 current_eps:0.05 buffer_size:167269


  7%|▋         | 59002/800000 [14:54<13:26:05, 15.32it/s]

epoch:59000 average return=0.71 current_eps:0.05 buffer_size:169888


  8%|▊         | 60009/800000 [15:16<4:21:29, 47.16it/s] 

epoch:60000 average return=0.67 current_eps:0.05 buffer_size:172548


  8%|▊         | 61013/800000 [15:40<3:51:45, 53.14it/s] 

epoch:61000 average return=0.74 current_eps:0.05 buffer_size:175155


  8%|▊         | 62011/800000 [16:00<4:07:05, 49.78it/s] 

epoch:62000 average return=0.66 current_eps:0.05 buffer_size:177844


  8%|▊         | 63013/800000 [16:19<3:30:49, 58.26it/s] 

epoch:63000 average return=0.68 current_eps:0.05 buffer_size:180472


  8%|▊         | 64011/800000 [16:42<4:33:46, 44.80it/s] 

epoch:64000 average return=0.62 current_eps:0.05 buffer_size:183065


  8%|▊         | 65001/800000 [17:01<4:10:32, 48.89it/s] 

epoch:65000 average return=0.73 current_eps:0.05 buffer_size:185669


  8%|▊         | 66011/800000 [17:24<3:36:30, 56.50it/s] 

epoch:66000 average return=0.74 current_eps:0.05 buffer_size:188297


  8%|▊         | 67012/800000 [17:45<3:36:06, 56.53it/s] 

epoch:67000 average return=0.59 current_eps:0.05 buffer_size:190880


  9%|▊         | 68008/800000 [18:08<4:55:56, 41.22it/s] 

epoch:68000 average return=0.61 current_eps:0.05 buffer_size:193481


  9%|▊         | 69015/800000 [18:34<4:11:25, 48.46it/s] 

epoch:69000 average return=0.59 current_eps:0.05 buffer_size:196122


  9%|▉         | 70003/800000 [18:56<4:17:47, 47.20it/s] 

epoch:70000 average return=0.72 current_eps:0.05 buffer_size:198718


  9%|▉         | 71005/800000 [19:21<4:16:02, 47.45it/s] 

epoch:71000 average return=0.69 current_eps:0.05 buffer_size:201328


  9%|▉         | 72012/800000 [19:44<3:44:46, 53.98it/s] 

epoch:72000 average return=0.57 current_eps:0.05 buffer_size:203966


  9%|▉         | 73010/800000 [20:05<2:54:01, 69.63it/s] 

epoch:73000 average return=0.69 current_eps:0.05 buffer_size:206608


  9%|▉         | 74005/800000 [20:23<3:40:09, 54.96it/s] 

epoch:74000 average return=0.67 current_eps:0.05 buffer_size:209178


  9%|▉         | 75014/800000 [20:42<3:27:59, 58.09it/s] 

epoch:75000 average return=0.69 current_eps:0.05 buffer_size:211787


  9%|▉         | 75116/800000 [20:44<7:17:52, 27.59it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 19%|█▉        | 154000/800000 [45:14<2:56:03, 61.15it/s]

epoch:154000 average return=0.74 current_eps:0.05 buffer_size:413662


 19%|█▉        | 155009/800000 [45:31<3:28:30, 51.56it/s]

epoch:155000 average return=0.71 current_eps:0.05 buffer_size:416162


 20%|█▉        | 156009/800000 [45:47<3:02:07, 58.93it/s]

epoch:156000 average return=0.71 current_eps:0.05 buffer_size:418686


 20%|█▉        | 157011/800000 [46:04<3:01:36, 59.01it/s]

epoch:157000 average return=0.73 current_eps:0.05 buffer_size:421226


 20%|█▉        | 158000/800000 [46:20<3:04:39, 57.94it/s]

epoch:158000 average return=0.76 current_eps:0.05 buffer_size:423754


 20%|█▉        | 159013/800000 [46:35<3:04:44, 57.83it/s]

epoch:159000 average return=0.66 current_eps:0.05 buffer_size:426339


 20%|██        | 160008/800000 [46:51<3:07:54, 56.77it/s]

epoch:160000 average return=0.7 current_eps:0.05 buffer_size:428853


 20%|██        | 161013/800000 [47:06<3:42:45, 47.81it/s]

epoch:161000 average return=0.58 current_eps:0.05 buffer_size:431391


 20%|██        | 162012/800000 [47:23<3:40:28, 48.23it/s]

epoch:162000 average return=0.64 current_eps:0.05 buffer_size:433965


 20%|██        | 163011/800000 [47:38<4:38:10, 38.17it/s]

epoch:163000 average return=0.71 current_eps:0.05 buffer_size:436492


 21%|██        | 164008/800000 [47:56<2:55:42, 60.33it/s] 

epoch:164000 average return=0.57 current_eps:0.05 buffer_size:439027


 21%|██        | 165007/800000 [48:11<3:40:24, 48.02it/s]

epoch:165000 average return=0.68 current_eps:0.05 buffer_size:441601


 21%|██        | 166010/800000 [48:27<3:16:26, 53.79it/s]

epoch:166000 average return=0.66 current_eps:0.05 buffer_size:444135


 21%|██        | 167011/800000 [48:42<2:35:06, 68.02it/s]

epoch:167000 average return=0.72 current_eps:0.05 buffer_size:446650


 21%|██        | 168006/800000 [48:57<3:45:00, 46.81it/s]

epoch:168000 average return=0.69 current_eps:0.05 buffer_size:449217


 21%|██        | 169013/800000 [49:13<2:51:28, 61.33it/s]

epoch:169000 average return=0.65 current_eps:0.05 buffer_size:451758


 21%|██▏       | 170005/800000 [49:29<3:26:50, 50.76it/s]

epoch:170000 average return=0.67 current_eps:0.05 buffer_size:454253


 21%|██▏       | 171008/800000 [49:44<3:10:49, 54.94it/s]

epoch:171000 average return=0.66 current_eps:0.05 buffer_size:456721


 22%|██▏       | 172008/800000 [50:02<4:19:35, 40.32it/s]

epoch:172000 average return=0.61 current_eps:0.05 buffer_size:459160


 22%|██▏       | 172999/800000 [50:20<2:52:32, 60.57it/s]

epoch:173000 average return=0.65 current_eps:0.05 buffer_size:461675


 22%|██▏       | 173999/800000 [50:37<2:12:49, 78.55it/s]

epoch:174000 average return=0.77 current_eps:0.05 buffer_size:464188


 22%|██▏       | 175010/800000 [50:54<3:47:38, 45.76it/s]

epoch:175000 average return=0.75 current_eps:0.05 buffer_size:466745


 22%|██▏       | 176012/800000 [51:09<2:47:02, 62.26it/s]

epoch:176000 average return=0.69 current_eps:0.05 buffer_size:469315


 22%|██▏       | 177009/800000 [51:26<3:30:41, 49.28it/s]

epoch:177000 average return=0.63 current_eps:0.05 buffer_size:471826


 22%|██▏       | 178007/800000 [51:41<4:03:28, 42.58it/s]

epoch:178000 average return=0.74 current_eps:0.05 buffer_size:474338


 22%|██▏       | 179012/800000 [51:56<3:03:31, 56.39it/s]

epoch:179000 average return=0.64 current_eps:0.05 buffer_size:476906


 22%|██▏       | 179232/800000 [52:00<2:28:56, 69.46it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 33%|███▎      | 261997/800000 [1:15:50<2:07:40, 70.23it/s]

epoch:262000 average return=0.72 current_eps:0.05 buffer_size:682904


 33%|███▎      | 263014/800000 [1:16:07<2:38:08, 56.59it/s]

epoch:263000 average return=0.71 current_eps:0.05 buffer_size:685316


 33%|███▎      | 264001/800000 [1:16:25<3:22:50, 44.04it/s]

epoch:264000 average return=0.62 current_eps:0.05 buffer_size:687767


 33%|███▎      | 265010/800000 [1:16:43<3:10:44, 46.75it/s]

epoch:265000 average return=0.76 current_eps:0.05 buffer_size:690256


 33%|███▎      | 266000/800000 [1:17:02<4:53:11, 30.36it/s]

epoch:266000 average return=0.78 current_eps:0.05 buffer_size:692790


 33%|███▎      | 267005/800000 [1:17:23<3:50:14, 38.58it/s]

epoch:267000 average return=0.72 current_eps:0.05 buffer_size:695319


 34%|███▎      | 268007/800000 [1:17:41<2:56:34, 50.22it/s]

epoch:268000 average return=0.67 current_eps:0.05 buffer_size:697814


 34%|███▎      | 269009/800000 [1:17:57<2:40:32, 55.12it/s]

epoch:269000 average return=0.63 current_eps:0.05 buffer_size:700309


 34%|███▍      | 270016/800000 [1:18:12<2:34:57, 57.00it/s]

epoch:270000 average return=0.62 current_eps:0.05 buffer_size:702785


 34%|███▍      | 271010/800000 [1:18:29<2:16:57, 64.37it/s]

epoch:271000 average return=0.76 current_eps:0.05 buffer_size:705160


 34%|███▍      | 272014/800000 [1:18:45<2:22:55, 61.57it/s]

epoch:272000 average return=0.7 current_eps:0.05 buffer_size:707625


 34%|███▍      | 273013/800000 [1:19:00<2:27:04, 59.72it/s]

epoch:273000 average return=0.62 current_eps:0.05 buffer_size:710090


 34%|███▍      | 274010/800000 [1:19:18<2:51:42, 51.05it/s]

epoch:274000 average return=0.69 current_eps:0.05 buffer_size:712517


 34%|███▍      | 275013/800000 [1:19:33<2:04:05, 70.51it/s]

epoch:275000 average return=0.55 current_eps:0.05 buffer_size:714980


 34%|███▍      | 276000/800000 [1:19:48<3:04:41, 47.29it/s]

epoch:276000 average return=0.65 current_eps:0.05 buffer_size:717432


 35%|███▍      | 276996/800000 [1:20:04<2:16:27, 63.88it/s]

epoch:277000 average return=0.61 current_eps:0.05 buffer_size:719966


 35%|███▍      | 278015/800000 [1:20:22<2:03:23, 70.51it/s]

epoch:278000 average return=0.68 current_eps:0.05 buffer_size:722453


 35%|███▍      | 279008/800000 [1:20:37<2:34:09, 56.33it/s]

epoch:279000 average return=0.71 current_eps:0.05 buffer_size:724888


 35%|███▌      | 280005/800000 [1:20:53<3:02:21, 47.52it/s]

epoch:280000 average return=0.61 current_eps:0.05 buffer_size:727336


 35%|███▌      | 281012/800000 [1:21:10<2:03:49, 69.85it/s]

epoch:281000 average return=0.6 current_eps:0.05 buffer_size:729860


 35%|███▌      | 282002/800000 [1:21:25<2:37:43, 54.73it/s]

epoch:282000 average return=0.62 current_eps:0.05 buffer_size:732238


 35%|███▌      | 283012/800000 [1:21:40<2:11:27, 65.55it/s]

epoch:283000 average return=0.63 current_eps:0.05 buffer_size:734716


 36%|███▌      | 284012/800000 [1:21:56<2:06:18, 68.08it/s]

epoch:284000 average return=0.62 current_eps:0.05 buffer_size:737149


 36%|███▌      | 285013/800000 [1:22:13<2:39:19, 53.87it/s]

epoch:285000 average return=0.64 current_eps:0.05 buffer_size:739663


 36%|███▌      | 286009/800000 [1:22:30<2:33:45, 55.71it/s]

epoch:286000 average return=0.74 current_eps:0.05 buffer_size:742125


 36%|███▌      | 286026/800000 [1:22:30<2:12:30, 64.64it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 38%|███▊      | 300012/800000 [1:26:26<2:37:39, 52.86it/s]

epoch:300000 average return=0.71 current_eps:0.05 buffer_size:776502


 38%|███▊      | 300999/800000 [1:26:42<1:44:45, 79.39it/s]

epoch:301000 average return=0.69 current_eps:0.05 buffer_size:778954


 38%|███▊      | 302005/800000 [1:26:59<3:00:55, 45.87it/s]

epoch:302000 average return=0.6 current_eps:0.05 buffer_size:781324


 38%|███▊      | 303010/800000 [1:27:16<2:39:32, 51.92it/s]

epoch:303000 average return=0.7 current_eps:0.05 buffer_size:783817


 38%|███▊      | 304003/800000 [1:27:33<2:30:13, 55.03it/s]

epoch:304000 average return=0.72 current_eps:0.05 buffer_size:786256


 38%|███▊      | 305014/800000 [1:27:49<2:39:08, 51.84it/s]

epoch:305000 average return=0.67 current_eps:0.05 buffer_size:788598


 38%|███▊      | 306005/800000 [1:28:04<2:59:29, 45.87it/s]

epoch:306000 average return=0.7 current_eps:0.05 buffer_size:791032


 38%|███▊      | 307005/800000 [1:28:22<4:23:46, 31.15it/s]

epoch:307000 average return=0.62 current_eps:0.05 buffer_size:793462


 39%|███▊      | 308001/800000 [1:28:37<2:17:44, 59.53it/s]

epoch:308000 average return=0.63 current_eps:0.05 buffer_size:795834


 39%|███▊      | 309009/800000 [1:28:56<2:28:11, 55.22it/s]

epoch:309000 average return=0.68 current_eps:0.05 buffer_size:798309


 39%|███▉      | 310010/800000 [1:29:13<2:52:35, 47.32it/s]

epoch:310000 average return=0.71 current_eps:0.05 buffer_size:800720


 39%|███▉      | 311010/800000 [1:29:31<2:12:05, 61.70it/s]

epoch:311000 average return=0.57 current_eps:0.05 buffer_size:803178


 39%|███▉      | 311805/800000 [1:29:43<1:26:42, 93.83it/s]

In [42]:
eval_env = KGEnv(
    len(relations_voc),
    observation_space=768+200+200+200,
    q_embeddings=question_embs,
    ent_embeddings=ent_emb_dict,
    rel_embeddings=rel_emb_dict,
    entities_vocab=entities_voc, 
    entities_inv_vocab=entities_inv_voc,
    rel_inv_voc=relations_inv_voc, rel_voc=relations_voc,
    mode='dev',
    nhops=[2]
)

avg_ret = np.mean(evaluate_agent(eval_env, ddqn, k=1000))
print(avg_ret)


Total env size in mode dev: 14872
0.716


In [361]:
def play_in_env(ddqn):
    state = kg_env.reset()
    done = False
    
    while not done:
        kg_env.render()        
        action = ddqn.get_action(state, det=True)
        state, reward, done, _ = kg_env.step(action)
        
play_in_env(ddqn)

The Spanish Earth has_genre
current entity: The Spanish Earth
available moves:
--directed_by(6)--> Joris Ivens
--release_year(16)--> 1937
--written_by(11)--> Ernest Hemingway
--in_language(8)--> Spanish
--written_by(11)--> John Dos Passos
--has_genre(13)--> War
--has_tags(7)--> spanish civil war
The Spanish Earth has_genre
current entity: The Spanish Earth
available moves:
--directed_by(6)--> Joris Ivens
--release_year(16)--> 1937
--written_by(11)--> Ernest Hemingway
--in_language(8)--> Spanish
--written_by(11)--> John Dos Passos
--has_genre(13)--> War
--has_tags(7)--> spanish civil war
The Spanish Earth has_genre
current entity: The Spanish Earth
available moves:
--directed_by(6)--> Joris Ivens
--release_year(16)--> 1937
--written_by(11)--> Ernest Hemingway
--in_language(8)--> Spanish
--written_by(11)--> John Dos Passos
--has_genre(13)--> War
--has_tags(7)--> spanish civil war
The Spanish Earth has_genre
current entity: The Spanish Earth
available moves:
--directed_by(6)--> Joris Iven

In [404]:

from d3rlpy.algos import DQN, DiscreteSAC
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy
from d3rlpy.metrics.scorer import evaluate_on_environment

env=kg_env

dqn = DQN(
    batch_size=32,
    learning_rate=2.5e-4,
    target_update_interval=100,
    use_gpu=False
)

sac = DiscreteSAC(
    use_batch_norm=True,
    target_update_interval=1000,
    use_gpu=False,
)

# setup replay buffer
buffer = ReplayBuffer(maxlen=100000, env=env)

# setup explorers
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                    end_epsilon=0.1,
                                    duration=10000)

# start training
sac.fit_online(
    env,
    buffer,
    n_steps=100000,
    n_steps_per_epoch=100,
    callback=evaluate_on_environment(eval_env, n_trials=100)
)

2021-12-05 12:38.55 [info     ] Directory is created at d3rlpy_logs/DiscreteSAC_online_20211205123855
2021-12-05 12:38.55 [debug    ] Building model...
2021-12-05 12:38.55 [debug    ] Model has been built.
2021-12-05 12:38.55 [info     ] Parameters are saved to d3rlpy_logs/DiscreteSAC_online_20211205123855/params.json params={'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0003, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 0.0001, 'weight_decay': 0, 'amsgrad': False}, 'batch_size': 64, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_learning_rate': 0.0003, 'critic_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 0.0001, 'weight_decay': 0, 'amsgrad': False}, 'gamma': 0.99, 'generated_maxlen': 100000, 'initial_temperature': 1.0,

  0%|▏                                                                                                                                                                | 82/100000 [00:09<3:19:36,  8.34it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn