In [1]:
from torch.nn import Sequential
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import MSELoss
import torch
import torch.optim.optimizer as optim
from torch.optim import Adam
import random, copy
import numpy as np

import re

from sklearn import preprocessing

from collections import defaultdict
from tqdm.notebook import tqdm 
import wandb



In [2]:
'PLACEHOLDER'  # информационные слоты
'UNK'  # слоты запросоы
'anything'  # любая информация для слота запроса уместна
'no match available'  # в базе данных не нашлйо совпадений для цели

# Типы возможных действий
usersim_intents = ['inform', 'request', 'thanks', 'reject', 'done']

# Главная цель диалога 
usersim_default_key = 'Model'

# Обязательный запросовый слот
usersim_required_init_inform_keys = ['Release date']


# То, что может запросить или передать агент
agent_inform_slots = ['Release date','Max resolution','Low resolution','Effective pixels','Zoom tele (T)', 'Normal focus range', 'Macro focus range', 'Storage included', 'Weight (inc. batteries)', 'Dimensions', usersim_default_key]
agent_request_slots = ['Release date','Max resolution','Low resolution','Effective pixels','Zoom wide (W)', 'Normal focus range', 'Macro focus range', 'Storage included', 'Weight (inc. batteries)', 'Dimensions', 'Price']

# Доступные действия агента
agent_actions = [
    {'intent': 'done', 'inform_slots': {}, 'request_slots': {}},  # Triggers closing of conversation
    {'intent': 'match_found', 'inform_slots': {}, 'request_slots': {}}
]
for slot in agent_inform_slots:
    if slot == usersim_default_key:
        continue
    agent_actions.append({'intent': 'inform', 'inform_slots': {slot: 'PLACEHOLDER'}, 'request_slots': {}})
for slot in agent_request_slots:
    agent_actions.append({'intent': 'request', 'inform_slots': {}, 'request_slots': {slot: 'UNK'}})

# Политика 
rule_requests = ['Release date','Max resolution','Low resolution', 'Effective pixels']

# Не может быть запросов по этому ключу
no_query_keys = ['Price', usersim_default_key]


# Индикаторы успешности эпизода
FAIL = -1
NO_OUTCOME = 0
SUCCESS = 1

# все возможные намерения
all_intents = ['inform', 'request', 'done', 'match_found', 'thanks', 'reject']

# Все возможые слоты 
all_slots = ['Release date','Max resolution','Low resolution','Effective pixels','Zoom wide (W)','Zoom tele (T)','Normal focus range','Macro focus range','Storage included','Weight (inc. batteries)','Dimensions','Price', usersim_default_key]

In [25]:
constants = {
  "db_file_paths": {
    "database": "/kaggle/input/new-cam/cameras_db.pkl",
    "dict": "/kaggle/input/new-cam/cameras_dict.pkl",
    "user_goals": "/kaggle/input/000000/camera_user_goals.pkl"
  },
  "run": {
    "usersim": True,
    "warmup_mem": 1000 ,
    "num_ep_run": 5000,
    "train_freq": 100,
    "max_round_num": 20,
    "success_rate_threshold": 0.3
  },
  "agent": {
    "save_weights_file_path": "",
    "load_weights_file_path": "",
    "vanilla": True,
    "learning_rate": 1e-3,
    "batch_size": 16,
    "dqn_hidden_size": 80,
    "epsilon_init": 0.0,
    "gamma": 0.9,
    "max_mem_size": 500_000
  },
  "emc": {
    "slot_error_mode": 0,
    "slot_error_prob": 0.05,
    "intent_error_prob": 0.0
  }
}

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
def convert_list_to_dict(lst):


    if len(lst) > len(set(lst)):
        raise ValueError('List must be unique!')
    return {k: v for v, k in enumerate(lst)}


def remove_empty_slots(dic):


    for id in list(dic.keys()):
        for key in list(dic[id].keys()):
            if dic[id][key] == '':
                dic[id].pop(key)


def reward_function(success, max_round):


    reward = -1
    if success == FAIL:
        reward += -max_round
    elif success == SUCCESS:
        reward += 2 * max_round
    return reward


In [6]:
class UserSimulator:
 
    def __init__(self, goal_list, constants, database):


        self.goal_list = goal_list
        self.max_round = constants['run']['max_round_num']
        self.default_key = usersim_default_key
        self.init_informs = usersim_required_init_inform_keys
        self.no_query = no_query_keys
        self.database = database

    def reset(self):


        self.goal = random.choice(self.goal_list)
        self.goal['request_slots'][self.default_key] = 'UNK'
        self.state = {}
        self.state['history_slots'] = {}
        self.state['inform_slots'] = {}
        self.state['request_slots'] = {}
        self.state['rest_slots'] = {}
        self.state['rest_slots'].update(self.goal['inform_slots'])
        self.state['rest_slots'].update(self.goal['request_slots'])
        self.state['intent'] = ''
        self.constraint_check = FAIL

        return self._return_init_action()

    def _return_init_action(self):

        self.state['intent'] = 'request'

        if self.goal['inform_slots']:
          
            for inform_key in self.init_informs:
                if inform_key in self.goal['inform_slots']:
                    self.state['inform_slots'][inform_key] = self.goal['inform_slots'][inform_key]
                    self.state['rest_slots'].pop(inform_key)
                    self.state['history_slots'][inform_key] = self.goal['inform_slots'][inform_key]
  
            if not self.state['inform_slots']:
                key, value = random.choice(list(self.goal['inform_slots'].items()))
                self.state['inform_slots'][key] = value
                self.state['rest_slots'].pop(key)
                self.state['history_slots'][key] = value

        self.goal['request_slots'].pop(self.default_key)
        if self.goal['request_slots']:
            req_key = random.choice(list(self.goal['request_slots'].keys()))
        else:
            req_key = self.default_key
        self.goal['request_slots'][self.default_key] = 'UNK'
        self.state['request_slots'][req_key] = 'UNK'

        user_response = {}
        user_response['intent'] = self.state['intent']
        user_response['request_slots'] = copy.deepcopy(self.state['request_slots'])
        user_response['inform_slots'] = copy.deepcopy(self.state['inform_slots'])

        return user_response

    def step(self, agent_action):


        for value in agent_action['inform_slots'].values():
            assert value != 'UNK'
            assert value != 'PLACEHOLDER'

        for value in agent_action['request_slots'].values():
            assert value != 'PLACEHOLDER'


        self.state['inform_slots'].clear()
        self.state['intent'] = ''

        done = False
        success = NO_OUTCOME

        if agent_action['round'] == self.max_round:
            done = True
            success = FAIL
            self.state['intent'] = 'done'
            self.state['request_slots'].clear()
        else:
            agent_intent = agent_action['intent']
            if agent_intent == 'request':
                self._response_to_request(agent_action)
            elif agent_intent == 'inform':
                self._response_to_inform(agent_action)
            elif agent_intent == 'match_found':
                self._response_to_match_found(agent_action)
            elif agent_intent == 'done':
                success = self._response_to_done()
                self.state['intent'] = 'done'
                self.state['request_slots'].clear()
                done = True


        if self.state['intent'] == 'request':
            assert self.state['request_slots']

        if self.state['intent'] == 'inform':
            assert self.state['inform_slots']
            assert not self.state['request_slots']
        assert 'UNK' not in self.state['inform_slots'].values()
        assert 'PLACEHOLDER' not in self.state['request_slots'].values()
  
        for key in self.state['rest_slots']:
            assert key not in self.state['history_slots']
        for key in self.state['history_slots']:
            assert key not in self.state['rest_slots']
    
        for inf_key in self.goal['inform_slots']:
            assert self.state['history_slots'].get(inf_key, False) or self.state['rest_slots'].get(inf_key, False)
        for req_key in self.goal['request_slots']:
            assert self.state['history_slots'].get(req_key, False) or self.state['rest_slots'].get(req_key,
                                                                                                   False), req_key
       
        for key in self.state['rest_slots']:
            assert self.goal['inform_slots'].get(key, False) or self.goal['request_slots'].get(key, False)
        assert self.state['intent'] != ''


        user_response = {}
        user_response['intent'] = self.state['intent']
        user_response['request_slots'] = copy.deepcopy(self.state['request_slots'])
        user_response['inform_slots'] = copy.deepcopy(self.state['inform_slots'])

        reward = reward_function(success, self.max_round)

        return user_response, reward, done, True if success == 1 else False

    def _response_to_request(self, agent_action):
        agent_request_key = list(agent_action['request_slots'].keys())[0]
   
        if agent_request_key in self.goal['inform_slots']:
            self.state['intent'] = 'inform'
            self.state['inform_slots'][agent_request_key] = self.goal['inform_slots'][agent_request_key]
            self.state['request_slots'].clear()
            self.state['rest_slots'].pop(agent_request_key, None)
            self.state['history_slots'][agent_request_key] = self.goal['inform_slots'][agent_request_key]

        elif agent_request_key in self.goal['request_slots'] and agent_request_key in self.state['history_slots']:
            self.state['intent'] = 'inform'
            self.state['inform_slots'][agent_request_key] = self.state['history_slots'][agent_request_key]
            self.state['request_slots'].clear()
            assert agent_request_key not in self.state['rest_slots']
    
        elif agent_request_key in self.goal['request_slots'] and agent_request_key in self.state['rest_slots']:
            self.state['request_slots'].clear()
            self.state['intent'] = 'request'
            self.state['request_slots'][agent_request_key] = 'UNK'
            rest_informs = {}
            for key, value in list(self.state['rest_slots'].items()):
                if value != 'UNK':
                    rest_informs[key] = value
            if rest_informs:
                key_choice, value_choice = random.choice(list(rest_informs.items()))
                self.state['inform_slots'][key_choice] = value_choice
                self.state['rest_slots'].pop(key_choice)
                self.state['history_slots'][key_choice] = value_choice
      
        else:
            assert agent_request_key not in self.state['rest_slots']
            self.state['intent'] = 'inform'
            self.state['inform_slots'][agent_request_key] = 'anything'
            self.state['request_slots'].clear()
            self.state['history_slots'][agent_request_key] = 'anything'

    def _response_to_inform(self, agent_action):


        agent_inform_key = list(agent_action['inform_slots'].keys())[0]
        agent_inform_value = agent_action['inform_slots'][agent_inform_key]

        assert agent_inform_key != self.default_key
        self.state['history_slots'][agent_inform_key] = agent_inform_value
        self.state['rest_slots'].pop(agent_inform_key, None)

        self.state['request_slots'].pop(agent_inform_key, None)

        if agent_inform_value != self.goal['inform_slots'].get(agent_inform_key, agent_inform_value):
            self.state['intent'] = 'inform'
            self.state['inform_slots'][agent_inform_key] = self.goal['inform_slots'][agent_inform_key]
            self.state['request_slots'].clear()
            self.state['history_slots'][agent_inform_key] = self.goal['inform_slots'][agent_inform_key]
        else:
            if self.state['request_slots']:
                self.state['intent'] = 'request'
            elif self.state['rest_slots']:
                def_in = self.state['rest_slots'].pop(self.default_key, False)
                if self.state['rest_slots']:
                    key, value = random.choice(list(self.state['rest_slots'].items()))
                    if value != 'UNK':
                        self.state['intent'] = 'inform'
                        self.state['inform_slots'][key] = value
                        self.state['rest_slots'].pop(key)
                        self.state['history_slots'][key] = value
                    else:
                        self.state['intent'] = 'request'
                        self.state['request_slots'][key] = 'UNK'
                else:
                    self.state['intent'] = 'request'
                    self.state['request_slots'][self.default_key] = 'UNK'
                if def_in == 'UNK':
                    self.state['rest_slots'][self.default_key] = 'UNK'
            else:
                self.state['intent'] = 'thanks'

    def _response_to_match_found(self, agent_action):


        agent_informs = agent_action['inform_slots']

        self.state['intent'] = 'thanks'
        self.constraint_check = SUCCESS

        assert self.default_key in agent_informs
        self.state['rest_slots'].pop(self.default_key, None)
        self.state['history_slots'][self.default_key] = str(agent_informs[self.default_key])
        self.state['request_slots'].pop(self.default_key, None)

        if agent_informs[self.default_key] == 'no match available':
            self.constraint_check = FAIL
        for key, value in self.goal['inform_slots'].items():
            assert value != None

            if key in self.no_query:
                continue

            if value != agent_informs.get(key, None):
                self.constraint_check = FAIL
                break

        if self.constraint_check == FAIL:
            self.state['intent'] = 'reject'
            self.state['request_slots'].clear()

    def _response_to_done(self):


        if self.constraint_check == FAIL:
            return FAIL

        if not self.state['rest_slots']:
            assert not self.state['request_slots']
        if self.state['rest_slots']:
            return FAIL


        assert self.state['history_slots'][self.default_key] != 'no match available'

        match = copy.deepcopy(self.database[int(self.state['history_slots'][self.default_key])])

        for key, value in self.goal['inform_slots'].items():
            assert value != None
            if key in self.no_query:
                continue
            if value != match.get(key, None):
                assert True is False, 'match: {}\ngoal: {}'.format(match, self.goal)
                break


        return SUCCESS

In [19]:

class DQNAgent:
  

    def __init__(self, state_size, constants):
        """
       Конструктор класса агента, сохранящий главные константы 
        """
        self.C = constants['agent']
        self.memory = []
        self.memory_index = 0
        self.max_memory_size = self.C['max_mem_size']
        self.eps = self.C['epsilon_init']
        self.vanilla = self.C['vanilla']
        self.lr = self.C['learning_rate']
        self.gamma = self.C['gamma']
        self.batch_size = self.C['batch_size']
        self.hidden_size = self.C['dqn_hidden_size']
        
        self.load_weights_file_path = self.C['load_weights_file_path']
        self.save_weights_file_path = self.C['save_weights_file_path']
        self.device = device
        if self.max_memory_size < self.batch_size:
            raise ValueError('Max memory size must be at least as great as batch size!')

        self.state_size = state_size
        self.possible_actions = agent_actions
        self.num_actions = len(self.possible_actions)
        self.rule_request_set = rule_requests
        self.beh_model = self._build_model()
        self.tar_model = self._build_model()
        self.optim = Adam(self.beh_model.parameters(), lr=self.lr)
        self.criterion=MSELoss()

        self._load_weights()

        self.reset()


    
    def _build_model(self):
        """Построение однослойной модели"""

        model = Sequential(
            Linear(self.state_size, self.hidden_size, dtype=float),
            ReLU(),
            Linear(self.hidden_size, self.num_actions, dtype=float)
        )
        
        return model.to(self.device)
    
    
    def reset(self):

        self.rule_current_slot_index = 0
        self.rule_phase = 'not done'

    def get_action(self, state, use_rule=False):

        if self.eps > random.random():
            index = random.randint(0, self.num_actions - 1)
            action = self._map_index_to_action(index)
            return index, action
        else:
            if use_rule:
                return self._rule_action()
            else:
                return self._dqn_action(state)

    def _rule_action(self):

        if self.rule_current_slot_index < len(self.rule_request_set):
            slot = self.rule_request_set[self.rule_current_slot_index]
            self.rule_current_slot_index += 1
            rule_response = {'intent': 'request', 'inform_slots': {}, 'request_slots': {slot: 'UNK'}}
        elif self.rule_phase == 'not done':
            rule_response = {'intent': 'match_found', 'inform_slots': {}, 'request_slots': {}}
            self.rule_phase = 'done'
        elif self.rule_phase == 'done':
            rule_response = {'intent': 'done', 'inform_slots': {}, 'request_slots': {}}
        else:
            raise Exception('Should not have reached this clause')

        index = self._map_action_to_index(rule_response)
        return index, rule_response

    def _map_action_to_index(self, response):

        for (i, action) in enumerate(self.possible_actions):
            if response == action:
                return i
        raise ValueError('Response: {} not found in possible actions'.format(response))

    def _dqn_action(self, state):

        index = np.argmax(self._dqn_predict_one(state))
        action = self._map_index_to_action(index)
        return index, action

    def _map_index_to_action(self, index):


        for (i, action) in enumerate(self.possible_actions):
            if index == i:
                return copy.deepcopy(action)
        raise ValueError('Index: {} not in range of possible actions'.format(index))

    def _dqn_predict_one(self, state, target=False):

        return self._dqn_predict(state.reshape(1, self.state_size), target=target).flatten()


    
    
    
    def _dqn_predict(self, states, target=False):
        """
 Функция, возвращающая предсказываемую Q-функцию, основываясь на состоянии
        """

        model = self.tar_model if target else self.beh_model
        return model(torch.tensor(states, dtype=float, device=self.device)).detach().cpu().numpy()
    



    def add_experience(self, state, action, reward, next_state, done):

        if len(self.memory) < self.max_memory_size:
            self.memory.append(None)
        self.memory[self.memory_index] = (state, action, reward, next_state, done)
        self.memory_index = (self.memory_index + 1) % self.max_memory_size

    def empty_memory(self):

        self.memory = []
        self.memory_index = 0

    def is_memory_full(self):


        return len(self.memory) == self.max_memory_size

    def train(self):

        num_batches = len(self.memory) // self.batch_size
        for b in range(num_batches):
            batch = random.sample(self.memory, self.batch_size)

            states = np.array([sample[0] for sample in batch])
            next_states = np.array([sample[3] for sample in batch])

            assert states.shape == (self.batch_size, self.state_size), 'States Shape: {}'.format(states.shape)
            assert next_states.shape == states.shape

            beh_state_preds = self._dqn_predict(states) 
            if not self.vanilla:
                beh_next_states_preds = self._dqn_predict(next_states)  
            tar_next_state_preds = self._dqn_predict(next_states, target=True) 

            inputs = np.zeros((self.batch_size, self.state_size))
            targets = np.zeros((self.batch_size, self.num_actions))

            for i, (s, a, r, s_, d) in enumerate(batch):
                t = beh_state_preds[i]
                if not self.vanilla:
                    t[a] = r + self.gamma * tar_next_state_preds[i][np.argmax(beh_next_states_preds[i])] * (not d)
                else:
                    t[a] = r + self.gamma * np.amax(tar_next_state_preds[i]) * (not d)

                inputs[i] = s
                targets[i] = t

            self.beh_model.train()
            batch_size = 32
            for ix in range(0, len(inputs), batch_size):
                batch_inputs = inputs[ix:ix+batch_size]
                batch_target = targets[ix:ix+batch_size]

                preds = self.beh_model(torch.tensor(batch_inputs, dtype=float, device=self.device))
                loss = self.criterion(torch.tensor(batch_target, dtype=float, device=self.device), preds)

                self.optim.zero_grad()
                loss.backward()

                self.optim.step()
    
    def copy(self):

        self.tar_model.load_state_dict(self.beh_model.state_dict())

    def save_weights(self):

        if not self.save_weights_file_path:
            return
        beh_save_file_path = re.sub(r'\.h5', r'_beh.h5', self.save_weights_file_path)
        self.beh_model.save_weights(beh_save_file_path)
        tar_save_file_path = re.sub(r'\.h5', r'_tar.h5', self.save_weights_file_path)
        self.tar_model.save_weights(tar_save_file_path)

    def _load_weights(self):
       

        if not self.load_weights_file_path:
            return
        beh_load_file_path = re.sub(r'\.h5', r'_beh.h5', self.load_weights_file_path)
        self.beh_model.load_weights(beh_load_file_path)
        tar_load_file_path = re.sub(r'\.h5', r'_tar.h5', self.load_weights_file_path)
        self.tar_model.load_weights(tar_load_file_path)


In [20]:
import random


class ErrorModelController:

    def __init__(self, db_dict, constants):

        self.movie_dict = db_dict
        self.slot_error_prob = constants['emc']['slot_error_prob']
        self.slot_error_mode = constants['emc']['slot_error_mode']  # [0, 3]
        self.intent_error_prob = constants['emc']['intent_error_prob']
        self.intents = usersim_intents

    def infuse_error(self, frame):

        informs_dict = frame['inform_slots']
        for key in list(frame['inform_slots'].keys()):
            assert key in self.movie_dict
            if random.random() < self.slot_error_prob:
                if self.slot_error_mode == 0: 
                    self._slot_value_noise(key, informs_dict)
                elif self.slot_error_mode == 1: 
                    self._slot_noise(key, informs_dict)
                elif self.slot_error_mode == 2:  
                    self._slot_remove(key, informs_dict)
                else:  
                    rand_choice = random.random()
                    if rand_choice <= 0.33:
                        self._slot_value_noise(key, informs_dict)
                    elif rand_choice > 0.33 and rand_choice <= 0.66:
                        self._slot_noise(key, informs_dict)
                    else:
                        self._slot_remove(key, informs_dict)
        if random.random() < self.intent_error_prob:
            frame['intent'] = random.choice(self.intents)

    def _slot_value_noise(self, key, informs_dict):

        informs_dict[key] = random.choice(self.movie_dict[key])

    def _slot_noise(self, key, informs_dict):
        informs_dict.pop(key)
        random_slot = random.choice(list(self.movie_dict.keys()))
        informs_dict[random_slot] = random.choice(self.movie_dict[random_slot])

    def _slot_remove(self, key, informs_dict):


        informs_dict.pop(key)


In [9]:
class DBQuery:
    def __init__(self, database):

        self.database = database
        self.cached_db_slot = defaultdict(dict)
        self.cached_db = defaultdict(dict)
        self.no_query = no_query_keys
        self.match_key = usersim_default_key

    def fill_inform_slot(self, inform_slot_to_fill, current_inform_slots):

        assert len(inform_slot_to_fill) == 1

        key = list(inform_slot_to_fill.keys())[0]

        current_informs = copy.deepcopy(current_inform_slots)
        current_informs.pop(key, None)

        db_results = self.get_db_results(current_informs)

        filled_inform = {}
        values_dict = self._count_slot_values(key, db_results)
        if values_dict:

            filled_inform[key] = max(values_dict, key=values_dict.get)
        else:
            filled_inform[key] = 'no match available'

        return filled_inform

    def _count_slot_values(self, key, db_subdict):
        slot_values = defaultdict(int)  # init to 0
        for id in db_subdict.keys():
            current_option_dict = db_subdict[id]
            if key in current_option_dict.keys():
                slot_value = current_option_dict[key]
                slot_values[slot_value] += 1
        return slot_values

    def get_db_results(self, constraints):

        new_constraints = {k: v for k, v in constraints.items() if k not in self.no_query and v != 'anything'}

        inform_items = frozenset(new_constraints.items())
        cache_return = self.cached_db[inform_items]

        if cache_return == None:
            return {}
        if cache_return:
            return cache_return


        available_options = {}
        for id in self.database.keys():
            current_option_dict = self.database[id]

            if len(set(new_constraints.keys()) - set(self.database[id].keys())) == 0:
                match = True
                for k, v in new_constraints.items():
                    if str(v).lower() != str(current_option_dict[k]).lower():
                        match = False
                if match:

                    self.cached_db[inform_items].update({id: current_option_dict})
                    available_options.update({id: current_option_dict})

        if not available_options:
            self.cached_db[inform_items] = None

        return available_options

    def get_db_results_for_slots(self, current_informs):
        inform_items = frozenset(current_informs.items())
        cache_return = self.cached_db_slot[inform_items]

        if cache_return:
            return cache_return
        db_results = {key: 0 for key in current_informs.keys()}
        db_results['matching_all_constraints'] = 0

        for id in self.database.keys():
            all_slots_match = True
            for CI_key, CI_value in current_informs.items():
                if CI_key in self.no_query:
                    continue
                if CI_value == 'anything':
                    db_results[CI_key] += 1
                    continue
                if CI_key in self.database[id].keys():
                    if CI_value.lower() == self.database[id][CI_key].lower():
                        db_results[CI_key] += 1
                    else:
                        all_slots_match = False
                else:
                    all_slots_match = False
            if all_slots_match: db_results['matching_all_constraints'] += 1

        self.cached_db_slot[inform_items].update(db_results)
        assert self.cached_db_slot[inform_items] == db_results
        return db_results


In [10]:
class StateTracker:

    def __init__(self, database, constants):


        self.db_helper = DBQuery(database)
        self.match_key = usersim_default_key
        self.intents_dict = convert_list_to_dict(all_intents)
        self.num_intents = len(all_intents)
        self.slots_dict = convert_list_to_dict(all_slots)
        self.num_slots = len(all_slots)
        self.max_round_num = constants['run']['max_round_num']
        self.none_state = np.zeros(self.get_state_size())
        self.reset()

    def get_state_size(self):

        return 2 * self.num_intents + 7 * self.num_slots + 3 + self.max_round_num

    def reset(self):

        self.current_informs = {}
        self.history = []
        self.round_num = 0

    def print_history(self):

        for action in self.history:
            print(action)

    def get_state(self, done=False):

        if done:
            return self.none_state

        user_action = self.history[-1]
        db_results_dict = self.db_helper.get_db_results_for_slots(self.current_informs)
        last_agent_action = self.history[-2] if len(self.history) > 1 else None

        user_act_rep = np.zeros((self.num_intents,))
        user_act_rep[self.intents_dict[user_action['intent']]] = 1.0

        user_inform_slots_rep = np.zeros((self.num_slots,))
        for key in user_action['inform_slots'].keys():
            user_inform_slots_rep[self.slots_dict[key]] = 1.0

        user_request_slots_rep = np.zeros((self.num_slots,))
        for key in user_action['request_slots'].keys():
            user_request_slots_rep[self.slots_dict[key]] = 1.0

        current_slots_rep = np.zeros((self.num_slots,))
        for key in self.current_informs:
            current_slots_rep[self.slots_dict[key]] = 1.0

        agent_act_rep = np.zeros((self.num_intents,))
        if last_agent_action:
            agent_act_rep[self.intents_dict[last_agent_action['intent']]] = 1.0

        agent_inform_slots_rep = np.zeros((self.num_slots,))
        if last_agent_action:
            for key in last_agent_action['inform_slots'].keys():
                agent_inform_slots_rep[self.slots_dict[key]] = 1.0

        agent_request_slots_rep = np.zeros((self.num_slots,))
        if last_agent_action:
            for key in last_agent_action['request_slots'].keys():
                agent_request_slots_rep[self.slots_dict[key]] = 1.0

        turn_rep = np.zeros((1,)) + self.round_num / 5.

        turn_onehot_rep = np.zeros((self.max_round_num,))
        turn_onehot_rep[self.round_num - 1] = 1.0

        kb_count_rep = np.zeros((self.num_slots + 1,)) + db_results_dict['matching_all_constraints'] / 100.
        for key in db_results_dict.keys():
            if key in self.slots_dict:
                kb_count_rep[self.slots_dict[key]] = db_results_dict[key] / 100.

        kb_binary_rep = np.zeros((self.num_slots + 1,)) + np.sum(db_results_dict['matching_all_constraints'] > 0.)
        for key in db_results_dict.keys():
            if key in self.slots_dict:
                kb_binary_rep[self.slots_dict[key]] = np.sum(db_results_dict[key] > 0.)

        state_representation = np.hstack(
            [user_act_rep, user_inform_slots_rep, user_request_slots_rep, agent_act_rep, agent_inform_slots_rep,
             agent_request_slots_rep, current_slots_rep, turn_rep, turn_onehot_rep, kb_binary_rep,
             kb_count_rep]).flatten()

        return state_representation

    def update_state_agent(self, agent_action):
        if agent_action['intent'] == 'inform':
            assert agent_action['inform_slots']
            inform_slots = self.db_helper.fill_inform_slot(agent_action['inform_slots'], self.current_informs)
            agent_action['inform_slots'] = inform_slots
            assert agent_action['inform_slots']
            key, value = list(agent_action['inform_slots'].items())[0]  # Only one
            assert key != 'match_found'
            assert value != 'PLACEHOLDER', 'KEY: {}'.format(key)
            self.current_informs[key] = value

        elif agent_action['intent'] == 'match_found':
            assert not agent_action['inform_slots'], 'Cannot inform and have intent of match found!'
            db_results = self.db_helper.get_db_results(self.current_informs)
            if db_results:

                key, value = list(db_results.items())[0]
                agent_action['inform_slots'] = copy.deepcopy(value)
                agent_action['inform_slots'][self.match_key] = str(key)
            else:
                agent_action['inform_slots'][self.match_key] = 'no match available'
            self.current_informs[self.match_key] = agent_action['inform_slots'][self.match_key]
        agent_action.update({'round': self.round_num, 'speaker': 'Agent'})
        self.history.append(agent_action)

    def update_state_user(self, user_action):

        for key, value in user_action['inform_slots'].items():
            self.current_informs[key] = value
        user_action.update({'round': self.round_num, 'speaker': 'User'})
        self.history.append(user_action)
        self.round_num += 1


In [11]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [33]:
run = wandb.init(
    project="cource",
    config=constants)


0,1
avg_reward,▁▃▆▆▄▇▃▅▄▇▇▅▅▅█▅▆▅▇█▇▇▇▆▇▇▇▃▆▇▆▇▆▅▆█▆▅▇▆
avg_reward_test,█▁██████████████▁█████▁█▁█████████▁██▁██
success_rate,▁▁▅▆▃▇▃▅▄▇▆▄▅▅█▅▆▅▇█▆▆▆▆▇▇▇▃▆▆▆▆▆▅▆█▆▅▆▆
success_rate_test,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
avg_reward,8.3
avg_reward_test,-2.1
success_rate,0.6
success_rate_test,0.0


In [26]:
import pickle, argparse, json, math


if __name__ == "__main__":

    file_path_dict = constants['db_file_paths']
    DATABASE_FILE_PATH = file_path_dict['database']
    DICT_FILE_PATH = file_path_dict['dict']
    USER_GOALS_FILE_PATH = file_path_dict['user_goals']

    run_dict = constants['run']
    USE_USERSIM = run_dict['usersim']
    WARMUP_MEM = run_dict['warmup_mem']
    NUM_EP_TRAIN = run_dict['num_ep_run']
    TRAIN_FREQ = run_dict['train_freq']
    MAX_ROUND_NUM = run_dict['max_round_num']
    SUCCESS_RATE_THRESHOLD = run_dict['success_rate_threshold']

    database = pickle.load(open(DATABASE_FILE_PATH, 'rb'), encoding='latin1')




    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

   
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')


    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def run_round(state, warmup=False):
    agent_action_index, agent_action = dqn_agent.get_action(state, use_rule=warmup)
    #print('AgAct', agent_action)
    state_tracker.update_state_agent(agent_action)
    user_action, reward, done, success = user.step(agent_action)
    #print('UsAct', user_action)
    if not done:
  
        emc.infuse_error(user_action)

    state_tracker.update_state_user(user_action)

    next_state = state_tracker.get_state(done)
    dqn_agent.add_experience(state, agent_action_index, reward, next_state, done)

    return next_state, reward, done, success


def warmup_run():

    print('Warmup Started...')
    total_step = 0
    progress_bar = tqdm(total=WARMUP_MEM)
    while total_step < WARMUP_MEM and not dqn_agent.is_memory_full():

        episode_reset()
        done = False

        state = state_tracker.get_state()
        while not done:
            next_state, _, done, _ = run_round(state, warmup=True)
            total_step += 1
            state = next_state
            progress_bar.update(1)

    print('...Warmup Ended')


def train_run():

    print('Training Started...')
    episode = 0
    period_reward_total = 0
    period_success_total = 0
    success_rate_best = 0.0
    progress_bar = tqdm(total=NUM_EP_TRAIN)
    while episode < NUM_EP_TRAIN:
        episode_reset()
        episode += 1
        progress_bar.update(1)
        done = False
        state = state_tracker.get_state()
        while not done:
            next_state, reward, done, success = run_round(state)
            #print(user.action)
            period_reward_total += reward
            state = next_state

        period_success_total += success
        
        if episode % 10 == 0:
            wandb.log({"success_rate": period_success_total / 10, 'avg_reward': period_reward_total / 10})
            period_success_total = 0
            period_reward_total = 0
        if episode % TRAIN_FREQ == 0:
            success_rate = period_success_total / TRAIN_FREQ
            avg_reward = period_reward_total / TRAIN_FREQ
            if success_rate >= success_rate_best and success_rate >= SUCCESS_RATE_THRESHOLD:
                dqn_agent.empty_memory()

            if success_rate > success_rate_best:
                success_rate_best = success_rate
                dqn_agent.save_weights()
            period_success_total = 0
            period_reward_total = 0
            dqn_agent.copy()
            dqn_agent.train()
    print(period_success_total)
    print('...Training Ended')


def episode_reset():
    state_tracker.reset()
    user_action = user.reset()
    emc.infuse_error(user_action)
    state_tracker.update_state_user(user_action)
    dqn_agent.reset()

    

warmup_run()
train_run()


Warmup Started...


  0%|          | 0/1000 [00:00<?, ?it/s]

...Warmup Ended
Training Started...


  0%|          | 0/5000 [00:00<?, ?it/s]

0
...Training Ended


In [31]:
constants = {
  "db_file_paths": {
    "database": "/kaggle/input/test-cam/cameras_db_test.pkl",
    "dict": "/kaggle/input/testt-cam/cameras_dict_test.pkl",
    "user_goals": "/kaggle/input/test-cam/camera_user_goals_test.pkl"
  },
  "run": {
    "usersim": True,
    "warmup_mem": 1000 ,
    "num_ep_run": 5000,
    "train_freq": 100,
    "max_round_num": 20,
    "success_rate_threshold": 0.3
  },
  "agent": {
    "save_weights_file_path": "",
    "load_weights_file_path": "",
    "vanilla": True,
    "learning_rate": 1e-3,
    "batch_size": 16,
    "dqn_hidden_size": 80,
    "epsilon_init": 0.0,
    "gamma": 0.9,
    "max_mem_size": 500_000
  },
  "emc": {
    "slot_error_mode": 0,
    "slot_error_prob": 0.05,
    "intent_error_prob": 0.0
  }
}

In [32]:


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--constants_path', dest='constants_path', type=str, default='')
    args, unknown = parser.parse_known_args()
    params = vars(args)

    
    file_path_dict = constants['db_file_paths']
    DATABASE_FILE_PATH = file_path_dict['database']
    DICT_FILE_PATH = file_path_dict['dict']
    USER_GOALS_FILE_PATH = file_path_dict['user_goals']

    run_dict = constants['run']
    USE_USERSIM = run_dict['usersim']
    NUM_EP_TEST = run_dict['num_ep_run']
    MAX_ROUND_NUM = run_dict['max_round_num']


    database = pickle.load(open(DATABASE_FILE_PATH, 'rb'), encoding='latin1')

    remove_empty_slots(database)

    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')

    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def test_run():
   

    print('Testing Started...')
    episode = 0
    while episode < NUM_EP_TEST:
        episode_reset()
        episode += 1
        ep_reward = 0
        
        success_total=0
        
        done = False
        state = state_tracker.get_state()
        while not done:
       
            agent_action_index, agent_action = dqn_agent.get_action(state)
    
            state_tracker.update_state_agent(agent_action)
   
            user_action, reward, done, success = user.step(agent_action)
            ep_reward += reward
            success_total+=success
            
            if episode % 10 == 0:
                wandb.log({"success_rate_test": success_total / 10, 'avg_reward_test': ep_reward / 10})
                ep_reward = 0
            if not done:
                emc.infuse_error(user_action)
            state_tracker.update_state_user(user_action)
            state = state_tracker.get_state(done)
#         print('Episode: {} Success: {} Reward: {}'.format(episode, success, ep_reward))
#         print('Действие пользователя:{}'.format(user_action))
#         print('Действие агента:{}'.format(agent_action))
        
    print('...Testing Ended')


def episode_reset():

    state_tracker.reset()
    user_action = user.reset()
    emc.infuse_error(user_action)
    state_tracker.update_state_user(user_action)
    dqn_agent.reset()


test_run()


Testing Started...
...Testing Ended
