In [470]:
import numpy as np
import tinygrad
from tinygrad import nn, Tensor

In [471]:
def fanin_init(size, fanin=None) -> Tensor:
    fanin = fanin or size[0]
    v = 1. / np.sqrt(fanin)
    return Tensor.uniform(*size, low=float(-v), high=float(v))

In [472]:
fanin_init((256,3), 0.003)

<Tensor <LB METAL (256, 3) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>

In [473]:
from tinygrad import nn, Tensor, dtypes
class Actor():
    def __init__(self, nb_states, nb_actions, hidden1=256, hidden2=128, init_w=3e-3):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nb_states, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, nb_actions)
        self.init_weights(init_w)
    
    def init_weights(self, init_w):
        self.fc1.weight = fanin_init(self.fc1.weight.size())
        self.fc2.weight = fanin_init(self.fc2.weight.size())
        self.fc3.weight = Tensor.uniform(
            *(self.fc3.weight.size()), 
            low=float(-init_w), high=float(init_w)
        )
    
    def __call__(self, x: Tensor) -> Tensor:
        out = self.fc1(x.cast(dtypes.float)).relu()
        out = self.fc2(out).relu()
        out = self.fc3(out).softsign()
        return out


In [474]:
actor = Actor(3, 3)
actor(Tensor([1,2,3]))

<Tensor <LB METAL (3,) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>

In [475]:
class Critic():
    def __init__(self, nb_states, nb_actions, hidden1=319, hidden2=128, init_w=3e-3):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(nb_states, hidden1)
        self.fc2 = nn.Linear(hidden1+nb_actions, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)
        #self.init_weights(init_w)
    
    def init_weights(self, init_w):
        self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())
        self.fc2.weight.data = fanin_init(self.fc2.weight.data.size())
        self.fc3.weight = Tensor.uniform(
            *(self.fc3.weight.size()),
            low=float(-init_w), high=float(init_w)
        )
    
    def __call__(self, xs:Tensor) -> Tensor:
        x, a = xs
        x = x.cast(dtypes.float)
        a = a.cast(dtypes.float)
        out = self.fc1(x).relu()
        # concatenate along columns
        c_in = out.cat(a, dim=len(a.shape)-1)
        out = self.fc2(c_in).relu()
        out = self.fc3(out)
        return out

In [476]:
critic = Critic(2,2)
critic(Tensor([[1,2], [2,1]]))

<Tensor <LB METAL (1,) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>

In [477]:
## MARK - Memory. Should be in a "memory.py" file

from __future__ import absolute_import
from collections import deque, namedtuple
import warnings
import random
import numpy as np

# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py

# This is to be understood as a transition: Given `state0`, performing `action`
# yields `reward` and results in `state1`, which might be `terminal`.
Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1')


def sample_batch_indexes(low, high, size):
    if high - low >= size:
        # We have enough data. Draw without replacement, that is each index is unique in the
        # batch. We cannot use `np.random.choice` here because it is horribly inefficient as
        # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion.
        # `random.sample` does the same thing (drawing without replacement) and is way faster.
        try:
            r = xrange(low, high)
        except NameError:
            r = range(low, high)
        batch_idxs = random.sample(r, size)
    else:
        # Not enough data. Help ourselves with sampling from the range, but the same index
        # can occur multiple times. This is not good and should be avoided by picking a
        # large enough warm-up phase.
        warnings.warn('Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!')
        # batch_idxs = np.random.random_integers(low, high - 1, size=size)
        batch_idxs = np.random.randint(low, high, size=size)
    assert len(batch_idxs) == size
    return batch_idxs

class RingBuffer(object):
    def __init__(self, maxlen):
        self.maxlen = maxlen
        self.start = 0
        self.length = 0
        self.data = [None for _ in range(maxlen)]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.length:
            raise KeyError()
        return self.data[(self.start + idx) % self.maxlen]

    def append(self, v):
        assert isinstance(v, np.ndarray) or isinstance(v, float) or isinstance(v, bool), "v_type:{}".format(type(v))
        if self.length < self.maxlen:
            # We have space, simply increase the length.
            self.length += 1
        elif self.length == self.maxlen:
            # No space, "remove" the first item.
            self.start = (self.start + 1) % self.maxlen
        else:
            # This should never happen.
            raise RuntimeError()
        self.data[(self.start + self.length - 1) % self.maxlen] = v


def zeroed_observation(observation):
    if hasattr(observation, 'shape'):
        return np.zeros(observation.shape)
    elif hasattr(observation, '__iter__'):
        out = []
        for x in observation:
            out.append(zeroed_observation(x))
        return out
    else:
        return 0.


class Memory(object):
    def __init__(self, window_length, ignore_episode_boundaries=False):
        self.window_length = window_length
        self.ignore_episode_boundaries = ignore_episode_boundaries

        self.recent_observations = deque(maxlen=window_length)
        self.recent_terminals = deque(maxlen=window_length)

    def sample(self, batch_size, batch_idxs=None):
        raise NotImplementedError()

    def append(self, observation, action, reward, terminal, training=True):
        self.recent_observations.append(observation)
        self.recent_terminals.append(terminal)

    def get_recent_state(self, current_observation):
        # This code is slightly complicated by the fact that subsequent observations might be
        # from different episodes. We ensure that an experience never spans multiple episodes.
        # This is probably not that important in practice but it seems cleaner.
        state = [current_observation]
        idx = len(self.recent_observations) - 1
        for offset in range(0, self.window_length - 1):
            current_idx = idx - offset
            current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False
            if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
                # The previously handled observation was terminal, don't add the current one.
                # Otherwise we would leak into a different episode.
                break
            state.insert(0, self.recent_observations[current_idx])
        while len(state) < self.window_length:
            state.insert(0, zeroed_observation(state[0]))
        return state

    def get_config(self):
        config = {
            'window_length': self.window_length,
            'ignore_episode_boundaries': self.ignore_episode_boundaries,
        }
        return config

class SequentialMemory(Memory):
    def __init__(self, limit, **kwargs):
        super(SequentialMemory, self).__init__(**kwargs)
        
        self.limit = limit

        # Do not use deque to implement the memory. This data structure may seem convenient but
        # it is way too slow on random access. Instead, we use our own ring buffer implementation.
        self.actions = RingBuffer(limit)
        self.rewards = RingBuffer(limit)
        self.terminals = RingBuffer(limit)
        self.observations = RingBuffer(limit)

    def sample(self, batch_size, batch_idxs=None):
        if batch_idxs is None:
            # Draw random indexes such that we have at least a single entry before each
            # index.
            assert self.nb_entries >= 2
            batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size)
        batch_idxs = np.array(batch_idxs) + 1
        assert np.min(batch_idxs) >= 1
        assert np.max(batch_idxs) < self.nb_entries
        assert len(batch_idxs) == batch_size

        # Create experiences
        experiences = []
        for idx in batch_idxs:
            terminal0 = self.terminals[idx - 2] if idx >= 2 else False
            while terminal0:
                # Skip this transition because the environment was reset here. Select a new, random
                # transition and use this instead. This may cause the batch to contain the same
                # transition twice.
                idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
                terminal0 = self.terminals[idx - 2] if idx >= 2 else False
            assert 1 <= idx < self.nb_entries

            # This code is slightly complicated by the fact that subsequent observations might be
            # from different episodes. We ensure that an experience never spans multiple episodes.
            # This is probably not that important in practice but it seems cleaner.
            state0 = [self.observations[idx - 1]]
            for offset in range(0, self.window_length - 1):
                current_idx = idx - 2 - offset
                current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
                if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
                    # The previously handled observation was terminal, don't add the current one.
                    # Otherwise we would leak into a different episode.
                    break
                state0.insert(0, self.observations[current_idx])
            while len(state0) < self.window_length:
                state0.insert(0, zeroed_observation(state0[0]))
            action = self.actions[idx - 1]
            reward = self.rewards[idx - 1]
            terminal1 = self.terminals[idx - 1]

            # Okay, now we need to create the follow-up state. This is state0 shifted on timestep
            # to the right. Again, we need to be careful to not include an observation from the next
            # episode if the last state is terminal.
            state1 = [np.copy(x) for x in state0[1:]]
            state1.append(self.observations[idx])

            assert len(state0) == self.window_length
            assert len(state1) == len(state0)
            experiences.append(Experience(state0=state0, action=action, reward=reward,
                                          state1=state1, terminal1=terminal1))
        assert len(experiences) == batch_size
        return experiences

    def sample_and_split(self, batch_size, batch_idxs=None):
        experiences = self.sample(batch_size, batch_idxs)

        state0_batch = []
        reward_batch = []
        action_batch = []
        terminal1_batch = []
        state1_batch = []
        for e in experiences:
            state0_batch.append(e.state0)
            state1_batch.append(e.state1)
            reward_batch.append(e.reward)
            action_batch.append(e.action)
            terminal1_batch.append(0. if e.terminal1 else 1.)

        # Prepare and validate parameters.
        state0_batch = np.array(state0_batch).reshape(batch_size,-1).astype(np.float32)
        state1_batch = np.array(state1_batch).reshape(batch_size,-1).astype(np.float32)
        terminal1_batch = np.array(terminal1_batch).reshape(batch_size,-1).astype(np.float32)
        reward_batch = np.array(reward_batch).reshape(batch_size,-1).astype(np.float32)
        action_batch = np.array(action_batch, dtype="object").reshape(batch_size,-1).astype(np.float32)

        return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch


    def append(self, observation, action, reward, terminal, training=True):
        super(SequentialMemory, self).append(observation, action, reward, terminal, training=training)
        
        # This needs to be understood as follows: in `observation`, take `action`, obtain `reward`
        # and weather the next state is `terminal` or not.
        if training:
            self.observations.append(observation)
            self.actions.append(action)
            self.rewards.append(reward)
            self.terminals.append(terminal)

    @property
    def nb_entries(self):
        return len(self.observations)

    def get_config(self):
        config = super(SequentialMemory, self).get_config()
        config['limit'] = self.limit
        return config


class EpisodeParameterMemory(Memory):
    def __init__(self, limit, **kwargs):
        super(EpisodeParameterMemory, self).__init__(**kwargs)
        self.limit = limit

        self.params = RingBuffer(limit)
        self.intermediate_rewards = []
        self.total_rewards = RingBuffer(limit)

    def sample(self, batch_size, batch_idxs=None):
        if batch_idxs is None:
            batch_idxs = sample_batch_indexes(0, self.nb_entries, size=batch_size)
        assert len(batch_idxs) == batch_size

        batch_params = []
        batch_total_rewards = []
        for idx in batch_idxs:
            batch_params.append(self.params[idx])
            batch_total_rewards.append(self.total_rewards[idx])
        return batch_params, batch_total_rewards

    def append(self, observation, action, reward, terminal, training=True):
        super(EpisodeParameterMemory, self).append(observation, action, reward, terminal, training=training)
        if training:
            self.intermediate_rewards.append(reward)

    def finalize_episode(self, params):
        total_reward = sum(self.intermediate_rewards)
        self.total_rewards.append(total_reward)
        self.params.append(params)
        self.intermediate_rewards = []

    @property
    def nb_entries(self):
        return len(self.total_rewards)

    def get_config(self):
        config = super(SequentialMemory, self).get_config()
        config['limit'] = self.limit
        return config

In [478]:
import numpy as np 

# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/random.py

class RandomProcess(object):
    def reset_states(self):
        pass

class AnnealedGaussianProcess(RandomProcess):
    def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
        self.mu = mu
        self.sigma = sigma
        self.n_steps = 0

        if sigma_min is not None:
            self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
            self.c = sigma
            self.sigma_min = sigma_min
        else:
            self.m = 0.
            self.c = sigma
            self.sigma_min = sigma

    @property
    def current_sigma(self):
        sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
        return sigma


# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
    def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
        super(OrnsteinUhlenbeckProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
        self.theta = theta
        self.mu = mu
        self.dt = dt
        self.x0 = x0
        self.size = size
        self.reset_states()

    def sample(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
        self.x_prev = x
        self.n_steps += 1
        return x

    def reset_states(self):
        self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)

In [479]:
from tinygrad import nn, Tensor, dtypes, TinyJit
from tinygrad.nn.optim import LAMB

# might have to change this to "load_state_dict" and manually updating state_dict
def hard_update(target, source):
    for target_tensor, tensor in zip(nn.state.get_parameters(target), nn.state.get_parameters(source)):
        tensor.requires_grad = False
        target_tensor.replace(tensor)
        tensor.requires_grad = True

class DDPG(object):
    def __init__(self, args, nb_states, nb_actions):
        if args.seed > 0:
            self.seed(args.seed)

        self.nb_states =  nb_states
        self.nb_actions= nb_actions

        net_cfg = {
            'hidden1': args.hidden1,
            'hidden2': args.hidden2,
            'init_w': args.init_w
        }
        self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
        self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)

        self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
        self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)

        hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)

        print(f'Initialized DDPG with actor parameters: {len(nn.state.get_parameters(self.actor))}, lr={args.p_lr}')
        print(f'Initialized DDPG with critic parameters: {len(nn.state.get_parameters(self.critic))}, lr={args.c_lr}')
        
        self.actor_optim = LAMB(params=nn.state.get_parameters(self.actor), lr=args.p_lr, weight_decay=args.weight_decay, adam=True)
        self.critic_optim = LAMB(params=nn.state.get_parameters(self.critic), lr=args.c_lr, weight_decay=args.weight_decay, adam=True)
        
        #Create replay buffer
        self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length)
        self.random_process = OrnsteinUhlenbeckProcess(size=self.nb_actions,
                                                       theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma)

        # Hyper-parameters
        self.batch_size = args.bsize
        self.tau_update = args.tau_update
        self.gamma = args.gamma

        # Linear decay rate of exploration policy
        self.depsilon = 1.0 / args.epsilon
        # initial exploration rate
        self.epsilon = 1.0
        self.s_t = None # Most recent state
        self.a_t = None # Most recent action
        self.is_training = True

        self.continious_action_space = False

    def update_policy(self):
        pass

    def observe(self, r_t, s_t1, done):
        if self.is_training:
            # print(f'typese of memory: s_t: {self.s_t}, a_t: {self.a_t}')
            self.memory.append(self.s_t, self.a_t, r_t, done)
            self.s_t = s_t1

    def random_action(self):
        action = np.random.uniform(-1., 1., self.nb_actions)
        # self.a_t = action
        return action

    def select_action(self, s_t, decay_epsilon=True):
        # proto action
        if type(s_t) is tuple and s_t[1] == {}:
            s_t = s_t[0]
        orig_tensor = Tensor([np.array(list(s_t), dtype=np.float32)], dtype=dtypes.float, requires_grad=False)
        action = self.actor(orig_tensor).numpy().squeeze(0)
        action += self.is_training * max(self.epsilon, 0) * self.random_process.sample()
        action = np.clip(action, -1., 1.)

        if decay_epsilon:
            self.epsilon -= self.depsilon
        
        # self.a_t = action
        return action

    def reset(self, s_t):
        self.s_t = s_t
        self.random_process.reset_states()

    def load_weights(self, dir):
        if dir is None: return

        # load all tensors to CPU
        ml = lambda storage, loc: storage

        state_dict_actor = nn.state.safe_load(
            'output/{}/actor.safetensors'.format(dir)
        )
        nn.state.load_state_dict(
            self.actor, state_dict_actor
        )

        state_dict_critic = nn.state.safe_load(
            'output/{}/critic.safetensors'.format(dir)
        )
        nn.state.load_state_dict(
            self.critic, state_dict_critic
        )
        print('model weights loaded')


    def save_model(self,output):
        state_dict_actor = nn.state.get_state_dict(self.actor)
        nn.state.safe_save(
            state_dict_actor,
            '{}/actor.safetensors'.format(output)
        )

        state_dict_critic = nn.state.get_state_dict(self.critic)
        nn.state.safe_save(
            state_dict_critic,
            '{}/critic.safetensors'.format(output)
        )

    def seed(self,seed):
        Tensor.manual_seed(seed)

In [480]:
# action_space.py

import numpy as np
np.bool = np.bool_
import itertools
import pyflann

"""
    This class represents a n-dimensional unit cube with a specific number of points embeded.
    Points are distributed uniformly in the initialization. A search can be made using the
    search_point function that returns the k (given) nearest neighbors of the input point.
"""


class Space:

    def __init__(self, low, high, points):

        self._low = np.array(low)
        self._high = np.array(high)
        self._range = self._high - self._low
        self._dimensions = len(low)
        self._space_low = -1
        self._space_high = 1
        self._k = (self._space_high - self._space_low) / self._range
        self.__space = init_uniform_space([self._space_low] * self._dimensions,
                                          [self._space_high] * self._dimensions,
                                          points).astype(np.float32)
        self._flann = pyflann.FLANN()
        self.rebuild_flann()

    def rebuild_flann(self):
        self._index = self._flann.build_index(self.__space, algorithm='kdtree')
        print(f'Rebuilding FLANN with: {self._flann}, index: {self._index}')

    def search_point(self, point, k):
        p_in = point
        if not isinstance(point, np.ndarray):
            p_in = np.array([p_in]).astype(np.float32)
        # p_in = self.import_point(point)
        # print(f'p_in.astype(np.float32).shape: {np.shape(p_in.astype(np.float32))}')
        if np.shape(p_in) == (1,64,1):
            p_in = p_in[0,:,:]
        search_res, _ = self._flann.nn_index(p_in.astype(np.float32), k)
        knns = self.__space[search_res]
        p_out = []
        for p in knns:
            p_out.append(self.export_point(p))

        if k == 1:
            p_out = [p_out]
        return knns, np.array(p_out)

    def import_point(self, point):
        return self._space_low + self._k * (point - self._low)

    def export_point(self, point):
        return self._low + (point - self._space_low) / self._k

    def get_space(self):
        return self.__space

    def shape(self):
        return self.__space.shape

    def get_number_of_actions(self):
        return self.shape()[0]


class Discrete_space(Space):
    """
        Discrete action space with n actions (the integers in the range [0, n))
        1, 2, ..., n-1, n

        In gym: 'Discrete' object has no attribute 'high'
    """

    def __init__(self, n):  # n: the number of the discrete actions
        super().__init__([0], [n-1], n)

    def export_point(self, point):
        return np.round(super().export_point(point)).astype(int)


def init_uniform_space(low, high, points):
    dims = len(low)
    # In Discrete situation, the action space is an one dimensional space, i.e., one row
    points_in_each_axis = round(points**(1.0 / dims))

    axis = []
    for i in range(dims):
        axis.append(list(np.linspace(low[i], high[i], points_in_each_axis)))

    space = []
    for _ in itertools.product(*axis):
        space.append(list(_))

    # space: e.g., [[1], [2], ... ,[n-1]]
    return np.array(space)

In [481]:
from tinygrad import TinyJit, Tensor

# see note above for hard_update
def soft_update(target, source, tau_update):
    for target_tensor, tensor in zip(nn.state.get_parameters(target), nn.state.get_parameters(source)):
        tensor.requires_grad = False
        target_tensor.replace(target_tensor * (1.0 - tau_update) + tensor * tau_update)
        tensor.requires_grad = True

def criterion(input, target):
    return ((input-target).pow(2)).mean()

class WolpertingerAgent(DDPG):

    def __init__(self, continuous, max_actions, action_low, action_high, nb_states, nb_actions, args, k_ratio=0.1):
        super().__init__(args, nb_states, nb_actions)
        self.experiment = args.id
        # according to the papers, it can be scaled to hundreds of millions
        if continuous:
            self.action_space = Space(action_low, action_high, args.max_actions)
            self.k_nearest_neighbors = max(1, int(args.max_actions * k_ratio))
        else:
            self.action_space = Discrete_space(max_actions)
            self.k_nearest_neighbors = max(1, int(max_actions * k_ratio))


    def get_name(self):
        return 'Wolp3_{}k{}_{}'.format(self.action_space.get_number_of_actions(),
                                       self.k_nearest_neighbors, self.experiment)

    def get_action_space(self):
        return self.action_space

    def wolp_action(self, s_t, proto_action):
        # get the proto_action's k nearest neighbors
        raw_actions, actions = self.action_space.search_point(proto_action, self.k_nearest_neighbors)

        if not isinstance(s_t, np.ndarray):
           s_t = s_t.numpy().astype(np.float32)
        # make all the state, action pairs for the critic
        s_t = np.tile(s_t, [raw_actions.shape[1], 1])

        # print(f'raw_actions.shape: {raw_actions.shape}, s_t.shape: {s_t.shape}')
        s_t = s_t.reshape(len(raw_actions), raw_actions.shape[1], s_t.shape[-1]) if self.k_nearest_neighbors > 1 \
            else s_t.reshape(raw_actions.shape[0], s_t.shape[-1])
        raw_actions = Tensor([raw_actions], dtype=dtypes.float, requires_grad=False)
        s_t = Tensor([s_t], dtype=dtypes.float, requires_grad=False)

        # evaluate each pair through the critic
        actions_evaluation = self.critic([s_t, raw_actions])

        # find the index of the pair with the maximum value
        max_index = np.argmax(actions_evaluation.numpy(), axis=2)
        #print(f'max_index: {max_index}, len(max_index): {len(max_index)}')
        #print(f'actions_evaluation.numpy.shape: {np.shape(actions_evaluation.numpy())}, argmax: {np.argmax(actions_evaluation.numpy())}')
        max_index = max_index.reshape(len(max_index.flatten()),)

        raw_actions = raw_actions.numpy().astype(np.float32)
        #print(f'raw_actions.shape: {np.shape(raw_actions)}')
        if len(raw_actions.shape) == 4:
            raw_actions = raw_actions[0]
        # return the best action, i.e., wolpertinger action from the full wolpertinger policy
        if self.k_nearest_neighbors > 1:
            return raw_actions[[i for i in range(len(raw_actions))], max_index, [0]].reshape(len(raw_actions),1), \
                   actions[[i for i in range(len(actions))], max_index, [0]].reshape(len(actions),1)
        else:
            return raw_actions[max_index], actions[max_index]

    def select_action(self, s_t, decay_epsilon=True):
        # taking a continuous action from the actor
        proto_action = super().select_action(s_t, decay_epsilon)

        #print(f'select_action types: {type(s_t.dtype)}, {type(proto_action.dtype)}')
        if type(s_t) is tuple and s_t[1] == {}:
            s_t = s_t[0]
        raw_wolp_action, wolp_action = self.wolp_action(s_t, proto_action)
        assert isinstance(raw_wolp_action, np.ndarray)
        self.a_t = raw_wolp_action
        # return the best neighbor of the proto action, this is an action for env step
        return wolp_action[0]  # [i]

    def random_action(self):
        proto_action = super().random_action()
        raw_action, action = self.action_space.search_point(proto_action, 1)
        raw_action = raw_action[0]
        action = action[0]
        assert isinstance(raw_action, np.ndarray)
        self.a_t = raw_action
        return action[0] # [i]

    def select_target_action(self, s_t):
        proto_action = self.actor_target(s_t)
        proto_action = proto_action.clamp(-1.0, 1.0).numpy().astype(np.float32)
        if type(s_t) is tuple and s_t[1] == {}:
            s_t = s_t[0]
        raw_wolp_action, wolp_action = self.wolp_action(s_t, proto_action)
        return raw_wolp_action

    def update_policy(self):
        # Sample batch
        state_batch, action_batch, reward_batch, \
        next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size)

        # Prepare for the target q batch
        # the operation below of critic_target does not require backward_P
        next_state_batch = Tensor([next_state_batch], dtype=dtypes.float, requires_grad=False)
        next_wolp_action_batch = Tensor(self.select_target_action(next_state_batch)[0:1,:,:], dtype=dtypes.float, requires_grad=False)
        next_q_values = self.critic_target((
            next_state_batch,
            next_wolp_action_batch
        ))
        # but it requires bp in computing gradient of critic loss
        # next_q_values.volatile = False

        # next_q_values = 0 if is terminal states
        target_q_batch = Tensor([reward_batch], dtype=dtypes.float, requires_grad=False) + \
                         self.gamma * \
                         Tensor([terminal_batch.astype(np.float32)], dtype=dtypes.float, requires_grad=False) * \
                         next_q_values

        # Critic update
        self.critic_optim.zero_grad()  # Clears the gradients of all optimized tinygrad.Tensor s.

        state_batch = Tensor([state_batch], dtype=dtypes.float, requires_grad=False)
        action_batch = Tensor([action_batch], dtype=dtypes.float, requires_grad=False)
        q_batch = self.critic([state_batch, action_batch])

        value_loss = criterion(q_batch, target_q_batch)
        value_loss.requires_grad = False
        value_loss.backward()  # computes gradients
        self.critic_optim.step()  # updates the parameters

        # Actor update
        self.actor_optim.zero_grad()

        # self.actor(to_tensor(state_batch)): proto_action_batch
        policy_loss = -self.critic([state_batch, self.actor(state_batch)])
        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()

        # Target update
        soft_update(self.actor_target, self.actor, self.tau_update)
        soft_update(self.critic_target, self.critic, self.tau_update)

In [482]:
import gym

# https://github.com/openai/gym/blob/master/gym/core.py
class NormalizedEnv(gym.ActionWrapper):
    """ Wrap action """

    # def _action(self, action):
    #     act_k = (self.action_space.high - self.action_space.low)/ 2.
    #     act_b = (self.action_space.high + self.action_space.low)/ 2.
    #     return act_k * action + act_b
    #
    # def _reverse_action(self, action):
    #     act_k_inv = 2./(self.action_space.high - self.action_space.low)
    #     act_b = (self.action_space.high + self.action_space.low)/ 2.
    #     return act_k_inv * (action - act_b)

    def action(self, action):
        act_k = (self.action_space.high - self.action_space.low)/ 2.
        act_b = (self.action_space.high + self.action_space.low)/ 2.
        return act_k * action + act_b

    def reverse_action(self, action):
        act_k_inv = 2./(self.action_space.high - self.action_space.low)
        act_b = (self.action_space.high + self.action_space.low)/ 2.
        return act_k_inv * (action - act_b)

In [483]:

import argparse

def init_parser(alg):

    if alg == 'WOLP_DDPG':

        parser = argparse.ArgumentParser(description='WOLP_DDPG')

        parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor for rewards (default: 0.99)')
        parser.add_argument('--max-episode-length', type=int, default=1440, metavar='M', help='maximum length of an episode (default: 1440)')
        parser.add_argument('--load', default=False, metavar='L', help='load a trained model')
        parser.add_argument('--load-model-dir', default='', metavar='LMD', help='folder to load trained models from')
        parser.add_argument('--gpu-ids', type=int, default=[1], nargs='+', help='GPUs to use [-1 CPU only]')
        parser.add_argument('--gpu-nums', type=int, default=1, help='#GPUs to use (default: 1)')
        parser.add_argument('--max-episode', type=int, default=200000, help='maximum #episode.')
        parser.add_argument('--test-episode', type=int, default=1000, help='maximum testing #episode.')
        parser.add_argument('--max-actions', default=200000, type=int, help='# max actions')
        parser.add_argument('--id', default='0', type=str, help='experiment id')
        parser.add_argument('--mode', default='train', type=str, help='support option: train/test')
        parser.add_argument('--env', default='Pendulum-v0', type=str, help='Ride sharing')
        parser.add_argument('--hidden1', default=256, type=int, help='hidden num of first fully connect layer')
        parser.add_argument('--hidden2', default=128, type=int, help='hidden num of second fully connect layer')
        parser.add_argument('--c-lr', default=0.001, type=float, help='critic net learning rate')
        parser.add_argument('--p-lr', default=0.0001, type=float, help='policy net learning rate (only for DDPG)')
        parser.add_argument('--warmup', default=128, type=int, help='time without training but only filling the replay memory')
        parser.add_argument('--bsize', default=64, type=int, help='minibatch size')
        parser.add_argument('--rmsize', default=6000000, type=int, help='memory size')
        parser.add_argument('--window_length', default=1, type=int, help='')
        parser.add_argument('--tau-update', default=0.001, type=float, help='moving average for target network')
        parser.add_argument('--ou_theta', default=0.15, type=float, help='noise theta')
        parser.add_argument('--ou_sigma', default=0.2, type=float, help='noise sigma')
        parser.add_argument('--ou_mu', default=0.0, type=float, help='noise mu')
        parser.add_argument('--max_episode_length', default=500, type=int, help='')
        parser.add_argument('--init_w', default=0.003, type=float, help='')
        parser.add_argument('--epsilon', default=80000, type=int, help='Linear decay of exploration policy')
        parser.add_argument('--seed', default=-1, type=int, help='')
        parser.add_argument('--weight-decay', default=0.0001, type=float, help='weight decay for L2 Regularization loss')
        parser.add_argument('--save_per_epochs', default=15, type=int, help='save model every X epochs')

        return parser

    else:

        raise RuntimeError('undefined algorithm {}'.format(alg))

In [484]:
import gym
env = gym.make('CartPole-v1')

continuous = None
try:
    # continuous action
    nb_states = env.observation_space.shape[0]
    nb_actions = env.action_space.shape[0]
    action_high = env.action_space.high
    action_low = env.action_space.low
    continuous = True
    env = NormalizedEnv(env)
except IndexError:
    # discrete action for 1 dimension
    nb_states = env.observation_space.shape[0]
    nb_actions = 1  # the dimension of actions, usually it is 1. Depend on the environment.
    max_actions = env.action_space.n
    continuous = False

parser = init_parser('WOLP_DDPG')
args = parser.parse_args(args=[])

if continuous:
    agent_args = {
        'continuous': continuous,
        'max_actions': None,
        'action_low': action_low,
        'action_high': action_high,
        'nb_states': nb_states,
        'nb_actions': nb_actions,
        'args': args,
    }
else:
    agent_args = {
        'continuous': continuous,
        'max_actions': max_actions,
        'action_low': None,
        'action_high': None,
        'nb_states': nb_states,
        'nb_actions': nb_actions,
        'args': args,
    }

agent = WolpertingerAgent(**agent_args)

Initialized DDPG with actor parameters: 6, lr=0.0001
Initialized DDPG with critic parameters: 6, lr=0.001


In [488]:
# Override and put a numpy bool in
import numpy as np
from tinygrad import TinyJit
np.bool = np.bool_

@TinyJit
@Tensor.train()
def train(continuous, env, agent, max_episode, warmup, save_model_dir, max_episode_length, logger, save_per_epochs):
    agent.is_training = True
    step = episode = episode_steps = 0
    episode_reward = 0.
    s_t = None
    print(f'max_episode: {max_episode}. save_per_epochs: {save_per_epochs}')
    while episode < max_episode:
        while True:
            if s_t is None:
                s_t = env.reset()
                agent.reset(s_t[0])

            # agent pick action ...
            # args.warmup: time without training but only filling the memory
            if step <= warmup:
                action = agent.random_action()
            else:
                action = agent.select_action(s_t)

            # env response with next_observation, reward, terminate_info
            if not continuous:
                action = action.reshape(1,).astype(int)[0]
            s_t1, r_t, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            #print(f's_t1, r_t, done: {s_t1}, {r_t}, {done}')
            #print(f's_t1 type: {type(s_t1)}')
            s_t1 = np.array(s_t1)

            if max_episode_length and episode_steps >= max_episode_length - 1:
                done = True

            # agent observe and update policy
            agent.observe(r_t, np.array(s_t1), done)
            if step > warmup:
                agent.update_policy()

            # update
            step += 1
            episode_steps += 1
            episode_reward += r_t
            s_t = s_t1
            # s_t = deepcopy(s_t1)

            if done:  # end of an episode
                print("Ep:{0} | R:{1:.4f}".format(episode, episode_reward))
                logger.info(
                    "Ep:{0} | R:{1:.4f}".format(episode, episode_reward)
                )

                agent.memory.append(
                    s_t,
                    agent.select_action(s_t),
                    0., True
                )

                # reset
                s_t = None
                episode_steps =  0
                episode_reward = 0.
                episode += 1
                # break to next episode
                break
        # [optional] save intermideate model every run through of 32 episodes
        if step > warmup and episode > 0 and episode % save_per_epochs == 0:
            agent.save_model(save_model_dir)
            logger.info("### Model Saved before Ep:{0} ###".format(episode))

@TinyJit
@Tensor.test()
def test(env, agent, model_path, test_episode, max_episode_length, logger):

    agent.load_weights(model_path)
    agent.is_training = False
    agent.eval()

    policy = lambda x: agent.select_action(x, decay_epsilon=False)

    episode_steps = 0
    episode_reward = 0.
    s_t = None
    for i in range(test_episode):
        while True:
            if s_t is None:
                s_t = env.reset()
                agent.reset(s_t)

            action = policy(s_t)
            s_t, r_t, done, _, _ = env.step(action)
            s_t = np.array(s_t)
            episode_steps += 1
            episode_reward += r_t
            if max_episode_length and episode_steps >= max_episode_length - 1:
                done = True
            if done:  # end of an episode
                logger.info(
                    "Ep:{0} | R:{1:.4f}".format(i+1, episode_reward)
                )
                s_t = None
                break

In [489]:
import logging
import numpy as np
np.bool = np.bool_

def setup_logger(logger_name, log_file, level=logging.INFO):
    l = logging.getLogger(logger_name)
    formatter = logging.Formatter('%(asctime)s : %(message)s')
    fileHandler = logging.FileHandler(log_file, mode='w')
    fileHandler.setFormatter(formatter)
    streamHandler = logging.StreamHandler()
    streamHandler.setFormatter(formatter)

    l.setLevel(level)
    l.addHandler(fileHandler)
    l.addHandler(streamHandler)

train_args = {
    'continuous': continuous,
    'env': env,
    'agent': agent,
    'max_episode': args.max_episode,
    'warmup': args.warmup,
    'save_model_dir': "./",
    'max_episode_length': args.max_episode_length,
    'logger': logging.getLogger('RS_log'),
    'save_per_epochs': args.save_per_epochs
}

train(**train_args)

max_episode: 200000. save_per_epochs: 15
Ep:0 | R:15.0000
Ep:1 | R:21.0000
Ep:2 | R:20.0000
Ep:3 | R:21.0000
Ep:4 | R:28.0000
Ep:5 | R:37.0000
Ep:6 | R:10.0000
Ep:7 | R:10.0000
Ep:8 | R:10.0000
Ep:9 | R:9.0000
Ep:10 | R:10.0000
Ep:11 | R:10.0000
Ep:12 | R:10.0000
Ep:13 | R:10.0000
Ep:14 | R:10.0000
Ep:15 | R:8.0000
Ep:16 | R:10.0000
Ep:17 | R:10.0000
Ep:18 | R:10.0000
Ep:19 | R:8.0000
Ep:20 | R:10.0000
Ep:21 | R:10.0000
Ep:22 | R:10.0000
Ep:23 | R:10.0000
Ep:24 | R:10.0000
Ep:25 | R:10.0000
Ep:26 | R:9.0000
Ep:27 | R:9.0000
Ep:28 | R:9.0000
Ep:29 | R:9.0000
Ep:30 | R:9.0000
Ep:31 | R:10.0000
Ep:32 | R:11.0000
Ep:33 | R:9.0000
Ep:34 | R:9.0000
Ep:35 | R:9.0000
Ep:36 | R:10.0000
Ep:37 | R:9.0000
Ep:38 | R:9.0000
Ep:39 | R:9.0000
Ep:40 | R:9.0000
Ep:41 | R:9.0000
Ep:42 | R:10.0000
Ep:43 | R:9.0000
Ep:44 | R:9.0000
Ep:45 | R:10.0000
Ep:46 | R:8.0000
Ep:47 | R:9.0000
Ep:48 | R:10.0000
Ep:49 | R:10.0000
Ep:50 | R:9.0000
Ep:51 | R:11.0000
Ep:52 | R:10.0000
Ep:53 | R:10.0000
Ep:54 | R:10.0000


KeyboardInterrupt: 