# Utils

In [1]:
from scipy.signal import lfilter


def update_target_graph(from_scope,to_scope):
    """
    Copies one set of variables to another.
    Used to set worker network parameters to those of global network.
    """
    from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)
    to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope)

    op_holder = []
    for from_var,to_var in zip(from_vars,to_vars):
        op_holder.append(to_var.assign(from_var))
    return op_holder


def discount(x, gamma):
    """
    Discounting function used to calculate discounted returns.
    """
    return lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]


def normalized_columns_initializer(std=1.0):
    """
    Used to initialize weights for policy and value output layers
    """
    def _initializer(shape, dtype=None, partition_info=None):
        out = np.random.randn(*shape).astype(np.float32)
        out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
        return tf.constant(out)
    return _initializer

def look_for_folder(main_folder='priors/', exp=''):
    """
    looks for a given folder and returns it.
    If it cannot find it, returns possible candidates
    """
    data_path = ''
    possibilities = []
    for root, dirs, files in os.walk(main_folder):
        ind = root.rfind('/')
        possibilities.append(root[ind+1:])
        if root[ind+1:] == exp:
            data_path = root
            break

    if data_path == '':
        candidates = difflib.get_close_matches(exp, possibilities,
                                               n=1, cutoff=0.)
        print(exp + ' NOT FOUND IN ' + main_folder)
        if len(candidates) > 0:
            print('possible candidates:')
            print(candidates)

    return data_path


def list_str(l):
    """
    list to str
    """
    nice_string = str(l[0])
    for ind_el in range(1, len(l)):
        nice_string += '_'+str(l[ind_el])
    return nice_string


def num2str(num):
    """
    pass big number to thousands
    """
    return str(int(num/1000))+'K'


# Network

In [2]:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim

def RNN_UGRU(inputs, prev_rewards, a_size, num_units):

    # create a UGRNNCell
    rnn_cell = tf.contrib.rnn.UGRNNCell(num_units, activation=tf.nn.relu)

    # this is the initial state used in the A3C model when training
    # or obtaining an action
    st_init = np.zeros((1, rnn_cell.state_size), np.float32)

    # defining initial state
    state_in = tf.placeholder(tf.float32, [1, rnn_cell.state_size])

    # reshape inputs size
    rnn_in = tf.expand_dims(inputs, [0])

    step_size = tf.shape(prev_rewards)[:1]

    # 'state' is a tensor of shape [batch_size, cell_state_size]
    # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
    outputs, state_out = tf.nn.dynamic_rnn(rnn_cell, rnn_in,
                                           initial_state=state_in,
                                           sequence_length=step_size,
                                           dtype=tf.float32,
                                           time_major=False)

    rnn_out = tf.reshape(outputs, [-1, num_units])

    actions, actions_onehot, policy, value = \
        process_output(rnn_out, outputs, a_size, num_units)

    return st_init, state_in, state_out, actions, actions_onehot, policy, value

def process_output(rnn_out, outputs, a_size, num_units):
    # Actions
    actions = tf.placeholder(shape=[None], dtype=tf.int32)
    actions_onehot = tf.one_hot(actions, a_size, dtype=tf.float32)

    # Output layers for policy and value estimations
    policy = slim.fully_connected(rnn_out, a_size,
                                  activation_fn=tf.nn.softmax,
                                  weights_initializer=normalized_columns_initializer(0.01),
                                  biases_initializer=None)
    value = slim.fully_connected(rnn_out, 1,
                                 activation_fn=None,
                                 weights_initializer=normalized_columns_initializer(1.0),
                                 biases_initializer=None)

    return actions, actions_onehot, policy, value

  from ._conv import register_converters as _register_converters


# Environment

## Data management

In [3]:

class data():
    def __init__(self, folder=''):
        # point by point parameter mats saved for some trials
        self.states_point = []
        self.net_state_point = []
        self.rewards_point = []
        self.done_point = []
        self.actions_point = []
        self.corrects_point = []
        self.new_trial_point = []
        self.trials_point = []
        self.stims_conf_point = []
        # where to save the trials data
        self.folder = folder

    def reset(self):
        """
        reset all mats
        """
        # reset parameters mat
        self.states_point = []
        self.net_state_point = []
        self.rewards_point = []
        self.done_point = []
        self.actions_point = []
        self.corrects_point = []
        self.new_trial_point = []
        self.stims_conf_point = []
        self.trials_point = []

    def update(self, new_state=[], net_state=[], reward=None, update_net=None,
               action=None, correct=[], new_trial=None, num_trials=None,
               stim_conf=[]):
        """
        append available info
        """
        if len(new_state) != 0:
            self.states_point.append(new_state)
        if len(net_state) != 0:
            self.net_state_point.append(net_state)
        if reward is not None:
            self.rewards_point.append(reward)
        if update_net is not None:
            self.done_point.append(update_net)  # 0 by construction
        if action is not None:
            self.actions_point.append(action)
        if len(correct) != 0:
            self.corrects_point.append(correct)
        if new_trial is not None:
            self.new_trial_point.append(new_trial)  # 0 by construction
        if num_trials is not None:
            self.trials_point.append(num_trials)
        if len(stim_conf) != 0:
            self.stims_conf_point.append(stim_conf)

    def save(self, num_trials):
        """
        save data
        """
        data = {'states': self.states_point, 'net_state': self.net_state_point,
                'rewards': self.rewards_point, 'done_flags': self.done_point,
                'actions': self.actions_point, 'corrects': self.corrects_point,
                'new_trial_flags': self.new_trial_point,
                'trials_saved': self.trials_point,
                'stims_conf': self.stims_conf_point}
        np.savez(self.folder + '/all_points_' + str(num_trials) + '.npz',
                 **data)


## Task

In [4]:
class expectations():
    def __init__(self, update_net_step=100, trial_duration=10,
                 repeating_prob=(0.2, 0.7), rewards=(-0.1, 0.0, 1.0, -1.0),
                 block_dur=200, stim_evidence=0.5, folder=''):
        # every X trial, the RNN network is updated with the samples
        # collected by the agent
        self.update_net_step = update_net_step

        # number of stim values presented. During the presentation of the stim,
        # the net has to 'fixate'
        self.td = trial_duration

        # rewards for:
        # [stop fixating too early, correctly keep fixating,
        # get the stimulus right, get the stimulus wrong]
        self.rewards = rewards

        # this variable describes the distributions from which the
        # presented stimuli are drawn (it might vary from task to task)
        self.internal_state = []

        self.num_tr = 0
        self.num_actions = 3

        # keeps track of the repeating prob of the current block
        self.curr_rep_prob = 0

        # number of trials for repeating/alternating blocks
        self.block_dur = block_dur

        # position of the stimuli
        self.stms_pos_new_trial = 0

        # stimulus evidence: one stimulus is always N(1,1), the mean of
        # the other is drawn from a uniform distrib.=U(stim_ev,1).
        # stim_evidence must then be between 0 and 1 and the higher it is
        # the more difficult is the task
        self.stim_evidence = np.max([stim_evidence, 10e-5])

        # prob. of repeating the stimuli in the positions of previous trial
        self.repeating_prob = repeating_prob

        # SAVED PARAMETERS AT THE END OF THE TRIAL

        # stimulus evidence
        self.evidence_mat = []

        # mean of stimulus 2 (mean of stimulus 1 is always 1)
        self.stim2_mat = []

        # position of stimulus 1
        self.stms_pos = []

        # whether the stimulus is repeating or not
        self.repeat_mat = []

        # reward
        self.reward_mat = []

        # performance
        self.perf_mat = []

        # duration of trial
        self.dur_tr = []

        # current repeating probability
        self.rep_prob = []

        # summed activity across the trial
        self.net_smmd_act = []

        # summed activity across the trial
        self.action = []

        # folder
        self.folder = folder

        # point by point parameter mats saved for some trials
        self.all_pts_data = data(folder=folder)

        # save all points step. Here I call the class data that implements
        # all the necessary functions
        self.sv_pts_stp = 10
        self.num_tr_svd = 1000

    def get_state(self):
        self.timestep += 1  # this was previously in pullArm
        if self.timestep < self.td:
            self.state = [np.random.normal(self.int_st[0], scale=1),
                          np.random.normal(self.int_st[1],scale=1), -1]
        else:
            self.state = [0, 0, 0]

        self.evidence += self.state[0]-self.state[1]
        self.state = np.reshape(self.state, [1, self.num_actions, 1])

        return self.state

    def new_trial(self):
        self.num_tr += 1
        self.timestep = 0
        self.evidence = 0
        stim1 = 1.0
        stim2 = np.random.uniform(1-self.stim_evidence, 1)
        assert stim2 != 1.0
        a = [stim1, stim2]
        # right now self.true is always = 1,
        # so you could consider just removing this variable
        self.true = a[0]
        self.choices = a

        # decide the position of the stims
        # if the block is finished update the prob of repeating
        if self.num_tr % self.block_dur == 0:
            self.curr_rep_prob = int(not self.curr_rep_prob)

        # flip a coin
        repeat = np.random.uniform() < self.repeating_prob[self.curr_rep_prob]
        if not repeat:
            self.stms_pos_new_trial = not(self.stms_pos_new_trial)

        aux = [self.choices[x] for x in [int(self.stms_pos_new_trial),
               int(not self.stms_pos_new_trial)]]

        self.int_st = np.concatenate((aux, np.array([-1])))

        # store some data about trial
        self.stim2_mat.append(stim2)
        self.stms_pos.append(self.stms_pos_new_trial)
        self.repeat_mat.append(repeat)
        self.rep_prob.append(self.curr_rep_prob)

        # get state
        s = self.get_state()

        # during some episodes I save all data points
        if np.floor(self.num_tr/self.num_tr_svd) % self.sv_pts_stp == 0:
            self.all_pts_data.update(new_state=s,
                                     new_trial=1,
                                     num_trials=self.num_tr)

        return s

    def pullArm(self, action, net_st=[]):
        """
        receives an action and returns a reward, a state and flag variables
        that indicate whether to start a new trial and whether to update
        the network
        """
        trial_dur = 0
        new_trial = True
        correct = False
        update_net = False

        # decide which reward and state (new_trial, correct) we are in
        if self.timestep < self.td:
            if (self.int_st[action] != -1).all():
                reward = self.rewards[0]
            else:
                # don't abort the trial even if the network stops fixating
                reward = self.rewards[1]

            new_trial = False

        else:
            if (self.int_st[action] == self.true).all():
                reward = self.rewards[2]
                correct = True
            else:
                reward = self.rewards[3]
            trial_dur = self.timestep

        if new_trial:
            # current trial info
            self.dur_tr.append(trial_dur)
            self.perf_mat.append(correct)
            self.action.append(action)
            self.reward_mat.append(reward)
            self.evidence_mat.append(self.evidence)
            new_st = None

            # check if it is time to update the network
            update_net = ((self.num_tr-1) % self.update_net_step == 0) and\
                (self.num_tr != 1)

            # point by point parameter mats saved for some periods
            if np.floor(self.num_tr / self.num_tr_svd) % self.sv_pts_stp == 0:
                self.all_pts_data.update(net_state=net_st, reward=reward,
                                         update_net=update_net,
                                         action=action, correct=[correct])

            # during some episodes I save all data points
            aux = np.floor((self.num_tr-1) / self.num_tr_svd)
            aux2 = np.floor(self.num_tr / self.num_tr_svd)
            if aux % self.sv_pts_stp == 0 and\
               aux2 % self.sv_pts_stp == 1:
                self.all_pts_data.save(self.num_tr)
                self.all_pts_data.reset()

        else:
            new_st = self.get_state()
            # during some episodes I save all data points
            if np.floor(self.num_tr / self.num_tr_svd) % self.sv_pts_stp == 0:
                self.all_pts_data.update(new_state=new_st, net_state=net_st,
                                         reward=reward, update_net=update_net,
                                         action=action, correct=[correct],
                                         new_trial=new_trial,
                                         num_trials=self.num_tr)

        return new_st, reward, update_net, new_trial

    def save_trials_data(self):
        # Periodically save model trials statistics.
        if self.num_tr % 10000 == 0:
            data = {'stim2': self.stim2_mat, 'trial_duration': self.dur_tr,
                    'stims_position': self.stms_pos, 'repeat': self.repeat_mat,
                    'reward': self.reward_mat, 'performance': self.perf_mat,
                    'evidence': self.evidence_mat, 'rep_prob': self.rep_prob,
                    'net_smmd_act': self.net_smmd_act, 'action': self.action}
            np.savez(self.folder +
                     '/trials_stats_' + str(self.num_tr) + '.npz', **data)


# Agent

## Network class

In [5]:
class AC_Network():
    def __init__(self, a_size, state_size, scope, trainer, num_units, network):
        with tf.variable_scope(scope):
            # Input and visual encoding layers
            self.st = tf.placeholder(shape=[None, 1, state_size, 1],
                                     dtype=tf.float32)
            self.prev_rewards = tf.placeholder(shape=[None, 1],
                                               dtype=tf.float32)
            self.prev_actions = tf.placeholder(shape=[None],
                                               dtype=tf.int32)

            self.prev_actions_onehot = tf.one_hot(self.prev_actions, a_size,
                                                  dtype=tf.float32)

            hidden = tf.concat([slim.flatten(self.st), self.prev_rewards,
                                self.prev_actions_onehot], 1)

            # call RNN network
            if network == 'relu':
                net = RNN_ReLU
            elif network == 'lstm':
                net = RNN
            elif network == 'gru':
                net = RNN_GRU
            elif network == 'ugru':
                net = RNN_UGRU
            else:
                raise ValueError('Unknown network')

            self.st_init, self.st_in, self.st_out, self.actions,\
                self.actions_onehot, self.policy, self.value =\
                net(hidden, self.prev_rewards, a_size, num_units)

            # Only the worker network needs ops for loss functions
            # and gradient updating.
            if scope != 'global':
                self.target_v = tf.placeholder(shape=[None], dtype=tf.float32)
                self.advantages = tf.placeholder(shape=[None],
                                                 dtype=tf.float32)

                self.resp_outputs = \
                    tf.reduce_sum(self.policy * self.actions_onehot, [1])

                # Loss functions
                self.value_loss = 0.5 * tf.reduce_sum(
                        tf.square(self.target_v -
                                  tf.reshape(self.value, [-1])))
                self.entropy = - tf.reduce_sum(
                        self.policy * tf.log(self.policy + 1e-7))
                self.policy_loss = -tf.reduce_sum(
                        tf.log(self.resp_outputs + 1e-7)*self.advantages)
                self.loss = 0.5 * self.value_loss +\
                    self.policy_loss -\
                    self.entropy * 0.05

                # Get gradients from local network using local losses
                local_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES, scope)
                self.gradients = tf.gradients(self.loss, local_vars)
                self.var_norms = tf.global_norm(local_vars)
                grads, self.grad_norms =\
                    tf.clip_by_global_norm(self.gradients, 999.0)

                # Apply local gradients to global network
                global_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES, 'global')
                self.apply_grads = trainer.apply_gradients(
                        zip(grads, global_vars))


## Worker class

In [6]:
class Worker():
    def __init__(self, game, name, a_size, state_size, trainer,
                 model_path, global_epss, data_path, num_units, network):
        self.name = "worker_" + str(name)
        self.number = name
        self.folder = './' + data_path + '/trains/train_' + str(self.number)
        self.model_path = model_path
        self.trainer = trainer
        self.global_epss = global_epss
        self.increment = self.global_epss.assign_add(1)
        self.network = network
        self.eps_rewards = []
        self.eps_mean_values = []

        self.summary_writer = tf.summary.FileWriter(self.folder)

        # Create the local copy of the network and the tensorflow op
        # to copy global parameters to local network
        self.local_AC = AC_Network(a_size, state_size, self.name, trainer,
                                   num_units, network)
        self.update_local_ops = update_target_graph('global', self.name)
        self.env = game

    def train(self, rollout, sess, gamma, bootstrap_value):
        rollout = np.array(rollout)
        states = rollout[:, 0]
        actions = rollout[:, 1]
        rewards = rollout[:, 2]

        prev_rewards = [0] + rewards[:-1].tolist()
        prev_actions = [0] + actions[:-1].tolist()
        values = rollout[:, 3]

        self.pr = prev_rewards
        self.pa = prev_actions
        # Here we take the rewards and values from the rollout, and use them to
        # generate the advantage and discounted returns.
        # The advantage function uses "Generalized Advantage Estimation"
        self.rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value])
        discounted_rewards = discount(self.rewards_plus, gamma)[:-1]
        self.value_plus = np.asarray(values.tolist() + [bootstrap_value])
        advantages = rewards +\
            gamma * self.value_plus[1:] -\
            self.value_plus[:-1]
        advantages = discount(advantages, gamma)

        # Update the global network using gradients from loss
        # Generate network statistics to periodically save
        rnn_state = self.local_AC.st_init
        if self.network == 'lstm':
            feed_dict = {self.local_AC.target_v: discounted_rewards,
                         self.local_AC.state: np.stack(states, axis=0),
                         self.local_AC.prev_rewards: np.vstack(prev_rewards),
                         self.local_AC.prev_actions: prev_actions,
                         self.local_AC.actions: actions,
                         self.local_AC.advantages: advantages,
                         self.local_AC.state_in[0]: rnn_state[0],
                         self.local_AC.state_in[1]: rnn_state[1]}
        elif (self.network == 'relu') or\
             (self.network == 'gru') or\
             (self.network == 'ugru'):
            feed_dict = {self.local_AC.target_v: discounted_rewards,
                         self.local_AC.st: np.stack(states, axis=0),
                         self.local_AC.prev_rewards: np.vstack(prev_rewards),
                         self.local_AC.prev_actions: prev_actions,
                         self.local_AC.actions: actions,
                         self.local_AC.advantages: advantages,
                         self.local_AC.st_in: rnn_state}

        v_l, p_l, e_l, g_n, v_n, _ = sess.run([self.local_AC.value_loss,
                                               self.local_AC.policy_loss,
                                               self.local_AC.entropy,
                                               self.local_AC.grad_norms,
                                               self.local_AC.var_norms,
                                               self.local_AC.apply_grads],
                                              feed_dict=feed_dict)
        aux = len(rollout)
        return v_l / aux, p_l / aux, e_l / aux, g_n, v_n

    def work(self, gamma, sess, coord, saver, train, exp_dur):
        eps_count = sess.run(self.global_epss)
        num_eps_tr_stats = int(1000/self.env.update_net_step)
        num_epss_end = int(exp_dur/self.env.update_net_step)
        num_epss_save_model = int(5000/self.env.update_net_step)
        total_steps = 0
        print("Starting worker " + str(self.number))
        with sess.as_default(), sess.graph.as_default():
            while not coord.should_stop():
                sess.run(self.update_local_ops)
                eps_buffer = []
                eps_values = []
                eps_reward = 0
                eps_step_count = 0
                d = False
                r = 0
                a = 0

                # get first state
                s = self.env.new_trial()

                rnn_state = self.local_AC.st_init
                net_smmd_act = np.zeros_like(rnn_state)
                while not d:
                    if self.network == 'lstm':
                        feed_dict = {
                                    self.local_AC.state: [s],
                                    self.local_AC.prev_rewards: [[r]],
                                    self.local_AC.prev_actions: [a],
                                    self.local_AC.state_in[0]: rnn_state[0],
                                    self.local_AC.state_in[1]: rnn_state[1]}
                    elif (self.network == 'relu') or\
                         (self.network == 'gru') or\
                         (self.network == 'ugru'):
                        feed_dict = {
                                    self.local_AC.st: [s],
                                    self.local_AC.prev_rewards: [[r]],
                                    self.local_AC.prev_actions: [a],
                                    self.local_AC.st_in: rnn_state}

                    # Take an action using probs from policy network output
                    a_dist, v, rnn_state_new = sess.run(
                                                        [self.local_AC.policy,
                                                         self.local_AC.value,
                                                         self.local_AC.st_out],
                                                        feed_dict=feed_dict)

                    a = np.random.choice(a_dist[0], p=a_dist[0])
                    a = np.argmax(a_dist == a)
                    rnn_state = rnn_state_new
                    net_smmd_act += rnn_state_new
                    aux = np.floor(self.env.num_tr/self.env.num_tr_svd)
                    if aux % self.env.sv_pts_stp == 0:
                        network_activity = rnn_state_new
                    else:
                        network_activity = []
                    # new_state, reward, update_net, new_trial
                    s1, r, d, nt = self.env.pullArm(a, network_activity)
                    # save samples for training the network later
                    eps_buffer.append([s, a, r, v[0, 0]])
                    eps_values.append(v[0, 0])
                    eps_reward += r
                    total_steps += 1
                    eps_step_count += 1
                    # store the summed activity at the end of the trial
                    if nt:
                        self.env.net_smmd_act.append(net_smmd_act)
                        net_smmd_act = np.zeros_like(rnn_state)
                        self.env.save_trials_data()
                    if not d:
                        if nt:
                            s = self.env.new_trial()
                        else:
                            s = s1

                self.eps_rewards.append(eps_reward)
                self.eps_mean_values.append(np.mean(eps_values))

                # Update the network using the experience buffer
                # at the end of the episode
                if len(eps_buffer) != 0 and train:
                    v_l, p_l, e_l, g_n, v_n = \
                        self.train(eps_buffer, sess, gamma, 0.0)

                # Periodically save model parameters and summary statistics.
                if eps_count % num_eps_tr_stats == 0 and eps_count != 0:
                    if eps_count % num_epss_save_model == 0 and\
                       self.name == 'worker_0' and\
                       train and\
                       len(self.eps_rewards) != 0:
                        saver.save(sess, self.model_path +
                                   '/model-' + str(eps_count) + '.cptk')

                    mean_tr_dur = np.mean(self.env.dur_tr[-10:])
                    mean_reward = np.mean(self.eps_rewards[-10:])
                    mean_value = np.mean(self.eps_mean_values[-10:])
                    summary = tf.Summary()
                    summary.value.add(tag='Perf/trial_duration',
                                      simple_value=float(mean_tr_dur))
                    summary.value.add(tag='Perf/Reward',
                                      simple_value=float(mean_reward))
                    summary.value.add(tag='Perf/Value',
                                      simple_value=float(mean_value))

                    performance_aux = np.vstack(np.array(self.env.perf_mat))

                    for ind_crr in range(performance_aux.shape[1]):
                        mean_performance = np.mean(performance_aux[:, ind_crr])
                        summary.value.add(tag='Perf/Perf_' + str(ind_crr),
                                          simple_value=float(mean_performance))

                    if train:
                        summary.value.add(tag='Losses/Value Loss',
                                          simple_value=float(v_l))
                        summary.value.add(tag='Losses/Policy Loss',
                                          simple_value=float(p_l))
                        summary.value.add(tag='Losses/Entropy',
                                          simple_value=float(e_l))
                        summary.value.add(tag='Losses/Grad Norm',
                                          simple_value=float(g_n))
                        summary.value.add(tag='Losses/Var Norm',
                                          simple_value=float(v_n))
                    self.summary_writer.add_summary(summary, eps_count)

                    self.summary_writer.flush()

                if self.name == 'worker_0':
                    sess.run(self.increment)

                eps_count += 1
                if eps_count > num_epss_end:
                    break


# Call the function.

In [7]:
import threading
import multiprocessing
import os


def main_priors(load_model=False, train=True, gamma=.8, update_net_step=5,
                trial_duration=10, repeating_prob=(0.2, 0.8), exp_dur=10**6,
                rewards=(-0.1, 0.0, 1.0, -1.0), block_dur=200,
                num_units=32, stim_evidence=.3, network='ugru',
                learning_rate=1e-3, instance=0):
    a_size = 3  # number of actions
    state_size = a_size  # number of inputs
    if train:
        test_flag = ''
    else:
        test_flag = '_test'
    data_path = 'priors/' + 'trial_duration_' + str(trial_duration) +\
        '_repeating_prob_' + str(list_str(repeating_prob)) +\
        '_rewards_' + str(list_str(rewards)) +\
        '_block_dur_' + str(block_dur) + '_stimEv_' + str(stim_evidence) +\
        '_gamma_' + str(gamma) + '_num_units_' + str(num_units) +\
        '_update_net_step_' + str(update_net_step) + '_network_' \
        + str(network) + '_' + str(instance) + test_flag + '/'

    data = {'trial_duration': trial_duration, 'repeating_prob': repeating_prob,
            'rewards': rewards, 'stim_evidence': stim_evidence,
            'block_dur': block_dur, 'gamma': gamma, 'num_units': num_units,
            'update_net_step': update_net_step, 'network': network}

    model_path = './' + data_path + '/model_meta_context'

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    np.savez(data_path + '/experiment_setup.npz', **data)

    tf.reset_default_graph()
    with tf.device("/cpu:0"):
        global_episodes = tf.Variable(0, dtype=tf.int32,
                                      name='global_episodes',
                                      trainable=False)
        trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        AC_Network(a_size, state_size, 'global',
                             None, num_units, network)  # Generate global net
        # Set workers to number of available CPU threads
        num_workers = multiprocessing.cpu_count()
        workers = []
        # Create worker classes
        for i in range(num_workers):
            saving_path = './' + data_path + '/trains/train_' + str(i)
            workers.append(Worker(expectations(
                                    update_net_step=update_net_step,
                                    trial_duration=trial_duration,
                                    repeating_prob=repeating_prob,
                                    rewards=rewards, block_dur=block_dur,
                                    stim_evidence=stim_evidence,
                                    folder=saving_path),
                            i, a_size, state_size,
                            trainer, model_path, global_episodes,
                            data_path, num_units, network))
        saver = tf.train.Saver(max_to_keep=5)

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        if load_model:
            print('Loading Model...')
            print(model_path)
            ckpt = tf.train.get_checkpoint_state(model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        worker_threads = []
        for worker in workers:
            worker_work = lambda: worker.work(gamma, sess, coord, saver, train, exp_dur)
            thread = threading.Thread(target=(worker_work))
            thread.start()
            worker_threads.append(thread)
        coord.join(worker_threads)


In [8]:
ls

Cosyne_priors.ipynb  gym_priors.ipynb  [0m[01;34mpriors[0m/           Untitled.ipynb
[01;34mgym-priors[0m/          [01;34mhome[0m/             priors_2.0.ipynb


In [9]:
main_priors(load_model=False, train=True, gamma=.8, update_net_step=5,
                trial_duration=10, repeating_prob=(0.2, 0.8),
                rewards=(-0.1, 0.0, 1.0, -1.0), block_dur=200,
                num_units=32, stim_evidence=.4, network='ugru',
                learning_rate=1e-3, instance=123)

Starting worker 0
Starting worker 3
Starting worker 4
Starting worker 5
Starting worker 6
Starting worker 11
Starting worker 1
Starting worker 10
Starting worker 9
Starting worker 2
Starting worker 8Starting worker 7



Exception in thread Thread-26:
Traceback (most recent call last):
  File "/home/molano/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1334, in _do_call
    return fn(*args)
  File "/home/molano/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1319, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/molano/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1407, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.CancelledError: Session has been closed.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/molano/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/molano/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-7-f94e668ab927>", line 69

KeyboardInterrupt: 