In [1]:
import pandas as pd

In [2]:
class QLearner:
    def __init__(self,
                 state_space_parameters, 
                 epsilon,
                 WeightInitializer=None,
                 device=None,
                 args=None,
                 save_path=None,
                 state=None,
                 qstore=None,
                 replaydict = None,
                 replay_dictionary = pd.DataFrame(columns=['net',
                                                           'spp_size',
                                                           'reward',
                                                           'epsilon',
                                                           'train_flag'])):
        self.state_list = []
        self.state_space_parameters = state_space_parameters
        self.args = args
        self.enum = se.StateEnumerator(state_space_parameters, args)
        self.stringutils = StateStringUtils(state_space_parameters, args)
        self.state = se.State('start', 0, 1, 0, 0, args.patch_size, 0, 0) if not state else state
        self.qstore = QValues() 
        if  type(qstore) is not type(None):
            self.qstore.load_q_values(qstore)
            self.replay_dictionary = pd.read_csv(replaydict, index_col=0)
        else:
            self.replay_dictionary = replay_dictionary
        self.epsilon = epsilon
        self.WeightInitializer = WeightInitializer
        self.device = device
        self.gpu_mem_0 = GPUMem(torch.device('cuda') == self.device)
        self.save_path = save_path
        
        self.count = args.continue_ite - 1 
        # 137 (hard-coded no. for epsilon < 1)

    def generate_net(self, epsilon=None, dataset=None):
        if epsilon != None:
            self.epsilon = epsilon
        self.reset_for_new_walk()
        state_list = self.run_agent()

        net_string = self.stringutils.state_list_to_string(
            state_list, num_classes=len(dataset.val_loader.dataset.class_to_idx))

        train_flag = True
        if net_string in self.replay_dictionary['net'].values:
            spp_size = self.replay_dictionary[self.replay_dictionary['net']
                                              == net_string['spp_size'].values[0]]

    def reset_for_new_walk(self):

        self.state_list = []
        self.state = se.State('start', 0, 1, 0, 0, self.args.patch_size, 0, 0)

    def run_agent(self):
        while self.state.terminate == 0:
            self.transition_q_learning()
        return self.state_list

    def transition_q_learning(self):
        if self.state.as_tuple() not in self.qstore.q:
            self.enum.enumerate_state(self.state, self.qstore.q)        
        action_values = self.qstore.q[self.state.as_tuple()]
        if np.random.random() < self.epsilon:
            action = se.State(
                state_list=action_values['actions'][np.random.randint(
                    len(action_values['actions']))])
        else:
            max_q_value = max(action_values['utilities'])
            max_q_indexes = [i for i in range(len(action_values['actions'])) 
                if action_values['utilities'][i] == max_q_value]
            max_actions = [action_values['actions'][i] for i in max_q_indexes]
            action = se.State(state_list=max_actions[np.random.randint(len(max_actions))])

        self.state = self.enum.state_action_transition(self.state, action)
        self.__post_transition_updates()

    def post_transition_updates(self):
        non_bucketed_state = self.state.copy()
        self.state_list.append(non_bucketed_state)

    def sample_replay_for_update(self):
        net = self.replay_dictionary.iloc[-1]['net']
        reward_best_val = self.replay_dictionary.iloc[-1]['reward']
        train_flag = self.replay_dictionary.iloc[-1]['train_flag']
        state_list = self.stringutils.convert_model_string_to_states(cnn_parse('net', net))
        # if train_flag:
        self.__update_q_value_sequence(state_list, self.__accuracy_to_reward(reward_best_val/100.))

        for i in range(self.state_space_parameters.replay_number-1):
            net = np.random.choice(self.replay_dictionary['net'])
            reward_best_val = self.replay_dictionary[self.replay_dictionary['net'] == net]['reward'].values[0]
            train_flag = self.replay_dictionary[self.replay_dictionary['net'] == net]['train_flag'].values[0]
            state_list = self.stringutils.convert_model_string_to_states(cnn_parse('net', net))
            # if train_flag == True:
            self.__update_q_value_sequence(state_list, self.__accuracy_to_reward(reward_best_val/100.))            

    def accuracy_to_reward(self, acc):
        return acc

    def update_q_value_sequence(self, states, termination_reward):
        self.update_q_value(states[-2], states[-1], termination_reward)
        for i in reversed(range(len(states) - 2)):
            
            # TODO: q-learning update (set proper q-learning rate in cmdparser.py)
            self.update_q_value(states[i], states[i+1], 0)

            # TODO: modified update for shorter search schedules (doesn't use q-learning rate in computation)
            # self.__update_q_value(states[i], states[i+1], termination_reward)