<a href="https://colab.research.google.com/github/milnico/InteractionRules/blob/main/ReactivePolicy_Behavior_DangerZone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pre-requisite

Before we start, we need to install EvoJAX and import some libraries.  
**Note** In our [paper](https://arxiv.org/abs/2202.05008), we ran the experiments on NVIDIA V100 GPU(s). Your results can be different from ours.


In [None]:
from IPython.display import clear_output, Image

!pip install evojax --no-dependencies
!pip install flax==0.6.0 --no-dependencies
!pip install rich==11.2.0 --no-dependencies
!pip install cma==3.2.2
!pip install optax==0.1.3 --no-dependencies
!pip install commonmark==0.9.1 --no-dependencies
!pip install chex==0.1.5 --no-dependencies

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evojax
  Downloading evojax-0.2.15-py3-none-any.whl (94 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.1/94.1 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evojax
Successfully installed evojax-0.2.15
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax==0.6.0
  Downloading flax-0.6.0-py3-none-any.whl (180 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.1/180.1 KB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flax
  Attempting uninstall: flax
    Found existing installation: flax 0.6.7
    Uninstalling flax-0.6.7:
      Successfully uninstalled flax-0.6.7
Successfully installed flax-0.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rich=

In [None]:
#@title Testo del titolo predefinito
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os
import numpy as np
import jax
import jax.numpy as jnp

from evojax.task.cartpole import CartPoleSwingUp
from evojax.policy.mlp import MLPPolicy
from evojax.algo import PGPE
from evojax import Trainer
from evojax.util import create_logger

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [None]:
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of a multi-agents WaterWorld task.

Ref: https://cs.stanford.edu/people/karpathy/reinforcejs/waterworld.html
"""

from typing import Tuple
from functools import partial
from PIL import Image
from PIL import ImageDraw
import numpy as np

import jax
import jax.numpy as jnp
from jax import random
from jax import tree_util
from flax.struct import dataclass

from evojax.task.base import TaskState
from evojax.task.base import VectorizedTask


SCREEN_W = 1000
SCREEN_H = 1000
SCALE = 1
BUBBLE_RADIUS = 2.5
EATING_RADIUS = 50
MIN_DIST = 2 * BUBBLE_RADIUS
MAX_RANGE = 400
NUM_RANGE_SENSORS = 8
DELTA_ANG = 3.14 / NUM_RANGE_SENSORS #2 * 3.14 / NUM_RANGE_SENSORS

TYPE_VOID = 0
TYPE_WALL = 1
TYPE_FOOD = 2
TYPE_POISON = 3
TYPE_AGENT = 4
SENSOR_DATA_DIM = 1

ACT_UP = 0
ACT_DOWN = 1
ACT_LEFT = 2
ACT_RIGHT = 3
SEED = 0


@dataclass
class BubbleStatus(object):
    pos_x: jnp.float32
    pos_y: jnp.float32
    vel_x: jnp.float32
    vel_y: jnp.float32
    bubble_type: jnp.int32
    valid: jnp.int32
    direction: jnp.float32



@dataclass
class State(TaskState):
    agent_state: BubbleStatus
    item_state: BubbleStatus
    obs: jnp.ndarray
    steps: jnp.int32
    key: jnp.ndarray


@partial(jax.vmap, in_axes=(0, None))
def create_bubbles(key: jnp.ndarray, is_agent: bool) -> BubbleStatus:
    center_x = random.uniform(key.val[-1]) * (SCREEN_W - 200) + 100
    center_y = random.uniform(key.val[-2]) * (SCREEN_H - 200) + 100
    k_pos_r, k_pos_theta, k_vel, k_center_x, k_center_y = random.split(key, 5)

    if is_agent:
        bubble_type = TYPE_AGENT
        #vel_x = random.uniform(k_vel, shape=(), minval=5.0, maxval=5.0)
        #vel_y = random.uniform(k_vel, shape=(), minval=5.0, maxval=5.0)
        #r = random.uniform(k_pos_r, shape=(), minval=0, maxval=200)
        #theta = random.uniform(k_pos_theta, shape=(), minval=0, maxval=2 * jnp.pi)
        pos_x =  random.uniform(k_center_x, shape=(), minval=10, maxval=SCREEN_W-10)#center_x + random.normal(k_center_x, shape=()) * 100
        pos_y =  random.uniform(k_center_y, shape=(), minval=10, maxval=SCREEN_H-10)#center_y + random.normal(k_center_y, shape=()) * 100
        direction = random.uniform(k_pos_theta, shape=(), minval=-3.14, maxval=3.14)
        vel_x = jnp.cos(direction)*5.0
        vel_y = jnp.sin(direction)*5.0
    else:
        bubble_type = TYPE_FOOD
        vel_x = vel_y = 0.
        pos_x = random.uniform(k_pos_r, shape=(),minval=100, maxval=SCREEN_W-100)
        pos_y = random.uniform(k_vel, shape=(),minval=100, maxval=SCREEN_H-100)
        direction = random.uniform(k_pos_theta, shape=(), minval=-3.14, maxval=3.14)

    return BubbleStatus(pos_x=pos_x, pos_y=pos_y, vel_x=vel_x, vel_y=vel_y,
                        bubble_type=bubble_type, valid=1, direction=direction)

'''
def get_item_move(direction_x: jnp.ndarray,direction_y: jnp.ndarray,
               items: BubbleStatus) -> Tuple[BubbleStatus,
                                             jnp.float32]:
    dist = jnp.sqrt(jnp.square(items.pos_x - direction_x) +
                    jnp.square(items.pos_y - direction_y))
    dist_bin = jnp.where(dist==jnp.min(dist),1,0)
    nearest_prey_x = jnp.sum(jnp.multiply(dist_bin,direction_x))
    nearest_prey_y = jnp.sum(jnp.multiply(dist_bin, direction_y))


    orient = -jnp.arctan2(nearest_prey_y - items.pos_y, nearest_prey_x - items.pos_x)
    vel_x = jnp.cos(orient) * 5.0
    vel_y = jnp.sin(orient) * 5.0

    pos_x = items.pos_x + vel_x
    pos_y = items.pos_y - vel_y
    # Collide with the west wall.
    pos_x = jnp.where(pos_x < 1, pos_x + SCREEN_W - 1, pos_x)
    # Collide with the east wall.
    pos_x = jnp.where(pos_x > SCREEN_W - 1, pos_x - SCREEN_W + 1, pos_x)
    # Collide with the north wall.
    pos_y = jnp.where(pos_y < 1, pos_y + SCREEN_H - 1, pos_y)
    # Collide with the south wall.
    pos_y = jnp.where(pos_y > SCREEN_H - 1, pos_y - SCREEN_H + 1, pos_y)

    reward = 0


    items_state = BubbleStatus(
        pos_x=pos_x, pos_y=pos_y,
        vel_x=items.vel_x, vel_y=items.vel_y,
        bubble_type=items.bubble_type,
        valid=items.valid, direction=orient)
    return items_state, reward
'''
def get_reward(items: BubbleStatus,
               agent: BubbleStatus,
               key: jnp.ndarray,
               agent_distance) -> Tuple[BubbleStatus,
                                             BubbleStatus,
                                             jnp.float32]:
    #noise_dist,pp = random.split(key)
    predator_visual = jnp.pi*(3/4)
    old_dir_pred = items.direction
    dist = jnp.sqrt(jnp.square(items.pos_x - agent.pos_x) +
                    jnp.square(items.pos_y - agent.pos_y))*agent.valid

    #angles between the predator and the preys
    angles = -jnp.arctan2(agent.pos_y - items.pos_y, agent.pos_x - items.pos_x)
    angles = (angles - old_dir_pred).ravel()
    angles = jnp.where(angles > jnp.pi, angles - (jnp.pi * 2), angles)
    angles = jnp.where(angles < -jnp.pi, angles + (jnp.pi * 2), angles)
    angles_bin = jnp.where((angles > -predator_visual) & (angles < predator_visual), 1.0, 0.0)
    angles_dist = angles_bin*dist

    #angles_dist =  jnp.where(angles_dist>0.0,angles_dist + random.uniform(noise_dist, shape=jnp.shape(angles_dist), minval=-20, maxval=20),0.0)
    #agent_valid = (dist >= BUBBLE_RADIUS*2) * items.valid
    #dist = jnp.where(dist==0.0,1000,dist)
    #angles_dist = jnp.where(angles_dist == 0.0, 1000, angles_dist)
    #dist_bin = jnp.where(angles_dist==jnp.min(angles_dist),1,0)


    nearest_prey_x = agent.pos_x[jnp.argmin(jnp.where(angles_dist == 0.0, 1000, angles_dist))]#jnp.sum(jnp.multiply(dist_bin,agent.pos_x))
    nearest_prey_y = agent.pos_y[jnp.argmin(jnp.where(angles_dist == 0.0, 1000, angles_dist))]#jnp.sum(jnp.multiply(dist_bin, agent.pos_y))


    orient = -jnp.arctan2(nearest_prey_y - items.pos_y, nearest_prey_x - items.pos_x)
    vel_x = jnp.cos(orient) * 6.0
    vel_y = jnp.sin(orient) * 6.0

    pos_x = items.pos_x + vel_x
    pos_y = items.pos_y - vel_y
    # Collide with the west wall.
    pos_x = jnp.where(pos_x < 1, pos_x + SCREEN_W - 1, pos_x)
    # Collide with the east wall.
    pos_x = jnp.where(pos_x > SCREEN_W - 1, pos_x - SCREEN_W + 1, pos_x)
    # Collide with the north wall.
    pos_y = jnp.where(pos_y < 1, pos_y + SCREEN_H - 1, pos_y)
    # Collide with the south wall.
    pos_y = jnp.where(pos_y > SCREEN_H - 1, pos_y - SCREEN_H + 1, pos_y)

    dist = jnp.sqrt(jnp.square(pos_x - agent.pos_x) +
                    jnp.square(pos_y - agent.pos_y)) * agent.valid

    eat_prob = jnp.where((dist > 0.0) & (dist <= EATING_RADIUS*2), 1-(dist/(EATING_RADIUS*2)), 0.0)
    escape_prob = agent_distance#(1-(agent_distance/jnp.amax(agent_distance)))*0.5
    cumulative_prob = jnp.where((dist > 0.0) & (dist <= EATING_RADIUS*2),(eat_prob-2*escape_prob),0.0)
    

    rnd_gen = random.uniform(key[0], shape=jnp.shape(cumulative_prob), minval=0, maxval=1)#random.uniform(pp, shape=jnp.shape(cumulative_prob), minval=0, maxval=1)
    
    rnd_process = jnp.where(cumulative_prob==jnp.amax(cumulative_prob),cumulative_prob,0)
    eaten = (rnd_process>rnd_gen)
    reward = -jnp.sum(eaten)
    agent_valid = (1-eaten)*agent.valid
    #reward = jnp.any(jnp.where((dist > 0.0) & (dist <= BUBBLE_RADIUS*2), 1.0, 0.0))
    #reward = jnp.where(reward==True,-1,0)
    #agent_valid = (dist >= BUBBLE_RADIUS * 2)

    #dist = jnp.sqrt(jnp.square(agent.pos_x - items.pos_x) +
     #               jnp.square(agent.pos_y - items.pos_y))
    #rewards = (jnp.where(items.bubble_type == TYPE_FOOD, 1., -1.) *
    #           items.valid * jnp.where(dist < MIN_DIST, 1, 0))
    #poison_cnt = jnp.sum(jnp.where(rewards == -1., 1, 0)) + agent.poison_cnt
    #reward = jnp.sum(rewards)

    agent_state = BubbleStatus(
        pos_x=agent.pos_x, pos_y=agent.pos_y,
        vel_x=agent.vel_x, vel_y=agent.vel_y,
        bubble_type=agent.bubble_type,
        valid=agent_valid, direction=agent.direction)
    items_state = BubbleStatus(
        pos_x=pos_x, pos_y=pos_y,
        vel_x=items.vel_x, vel_y=items.vel_y,
        bubble_type=items.bubble_type,
        valid=items.valid, direction=orient)
    return agent_state, items_state, reward


@partial(jax.vmap, in_axes=(0, None,None,None))
def get_rewards(items: BubbleStatus,
                agents: BubbleStatus,
                key: jnp.ndarray,
                agent_distance: jnp.array) -> Tuple[BubbleStatus,
                                              BubbleStatus,
                                              jnp.ndarray]:
    return get_reward(items, agents,key,agent_distance)



@jax.vmap
def update_item_state(item: BubbleStatus) -> BubbleStatus:
    '''
    vel_x = item.vel_x
    vel_y = item.vel_y
    pos_x = item.pos_x + vel_x
    pos_y = item.pos_y + vel_y
    # Collide with the west wall.
    vel_x = jnp.where(pos_x < 1, -vel_x, vel_x)
    pos_x = jnp.where(pos_x < 1, 1, pos_x)
    # Collide with the east wall.
    vel_x = jnp.where(pos_x > SCREEN_W - 1, -vel_x, vel_x)
    pos_x = jnp.where(pos_x > SCREEN_W - 1, SCREEN_W - 1, pos_x)
    # Collide with the north wall.
    vel_y = jnp.where(pos_y < 1, -vel_y, vel_y)
    pos_y = jnp.where(pos_y < 1, 1, pos_y)
    # Collide with the south wall.
    vel_y = jnp.where(pos_y > SCREEN_H - 1, -vel_y, vel_y)
    pos_y = jnp.where(pos_y > SCREEN_H - 1, SCREEN_H - 1, pos_y)
    '''
    return BubbleStatus(
        pos_x=item.pos_x, pos_y=item.pos_y, vel_x=item.vel_x, vel_y=item.vel_y,
        bubble_type=item.bubble_type, valid=item.valid,
        direction=item.direction)

@jax.vmap
def update_agent_state(agent: BubbleStatus,
                       action: jnp.ndarray,key: jnp.ndarray) -> BubbleStatus:
    old_dir = agent.direction
    k_noise_x, k_noise_y, k_noise_px, k_noise_py = random.split(key, 4)
    '''
    vel_x = agent.vel_x
    vel_y = agent.vel_y
    w = (1/(1 + jnp.exp(-weights)))*10.0
    flock_dir = direction[:NUM_RANGE_SENSORS-1]
    w_f = w[NUM_RANGE_SENSORS-1]

    # swarm flocking part
    k_noise_x, k_noise_y, k_noise_px, k_noise_py = random.split(key, 4)
    divisor = jnp.sum(jnp.where(flock_dir != 0.0, 1, 0)*w_f)
    divisor = jnp.where(divisor == 0.0, 1.0, divisor)
    # circular mean of the angles
    y_comp = jnp.sum(jnp.sin(flock_dir)*w_f) / divisor
    x_comp = jnp.sum(jnp.where(flock_dir == 0.0, 0.0, jnp.cos(flock_dir))*w_f) / divisor
    av_angle = jnp.arctan2(y_comp, x_comp)
    orient_flock = av_angle
    #orient = jnp.where(orient_flock==0.0,old_dir,jnp.arctan2((jnp.sin(orient_flock)+jnp.sin(old_dir))/2,(jnp.cos(orient_flock)+jnp.cos(old_dir))/2))

    #predator part
    predator_dir = jnp.sum(direction[NUM_RANGE_SENSORS - 1:])
    angle_to_predator = jnp.where(predator_dir == 0, 0.0, predator_dir - jnp.pi)
    angle_to_predator = jnp.where(angle_to_predator > jnp.pi, angle_to_predator - (jnp.pi * 2), angle_to_predator)
    angle_to_predator = jnp.where(angle_to_predator < -jnp.pi, angle_to_predator + (jnp.pi * 2), angle_to_predator)
    #angle_to_predator= angle_to_predator - jnp.pi*jnp.sign(angle_to_predator)

    orient = jnp.where(angle_to_predator == 0.0, orient_flock,
                       jnp.arctan2((jnp.sin(orient_flock)*w[-3] + jnp.sin(angle_to_predator)*w[-2] + jnp.sin(old_dir)*w[-1]) / (w[-3]+w[-2]+w[-1]),
                                   (jnp.cos(orient_flock)*w[-3] + jnp.cos(angle_to_predator)*w[-2] + jnp.cos(old_dir)*w[-1]) / (w[-3]+w[-2]+w[-1])))
    orient += random.uniform(k_noise_y, shape=(), minval=-3.14, maxval=3.14) * 0.1
    '''
    orient = (old_dir + jnp.pi) - jnp.clip(action[0],-1,1)
    orient = jnp.where(orient > jnp.pi * 2, orient - (jnp.pi * 2),orient)
    orient = jnp.where(orient < 0.0, orient + (jnp.pi * 2), orient)
    orient -= jnp.pi
    orient += random.uniform(k_noise_x, shape=(), minval=-3.14, maxval=3.14) * 0.1
    vel_x = jnp.cos(orient) * 5.0
    vel_y = jnp.sin(orient) * 5.0
    pos_x = agent.pos_x + vel_x
    pos_y = agent.pos_y - vel_y
    orient = -jnp.arctan2(pos_y - agent.pos_y, pos_x - agent.pos_x)
    # Collide with the west wall.
    pos_x = jnp.where(pos_x < 1, pos_x + SCREEN_W - 1, pos_x)
    # Collide with the east wall.
    pos_x = jnp.where(pos_x > SCREEN_W - 1, pos_x - SCREEN_W + 1, pos_x)
    # Collide with the north wall.
    pos_y = jnp.where(pos_y < 1, pos_y + SCREEN_H - 1, pos_y)
    # Collide with the south wall.
    pos_y = jnp.where(pos_y > SCREEN_H - 1, pos_y - SCREEN_H + 1, pos_y)



    return BubbleStatus(
        pos_x=pos_x, pos_y=pos_y, vel_x=vel_x, vel_y=vel_y,
        bubble_type=agent.bubble_type, valid=agent.valid,
        direction=orient)


@jax.vmap
def get_line_seg_intersection(x1: jnp.float32,
                              y1: jnp.float32,
                              x2: jnp.float32,
                              y2: jnp.float32,
                              x3: jnp.float32,
                              y3: jnp.float32,
                              x4: jnp.float32,
                              y4: jnp.float32) -> Tuple[np.bool, jnp.ndarray]:
    """Determine if line segment (x1, y1, x2, y2) intersects with line
    segment (x3, y3, x4, y4), and return the intersection coordinate.
    """
    denominator = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
    ua = jnp.where(
        jnp.isclose(denominator, 0.0), 0,
        ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / denominator)
    mask1 = jnp.bitwise_and(ua > 0., ua < 1.)
    ub = jnp.where(
        jnp.isclose(denominator, 0.0), 0,
        ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / denominator)
    mask2 = jnp.bitwise_and(ub > 0., ub < 1.)
    intersected = jnp.bitwise_and(mask1, mask2)
    x_intersection = x1 + ua * (x2 - x1)
    y_intersection = y1 + ua * (y2 - y1)
    up = jnp.where(intersected,
                   jnp.array([x_intersection, y_intersection]),
                   jnp.array([SCREEN_W, SCREEN_W]))
    return intersected, up


@jax.vmap
def get_line_dot_intersection(x1: jnp.float32,
                              y1: jnp.float32,
                              x2: jnp.float32,
                              y2: jnp.float32,
                              x3: jnp.float32,
                              y3: jnp.float32) -> Tuple[np.bool, jnp.ndarray]:
    """Determine if a line segment (x1, y1, x2, y2) intersects with a dot at
    (x3, y3) with radius BUBBLE_RADIUS, if so return the point of intersection.
    """
    point_xy = jnp.array([x3, y3])
    v = jnp.array([y2 - y1, x1 - x2])
    v_len = jnp.linalg.norm(v)
    d = jnp.abs((x2 - x1) * (y1 - y3) - (x1 - x3) * (y2 - y1)) / v_len
    up = point_xy + v / v_len * d
    ua = jnp.where(jnp.abs(x2 - x1) > jnp.abs(y2 - y1),
                   (up[0] - x1) / (x2 - x1),
                   (up[1] - y1) / (y2 - y1))
    ua = jnp.where(d > BUBBLE_RADIUS, 0, ua)
    intersected = jnp.bitwise_and(ua > 0., ua < 1.)
    return intersected, up


@partial(jax.vmap, in_axes=(0, None, None, None))
def get_obs(agent: BubbleStatus,
            agents: BubbleStatus,
            items: BubbleStatus,
            walls: jnp.ndarray) -> Tuple[np.float32, jnp.ndarray]:
    sensor_obs = []
    distance_obs = []
    treshold_dist=100

    #prev_eating = jnp.array([agent.eat_cnt]).ravel()

    agent_xy = jnp.array([agent.pos_x, agent.pos_y]).ravel()
   # print(agent_xy)


    agent_dir = jnp.array([agent.direction]).ravel()
    agents_dir = jnp.array([agents.direction]).ravel()
    dist = (jnp.sqrt(jnp.square(agent_xy[0] - agents.pos_x) +
                    jnp.square(agent_xy[1] - agents.pos_y))*agents.valid).ravel()
    dist_binary = jnp.where((dist > 0.0) & (dist < treshold_dist), 1, 0)
    dist_flocking = jnp.where((dist > 0.0) & (dist < treshold_dist/2), 1, 0)
    angles = -jnp.arctan2(agents.pos_y - agent_xy[1],agents.pos_x-agent_xy[0])
    angles = (angles - agent_dir).ravel()
    angles = jnp.where(angles > jnp.pi, angles - (jnp.pi * 2), angles)
    angles = jnp.where(angles < -jnp.pi, angles + (jnp.pi * 2), angles)
    #predator
    dist_pred = (jnp.sqrt(jnp.square(agent_xy[0] - items.pos_x) +
                    jnp.square(agent_xy[1] - items.pos_y))).ravel()
    dist_pred_binary = jnp.where((dist_pred>0.0)&(dist_pred<treshold_dist*2.0),1,0)
    pred_angle = -jnp.arctan2(items.pos_y - agent_xy[1], items.pos_x - agent_xy[0])

    predator_obs_angle = jnp.multiply(pred_angle, dist_pred_binary)
    #pred_angle = (pred_angle - agent_dir).ravel()
    #pred_angle = jnp.where(pred_angle > jnp.pi, pred_angle - (jnp.pi * 2), pred_angle)
    #pred_angle = jnp.where(pred_angle < -jnp.pi, pred_angle + (jnp.pi * 2), pred_angle)

    step_theta = np.linspace(-jnp.pi*(3 / 4), jnp.pi*( 3/ 4), NUM_RANGE_SENSORS)
    #angles = jnp.delete(angles, jnp.where(angles == 0.0))
    #print(dist)
    #print(angles)
    #input("getobs")
    #dist = dist[dist!=0]
    #angles = angles[angles!=0]
    #dirs = agents.direction[]
    #comm_array = jnp.argwhere(dist_mat)

    #active_comunication = jnp.multiply(dist_mat,agent_comm)
    #perceived_angle = jnp.multiply(active_comunication,angles)
    #perceived_angle = jnp.sum(perceived_angle)
    #angles must be averaged by active comunicators


    #comm_activation = jnp.sum(active_comunication)
    #divisor = jnp.where(comm_activation==0.0,1.0,comm_activation)

    for i in range(NUM_RANGE_SENSORS-1):

        v = jnp.where((angles > step_theta[i]) & (angles < step_theta[i + 1]), 1.0, 0.0)
        #p = jnp.where((pred_angle > step_theta[i]) & (pred_angle < step_theta[i + 1]), 1.0, 0.0)
        d = jnp.multiply(v, dist_binary)
        d_flock = jnp.multiply(v, dist_flocking)
        #d_p = jnp.multiply(p, dist_pred)
        value = jnp.where(jnp.sum(d) == 0.0, 0, agents_dir[jnp.argmin(jnp.where(d == 0.0, 1000, d))])
        #value_dist = dist[jnp.argmin(jnp.where(dd == 0.0, 1000, dd))]
        value_dist = jnp.where(jnp.sum(d_flock) == 0.0, 0, 1.0)
        clockwise_distance = (agent_dir+jnp.pi)-(value+jnp.pi)
        clockwise_distance = jnp.where(clockwise_distance > jnp.pi*2, clockwise_distance - (jnp.pi * 2), clockwise_distance)
        clockwise_distance = jnp.where(clockwise_distance < 0.0, clockwise_distance + (jnp.pi * 2), clockwise_distance)
        anticlockwise_distance = clockwise_distance-jnp.pi*2
        distance_array= jnp.array([clockwise_distance,anticlockwise_distance])
        idx = jnp.argmin(jnp.abs(distance_array))
        angular_distance = jnp.where(value==0.0,0.0,distance_array[idx])
        #pred_value = jnp.where(jnp.sum(d_p) == 0.0, 0, items.direction)
        sensor_obs.append(angular_distance)
        distance_obs.append(value_dist)
        #predator_obs.append(pred_value)
        #prova_obs.append(ff)
        #print(sensor_obs)
        #print(prova_obs)
        #input("whw")
    pred_angle = (pred_angle+jnp.pi) +jnp.pi
    clockwise_distance_predator = (agent_dir+jnp.pi)-(pred_angle)
    clockwise_distance_predator = jnp.where(clockwise_distance_predator > jnp.pi * 2, clockwise_distance_predator - (jnp.pi * 2),clockwise_distance_predator)
    clockwise_distance_predator = jnp.where(clockwise_distance_predator < 0.0, clockwise_distance_predator + (jnp.pi * 2), clockwise_distance_predator)
    anticlockwise_distance_predator = clockwise_distance_predator - (jnp.pi*2)
    distance_array = jnp.array([clockwise_distance_predator, anticlockwise_distance_predator])
    idx = jnp.argmin(jnp.abs(distance_array))
    angular_distance = distance_array[idx]
    predator_angle = jnp.stack(angular_distance*dist_pred_binary)
    predator_distance = jnp.where(angular_distance*dist_pred_binary==0.0,0.0,dist_pred/(treshold_dist*2.0))
    sensor_obs = jnp.stack(sensor_obs)
    distance_obs = jnp.stack(distance_obs)
    predator_distance = jnp.stack(predator_distance)
   # predator_angle = jnp.stack(predator_angle)


    return jnp.mean(distance_obs),jnp.concatenate([sensor_obs.ravel(),predator_angle.ravel()])#jnp.concatenate([sensor_obs.ravel(), pos_xy,eating,perc_angle,comm], axis=0)

@jax.vmap
def select_direction(key: jnp.ndarray, action_prob: jnp.ndarray) -> jnp.int32:
    return random.choice(key, 4, replace=False, p=action_prob.ravel())


class PredatorFlocking(VectorizedTask):
    """Water world, multi-agents training version."""

    def __init__(self,
                 num_agents: int = 16,
                 num_items: int = 1,
                 max_steps: int = 1000,
                 test: bool = False):

        self.multi_agent_training = True
        self.num_agents=num_agents

        self.max_steps = max_steps
        self.test = test
        # num range sensor + position(x,y) + eating sensor + comm sensors
        self.obs_shape = tuple([
            num_agents, (NUM_RANGE_SENSORS -1) + 1 , ])
        self.act_shape = tuple([num_agents, 1])
        walls = jnp.array([[0, 0, 0, SCREEN_H],
                           [0, SCREEN_H, SCREEN_W, SCREEN_H],
                           [SCREEN_W, SCREEN_H, SCREEN_W, 0],
                           [SCREEN_W, 0, 0, 0]])

        def reset_fn(key):
            next_key, key = random.split(key)
            ks = random.split(key, num_agents + num_items)
            agents = create_bubbles(ks[:num_agents], True)
            items = create_bubbles(ks[num_agents:], False)

            _,obs = get_obs(agents, agents, items, walls)
            return State(agent_state=agents, item_state=items, obs=obs,
                         steps=jnp.zeros((), dtype=jnp.int32), key=next_key)

        self._reset_fn = jax.jit(jax.vmap(reset_fn))

        def step_fn(state, action):

            next_key, key = random.split(state.key)
            ks = random.split(key, num_agents)
            ki = random.split(key, num_items+1)
            #action_keys = random.split(key, num_agents)
            directions = state.obs#action#select_direction(action_keys, action)

            agents = update_agent_state(state.agent_state,action,ks)

            #items = update_item_state(state.item_state)
            agt_x = agents.pos_x.ravel()
            agt_y = agents.pos_y.ravel()

            #items, rewards = get_item_move(agt_x,agt_y, state.item_state)


            steps = state.steps + 1
            done = jnp.where(steps >= max_steps, 1, 0)
            distances, obs = get_obs(agents, agents, state.item_state, walls)
            agents, items, rewards = get_rewards(state.item_state, agents, ki,distances)

            # rewards=0
            # items_state.shape=(num_agents, num_items), merge to (num_items, ).
            agents = BubbleStatus(
                pos_x=agents.pos_x[0], pos_y=agents.pos_y[0],
                vel_x=agents.vel_x[0], vel_y=agents.vel_y[0],
                bubble_type=agents.bubble_type[0],
                direction=agents.direction[0],
                valid=jnp.prod(agents.valid, axis=0))
            return State(agent_state=agents, item_state=items, obs=obs,
                         steps=steps, key=next_key), rewards, done

        self._step_fn = jax.jit(jax.vmap(step_fn))

    def reset(self, key: jnp.array) -> State:
        return self._reset_fn(key)

    def step(self,
             state: State,
             action: jnp.ndarray) -> Tuple[State, jnp.ndarray, jnp.ndarray]:
        return self._step_fn(state, action)

    @staticmethod
    def render(state: State, task_id: int = 0) -> Image:
        no_printvariables = 5
        img = Image.new('RGB', (int(SCREEN_W*SCALE), int(SCREEN_H*SCALE)), (255, 255, 255))
        draw = ImageDraw.Draw(img)
        state = tree_util.tree_map(lambda s: s[task_id], state)
        # Draw the items.

        items = state.item_state

        for v, t, x, y,d in zip(np.array(items.valid, dtype=bool),
                              np.array(items.bubble_type, dtype=int),
                              np.array(items.pos_x),
                              np.array(items.pos_y),
                              np.array(items.direction)):
            if v:
                color = (255, 0, 0) if t == TYPE_FOOD else (255, 0, 0)
                draw.ellipse(
                    ((x - EATING_RADIUS*2)*SCALE, (y - EATING_RADIUS*2)*SCALE,
                     (x + EATING_RADIUS*2)*SCALE, (y + EATING_RADIUS*2)*SCALE,),
                    fill=color, outline=(0, 0, 0))



        # Draw the agent.
        agents = state.agent_state

        for i, (v,x, y,d) in enumerate(zip(agents.valid,agents.pos_x, agents.pos_y,agents.direction)):
            #for j, obs in enumerate(sensor_data[i]):
            #    ang = j * DELTA_ANG - (np.pi/2)
            #    dist = np.min(obs[:])
            #    x_end = x + dist * MAX_RANGE * np.cos(ang)
            #    y_end = y + dist * MAX_RANGE * np.sin(ang)
            #    draw.line((x, y, x_end, y_end), fill=(0, 0, 0), width=1)
           if v:
                draw.ellipse(
                    ((x - BUBBLE_RADIUS * 2)*SCALE, (y - BUBBLE_RADIUS * 2)*SCALE,
                     (x + BUBBLE_RADIUS * 2)*SCALE, (y + BUBBLE_RADIUS * 2)*SCALE),
                    fill=(0, 255, 0), outline=(0, 0, 0))


                x_end = x + (2* BUBBLE_RADIUS)  * np.cos(d)
                y_end = y - (2* BUBBLE_RADIUS) * np.sin(d)

                draw.line((x*SCALE, y*SCALE, x_end*SCALE, y_end*SCALE), fill=(0, 0, 0), width=1)




        return img

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y4: jnp.float32) -> Tuple[np.bool, jnp.ndarray]:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y3: jnp.float32) -> Tuple[np.bool, jnp.ndarray]:


In [None]:
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial
from typing import Tuple
from typing import Union

import jax
import jax.numpy as jnp
from jax import random
from jax.tree_util import tree_map

from evojax.obs_norm import ObsNormalizer
from evojax.task.base import TaskState
from evojax.task.base import VectorizedTask
from evojax.policy.base import PolicyState
from evojax.policy.base import PolicyNetwork
from evojax.util import create_logger


@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
def get_task_reset_keys(key: jnp.ndarray,
                        test: bool,
                        pop_size: int,
                        n_tests: int,
                        n_repeats: int,
                        ma_training: bool) -> Tuple[jnp.ndarray, jnp.ndarray]:
    key, subkey = random.split(key=key)
    if ma_training:
        reset_keys = random.split(subkey, n_repeats)
        reset_keys = jnp.tile(reset_keys, (pop_size, 1))
    else:
        if test:
            reset_keys = random.split(subkey, n_tests * n_repeats)

        else:
            reset_keys = random.split(subkey, n_repeats)
            reset_keys = jnp.tile(reset_keys, (pop_size, 1))
    return key, reset_keys


@jax.jit
def split_params_for_pmap(param: jnp.ndarray) -> jnp.ndarray:
    return jnp.stack(jnp.split(param, jax.local_device_count()))


@jax.jit
def split_states_for_pmap(
        state: Union[TaskState, PolicyState]) -> Union[TaskState, PolicyState]:
    return tree_map(split_params_for_pmap, state)


@jax.jit
def reshape_data_from_pmap(data: jnp.ndarray) -> jnp.ndarray:
    # data.shape = (#device, steps, #jobs/device, *)
    data = data.transpose([1, 0] + [i for i in range(2, data.ndim)])
    return jnp.reshape(data, (data.shape[0], data.shape[1] * data.shape[2], -1))


@partial(jax.jit, static_argnums=(1, 2))
def duplicate_params(params: jnp.ndarray,
                     repeats: int,
                     ma_training: bool) -> jnp.ndarray:
    if ma_training:
        return jnp.repeat(params, repeats= repeats,
                          axis=0)  # jnp.tile(params, (repeats, ) + (1,) * (params.ndim - 1))
    else:
        return jnp.repeat(params, repeats=repeats, axis=0)


class MSimManager(object):
    """Simulation manager."""

    def __init__(self,
                 n_repeats: int,
                 test_n_repeats: int,
                 pop_size: int,
                 agents: int,
                 n_evaluations: int,
                 policy_net: PolicyNetwork,
                 train_vec_task: VectorizedTask,
                 valid_vec_task: VectorizedTask,
                 seed: int = 0,
                 obs_normalizer: ObsNormalizer = None,
                 logger: logging.Logger = None):
        """Initialization function.

        Args:
            n_repeats - Number of repeated parameter evaluations.
            pop_size - Population size.
            n_evaluations - Number of evaluations of the best parameter.
            policy_net - Policy network.
            train_vec_task - Vectorized tasks for training.
            valid_vec_task - Vectorized tasks for validation.
            seed - Random seed.
            obs_normalizer - Observation normalization helper.
            logger - Logger.
        """

        if logger is None:
            self._logger = create_logger(name='SimManager')
        else:
            self._logger = logger

        self._key = random.PRNGKey(seed=seed)
        self._n_repeats = n_repeats
        self._test_n_repeats = test_n_repeats
        self._pop_size = pop_size
        self._agents = agents
        self._n_evaluations = max(n_evaluations, jax.local_device_count())
        self._ma_training = train_vec_task.multi_agent_training
        self.rendering = []

        self.obs_normalizer = obs_normalizer
        if self.obs_normalizer is None:
            self.obs_normalizer = ObsNormalizer(
                obs_shape=train_vec_task.obs_shape,
                dummy=True,
            )
        self.obs_params = self.obs_normalizer.get_init_params()
        self.num_obs = train_vec_task.obs_shape[-1]+1
        self._num_device = jax.local_device_count()
        if self._pop_size % self._num_device != 0:
            raise ValueError(
                'pop_size must be multiples of GPU/TPUs: '
                'pop_size={}, #devices={}'.format(
                    self._pop_size, self._num_device))
        if self._n_evaluations % self._num_device != 0:
            raise ValueError(
                'n_evaluations must be multiples of GPU/TPUs: '
                'n_evaluations={}, #devices={}'.format(
                    self._n_evaluations, self._num_device))

        def step_once(carry, input_data, task):
            (task_state, policy_state, params, obs_params,
             accumulated_reward, valid_mask) = carry
            if task.multi_agent_training:
                num_tasks, num_agents = task_state.obs.shape[:2]
                # print(num_tasks,num_agents)
                # input("www")
                task_state = task_state.replace(
                    obs=task_state.obs.reshape((-1, *task_state.obs.shape[2:])))

            org_obs = task_state.obs
            
            #normed_obs = self.obs_normalizer.normalize_obs(org_obs, obs_params)
            #task_state = task_state.replace(obs=normed_obs)
            #actions, policy_state = policy_net.get_actions(
            #    task_state, params, policy_state)
            #actions, policy_state = policy_net.get_actions(
            #    task_state, params[:,:-self.num_obs], policy_state)
            action_f = jnp.sum(org_obs[:, :-1] * params[:, -self.num_obs:-1], axis=1)
            action_p = org_obs[:, -1] * params[:, -1]
            actionss = jnp.expand_dims((action_f + action_p),axis=1)

            
            
            #actions = jnp.expand_dims((jnp.mean(org_obs[:,:-1],axis=1) + org_obs[:,-1])/2,axis=-1)
            actions = actionss
            
            if task.multi_agent_training:
                task_state = task_state.replace(
                    obs=task_state.obs.reshape(
                        (num_tasks, num_agents, *task_state.obs.shape[1:])))
                actions = actions.reshape(
                    (num_tasks, num_agents, *actions.shape[1:]))

            task_state, reward, done = task.step(task_state, actions)
            
            if task.multi_agent_training:
                reward = reward.ravel()
                reward = jnp.repeat(reward, num_agents, axis=0)
                done = jnp.repeat(done, num_agents, axis=0)

            accumulated_reward = accumulated_reward + reward * valid_mask
            valid_mask = valid_mask * (1 - done.ravel())

            return ((task_state, policy_state, params, obs_params,
                     accumulated_reward, valid_mask),
                    (org_obs, valid_mask))

        def rollout(task_states, policy_states, params, obs_params,
                    step_once_fn, max_steps):
            accumulated_rewards = jnp.zeros(params.shape[0])
            valid_masks = jnp.ones(params.shape[0])

            ((task_states, policy_states, params, obs_params,
              accumulated_rewards, valid_masks),
             (obs_set, obs_mask)) = jax.lax.scan(
                step_once_fn,
                (task_states, policy_states, params, obs_params,
                 accumulated_rewards, valid_masks), (), max_steps)

            return accumulated_rewards, obs_set, obs_mask

        def rollout_test(task_states, policy_states, params, obs_params,
                    step_once_fn, max_steps):
            accumulated_rewards = jnp.zeros(params.shape[0])
            self.rendering.append(train_vec_task.render(task_states))

            valid_masks = jnp.ones(params.shape[0])

            ((task_states, policy_states, params, obs_params,
              accumulated_rewards, valid_masks),
             (obs_set, obs_mask)) = jax.lax.scan(
                step_once_fn,
                (task_states, policy_states, params, obs_params,
                 accumulated_rewards, valid_masks), (), max_steps)

            return accumulated_rewards, obs_set, obs_mask

        self.policy = policy_net
        self.test_task = valid_vec_task
        self._policy_reset_fn = jax.jit(policy_net.reset)

        # Set up training functions.
        self._train_reset_fn = train_vec_task.reset

        self._train_rollout_fn = partial(
            rollout,
            step_once_fn=partial(step_once, task=train_vec_task),
            max_steps=train_vec_task.max_steps)
        if self._num_device > 1:
            self._train_rollout_fn = jax.jit(jax.pmap(
                self._train_rollout_fn, in_axes=(0, 0, 0, None)))
            

        # Set up validation functions.
        self._valid_reset_fn = valid_vec_task.reset
        self._valid_rollout_fn = partial(
            rollout,
            step_once_fn=partial(step_once, task=valid_vec_task),
            max_steps=valid_vec_task.max_steps)
        if self._num_device > 1:
            self._valid_rollout_fn = jax.jit(jax.pmap(
                self._valid_rollout_fn, in_axes=(0, 0, 0, None)))

    def eval_params(self, params: jnp.ndarray, test: bool) -> jnp.ndarray:
        """Evaluate population parameters or test the best parameter.

        Args:
            params - Parameters to be evaluated.
            test - Whether we are testing the best parameter
        Returns:
            An array of fitness scores.
        """

        policy_reset_func = self._policy_reset_fn
        if test:
            n_repeats = self._test_n_repeats
            task_reset_func = self._valid_reset_fn
            rollout_func = self._valid_rollout_fn
            params = duplicate_params(
                params[None, :], self._pop_size, False)

        else:
            n_repeats = self._n_repeats
            task_reset_func = self._train_reset_fn

            rollout_func = self._train_rollout_fn

        # Suppose pop_size=2 and n_repeats=3.
        # For multi-agents training, params become
        #   a1, a2, ..., an  (individual 1 params)
        #   b1, b2, ..., bn  (individual 2 params)
        #   a1, a2, ..., an  (individual 1 params)
        #   b1, b2, ..., bn  (individual 2 params)
        #   a1, a2, ..., an  (individual 1 params)
        #   b1, b2, ..., bn  (individual 2 params)
        # For non-ma training, params become
        #   a1, a2, ..., an  (individual 1 params)
        #   a1, a2, ..., an  (individual 1 params)
        #   a1, a2, ..., an  (individual 1 params)
        #   b1, b2, ..., bn  (individual 2 params)
        #   b1, b2, ..., bn  (individual 2 params)
        #   b1, b2, ..., bn  (individual 2 params)
        #print(params)
        #print("prima")
        params = duplicate_params(params, self._agents*n_repeats, self._ma_training)

        #print(params)
        #print(params[:10,:])
        #input("ee")
        
        self._key, reset_keys = get_task_reset_keys(
            self._key, test, self._pop_size, self._n_evaluations, n_repeats,
            self._ma_training)

        if test:
            params = params[:self._agents * n_repeats, :]
            
            reset_keys = reset_keys[:n_repeats, :]

        #print(self._key,reset_keys)
        #input("key")
        # Reset the tasks and the policy.
        #print(jnp.shape(reset_keys),jnp.shape(params))
        
        task_state = task_reset_func(reset_keys)

        policy_state = policy_reset_func(task_state)

        if self._num_device > 1:
            params = split_params_for_pmap(params)
            task_state = split_states_for_pmap(task_state)
            policy_state = split_states_for_pmap(policy_state)

        # Do the rollouts.

        scores, all_obs, masks = rollout_func(
              task_state, policy_state, params, self.obs_params)
        

        if self._num_device > 1:
            all_obs = reshape_data_from_pmap(all_obs)
            masks = reshape_data_from_pmap(masks)

        if not test and not self.obs_normalizer.is_dummy:
            
            self.obs_params = self.obs_normalizer.update_normalization_params(
                obs_buffer=all_obs, obs_mask=masks, obs_params=self.obs_params)

        if self._ma_training:
            if not test:
                # In training, each agent has different parameters.
                # return jnp.mean(scores.ravel().reshape((n_repeats, -1)), axis=0)
                tmp = scores.ravel().reshape((self._pop_size * n_repeats, -1))
                
                tmp = jnp.mean(tmp, axis=1).reshape(self._pop_size, n_repeats)
                  
                return jnp.mean(tmp, axis=1)
            else:

                tmp = scores.ravel().reshape((n_repeats, self._agents))
                print(jnp.mean(tmp, axis=1))
                gif_file = os.path.join('/content/gdrive/My Drive/log/water_world_ma', 'base_scores.npy')
                np.save(gif_file,jnp.mean(tmp, axis=1))
                
                # In tests, they share the same parameters.
                return jnp.mean(tmp, axis=1)
        else:
            return jnp.mean(scores.ravel().reshape((-1, n_repeats)), axis=-1)

In [None]:
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
import numpy as np

from evojax.task.base import VectorizedTask
from evojax.policy import PolicyNetwork
from evojax.algo import NEAlgorithm
from evojax.sim_mgr import SimManager
from evojax.obs_norm import ObsNormalizer
from evojax.util import create_logger
from evojax.util import load_model
from evojax.util import save_model


class MTrainer(object):
    """A trainer that organizes the training logistics."""

    def __init__(self,
                 policy: PolicyNetwork,
                 solver: NEAlgorithm,
                 train_task: VectorizedTask,
                 test_task: VectorizedTask,
                 max_iter: int = 1000,
                 log_interval: int = 20,
                 test_interval: int = 100,
                 n_repeats: int = 1,
                 test_n_repeats: int = 1,
                 n_evaluations: int = 100,
                 num_agents: int = 64,
                 seed: int = 42,
                 debug: bool = False,
                 normalize_obs: bool = False,
                 model_dir: str = None,
                 log_dir: str = None,
                 logger: logging.Logger = None):
        """Initialization.

        Args:
            policy - The policy network to use.
            solver - The ES algorithm for optimization.
            train_task - The task for training.
            test_task - The task for evaluation.
            max_iter - Maximum number of training iterations.
            log_interval - Interval for logging.
            test_interval - Interval for tests.
            n_repeats - Number of rollout repetitions.
            n_evaluations - Number of tests to conduct.
            seed - Random seed to use.
            debug - Whether to turn on the debug flag.
            normalize_obs - Whether to use an observation normalizer.
            model_dir - Directory to save/load model.
            log_dir - Directory to dump logs.
            logger - Logger.
        """

        if logger is None:
            self._logger = create_logger(
                name='Trainer', log_dir=log_dir, debug=debug)
        else:
            self._logger = logger

        self._log_interval = log_interval
        self._test_interval = test_interval
        self._max_iter = max_iter
        self.model_dir = model_dir
        self._log_dir = log_dir
        self.seed = seed
        self._obs_normalizer = ObsNormalizer(
            obs_shape=train_task.obs_shape,
            dummy=not normalize_obs,
        )

        self.solver = solver
        self.task = test_task
        self.policy = policy
        self.sim_mgr = MSimManager(
            n_repeats=n_repeats,
            test_n_repeats=test_n_repeats,
            pop_size=solver.pop_size,
            agents=num_agents,
            n_evaluations=n_evaluations,
            policy_net=policy,
            train_vec_task=train_task,
            valid_vec_task=test_task,
            seed=seed,
            obs_normalizer=self._obs_normalizer,
            logger=self._logger,
        )

    def run(self, demo_mode: bool = False) -> float:
        """Start the training / test process."""

        if self.model_dir is not None:
            params, obs_params = load_model(model_dir=self.model_dir)

            self.sim_mgr.obs_params = obs_params
            self._logger.info(
                'Loaded model parameters from {}.'.format(self.model_dir))
        else:
            params = None

        if demo_mode:
            if params is None:
                raise ValueError('No policy parameters to evaluate.')
            self._logger.info('Start to test the parameters.')
            scores = np.array(
                self.sim_mgr.eval_params(params=params, test=True))
            self._logger.info(
                '[TEST] #tests={0}, max={1:.4f}, avg={2:.4f}, min={3:.4f}, '
                'std={4:.4f}'.format(scores.size, scores.max(), scores.mean(),
                                     scores.min(), scores.std()))
            return scores.mean()
        else:
            self._logger.info(
                'Start to train for {} iterations.'.format(self._max_iter))

            if params is not None:
                # Continue training from the breakpoint.
                self.solver.best_params = params
                best_params = params

            best_score = -float('Inf')
            fitness_ = []
            for i in range(self._max_iter):

                start_time = time.perf_counter()
                params = self.solver.ask()
                
                self._logger.debug('solver.ask time: {0:.4f}s'.format(
                    time.perf_counter() - start_time))

                start_time = time.perf_counter()
                scores = self.sim_mgr.eval_params(params=params, test=False)

                self._logger.debug('sim_mgr.eval_params time: {0:.4f}s'.format(
                    time.perf_counter() - start_time))

                start_time = time.perf_counter()
                self.solver.tell(fitness=scores)
                self._logger.debug('solver.tell time: {0:.4f}s'.format(
                    time.perf_counter() - start_time))

                if i > 0 and i % self._log_interval == 0:
                    scores = np.array(scores)

                    best_ind = np.argmax(scores)

                    if scores[best_ind] > best_score:
                        best_params = params[best_ind]
                        best_score = scores[best_ind]
                        print("best_params",best_params)

                    print("iter ",i,"size",scores.size,"max",scores.max())
                    fitness_.append(scores.max())
                    self._logger.info(
                        'Iter={0}, size={1}, max={2:.4f}, '
                        'avg={3:.4f}, min={4:.4f}, std={5:.4f}'.format(
                            i, scores.size, scores.max(), scores.mean(),
                            scores.min(), scores.std()))

                if i > 0 and i % self._test_interval == 0:
                    test_params = params[best_ind]  # self.solver.best_params
                    
                    test_scores = self.sim_mgr.eval_params(
                        params=test_params, test=True)
                    print(
                        '[TEST] Iter={0}, #tests={1}, max={2:.4f} avg={3:.4f}, '
                        'min={4:.4f}, std={5:.4f}'.format(
                            i, test_scores.size, test_scores.max(),
                            test_scores.mean(), test_scores.min(),
                            test_scores.std()))
                    mean_test_score = test_scores.mean()
                    # save_model(
                    #    model_dir=self._log_dir,
                    #    model_name='iter_{}'.format(i),
                    #    params=best_params,
                    #    obs_params=self.sim_mgr.obs_params,
                    #    best=mean_test_score > best_score,
                    # )
                    if mean_test_score > best_score:
                        best_params = test_params
                        best_score = mean_test_score
                    best_score = max(best_score, mean_test_score)

            # Test and save the final model.
            self.solver.best_params = best_params
            self.best_params = self.solver.best_params  # best_params#self.solver.best_params

            test_scores = self.sim_mgr.eval_params(
                params=self.best_params, test=True)
            print(
                '[TEST] Iter={0}, #tests={1}, max={2:.4f}, avg={3:.4f}, '
                'min={4:.4f}, std={5:.4f}'.format(
                    self._max_iter, test_scores.size, test_scores.max(),
                    test_scores.mean(), test_scores.min(), test_scores.std()))
            mean_test_score = test_scores.mean()
            save_model(
                model_dir=self._log_dir,
                model_name='final'+str(self.seed),
                params=self.best_params,
                obs_params=self.sim_mgr.obs_params,
                best=mean_test_score > best_score,
            )
            np.save(self._log_dir+"/fitness"+str(self.seed)+".npy",fitness_)
            best_score = max(best_score, mean_test_score)
            self._logger.info(
                'Training done, best_score={0:.4f}'.format(best_score))

            return best_score


In [None]:
# Let's create a directory to save logs and models.
log_dir = './log'
logger = create_logger(name='EvoJAX', log_dir=log_dir)
logger.info('Welcome to the tutorial on Neuroevolution algorithm creation!')

logger.info('Jax backend: {}'.format(jax.local_devices()))

!nvidia-smi --query-gpu=name --format=csv,noheader

Tesla T4


## Introduction

EvoJAX has three major components: the *task*, the *policy network* and the *neuroevolution algorithm*. Once these components are implemented and instantiated, we can use a trainer to start the training process. The following code snippet provides an example of how we use EvoJAX.

In [None]:
import argparse
import os
import shutil
import jax
import jax.numpy as jnp

#from evojax.task.ma_waterworld import MultiAgentWaterWorld
from evojax.policy.mlp import MLPPolicy
from evojax.algo import PGPE
from evojax import Trainer
from evojax import util


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--hidden-size', type=int, default=16, help='Policy hidden size.')
    parser.add_argument(
        '--num-tests', type=int, default=50, help='Number of test rollouts.')
    parser.add_argument(
        '--n-repeats', type=int, default=10, help='Training repetitions.')
    parser.add_argument(
        '--max-iter', type=int, default=50, help='Max training iterations.')
    parser.add_argument(
        '--test-interval', type=int, default=20, help='Test interval.')
    parser.add_argument(
        '--log-interval', type=int, default=1, help='Logging interval.')
    parser.add_argument(
        '--seed', type=int, default=2, help='Random seed for training.')
    parser.add_argument(
        '--center-lr', type=float, default=0.011, help='Center learning rate.')
    parser.add_argument(
        '--std-lr', type=float, default=0.054, help='Std learning rate.')
    parser.add_argument(
        '--init-std', type=float, default=0.095, help='Initial std.')#0.095
    parser.add_argument(
        '--gpu-id', type=str, help='GPU(s) to use.')
    parser.add_argument(
        '--debug', action='store_true', help='Debug mode.')
    config, _ = parser.parse_known_args()
    return config



def main(config):
    log_dir = '/content/gdrive/My Drive/log/reactive/evolved'#'./log/water_world_ma'
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    logger = util.create_logger(
        name='MultiAgentWaterWorld', log_dir=log_dir, debug=config.debug)

    logger.info('EvoJAX MultiAgentWaterWorld')
    logger.info('=' * 30)
    population = 40
    num_agents = 512
    max_steps = 400
    #key = jax.random.PRNGKey(125)[None, :]
    for ii in range(1,10):
      config.seed = ii+1
      train_task = PredatorFlocking(
          num_agents=num_agents, test=False, max_steps=max_steps)
      test_task = PredatorFlocking(
          num_agents=num_agents, test=True, max_steps=max_steps)
      #task_reset_fn = jax.jit(test_task.reset)
      #task_state = task_reset_fn(key)
      num_obs = 8#np.shape(task_state.obs)[-1]+1
      policy = MLPPolicy(
          input_dim=train_task.obs_shape[-1],
          hidden_dims=[config.hidden_size,],
          output_dim=train_task.act_shape[-1],
          output_act_fn='tanh',
      )
      init_vector = np.hstack((np.ones(num_obs-1)*((1/(num_obs-1))*0.5),np.ones(1)*0.5))
      solver = PGPE(
          pop_size=population,
          param_size= num_obs,
          init_params=init_vector,
          optimizer='adam',
          center_learning_rate=config.center_lr,
          stdev_learning_rate=config.std_lr,
          init_stdev=config.init_std,
          logger=logger,
          seed=config.seed,
      )


      trainer = MTrainer(
          policy=policy,
          solver=solver,
          train_task=train_task,
          test_task=test_task,
          max_iter=config.max_iter,
          log_interval=config.log_interval,
          test_interval=config.test_interval,
          n_evaluations=config.num_tests,
          num_agents = num_agents,
          n_repeats=config.n_repeats,
          test_n_repeats=config.num_tests,
          seed=config.seed,
          log_dir=log_dir,
          logger=logger,
      )
      
      trainer.run()
    # Visualize the policy.
    for j in range(10):
      task_reset_fn = jax.jit(test_task.reset)
      policy_reset_fn = jax.jit(policy.reset)
      step_fn = jax.jit(test_task.step)
      action_fn = jax.jit(policy.get_actions)
 
      #best_params = jnp.repeat(
      #    trainer.best_params, num_agents, axis=0)
      
      #best_params = best_params.reshape(num_agents,policy.num_params)
      best_params = jnp.tile(
        trainer.best_params, num_agents).reshape(num_agents,num_obs)
      #print(trainer.best_params)
      #input("pause")

      key = jax.random.PRNGKey(j)[None, :]

      task_state = task_reset_fn(key)
      policy_state = policy_reset_fn(task_state)
      screens = []
      rew = 0
      for _ in range(200):
          num_tasks, num_agents = task_state.obs.shape[:2]
          task_state = task_state.replace(
              obs=task_state.obs.reshape((-1, *task_state.obs.shape[2:])))
          
          #action, policy_state = action_fn(task_state, best_params, policy_state)
          #action = action.reshape(num_tasks, num_agents, *action.shape[1:])
          task_state = task_state.replace(
              obs=task_state.obs.reshape(
                  num_tasks, num_agents, *task_state.obs.shape[1:]))
          org_obs = task_state.obs
          #actions, policy_state = self.policy.get_actions(
          #  task_state, params[:,:-self.num_obs], policy_state)
          
          #input("ww")
          action_f = jnp.sum(org_obs[0,:, :-1] * best_params[:,:-1], axis=1)
          action_p = org_obs[0,:, -1] * best_params[:, -1]
          actions = jnp.expand_dims((action_f* best_params[:, -1]  + action_p),axis=1)
          actions = actions.reshape(num_tasks, num_agents, num_tasks)
          
          task_state, reward, done = step_fn(task_state, actions)
          print(_,rew)
          rew+=reward
          screens.append(test_task.render(task_state))
      print(rew)
      #imgs = render(test_task, solver, policy)

      gif_file = os.path.join(log_dir, 'DangerZone_100_m' +str(j)+'.gif')
      screens[j].save(
          gif_file, save_all=True, append_images=screens[1:], duration=40, loop=0)

    #from google.colab import files
    #files.download( './log/water_world_ma/cartpole.gif')
    #Image(open(os.path.join(log_dir, 'cartpole.gif'),'rb').read())

    
'''
if task.multi_agent_training:
                num_tasks, num_agents = task_state.obs.shape[:2]
                # print(num_tasks,num_agents)
                # input("www")
                task_state = task_state.replace(
                    obs=task_state.obs.reshape((-1, *task_state.obs.shape[2:])))

            org_obs = task_state.obs

            #normed_obs = self.obs_normalizer.normalize_obs(org_obs, obs_params)
            #task_state = task_state.replace(obs=normed_obs)

            actions, policy_state = policy_net.get_actions(
                task_state, params, policy_state)


            if task.multi_agent_training:
                task_state = task_state.replace(
                    obs=task_state.obs.reshape(
                        (num_tasks, num_agents, *task_state.obs.shape[1:])))
                actions = actions.reshape(
                    (num_tasks, num_agents, *actions.shape[1:]))
'''



if __name__ == '__main__':
    configs = parse_args()
    if configs.gpu_id is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id
    main(configs)

In [None]:
import argparse
import os
import shutil
import jax
import jax.numpy as jnp

#from evojax.task.ma_waterworld import MultiAgentWaterWorld
from evojax.policy.mlp import MLPPolicy
from evojax.algo import PGPE
from evojax import Trainer
from evojax import util
from evojax.util import load_model

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--hidden-size', type=int, default=16, help='Policy hidden size.')
    parser.add_argument(
        '--num-tests', type=int, default=50, help='Number of test rollouts.')
    parser.add_argument(
        '--n-repeats', type=int, default=2, help='Training repetitions.')
    parser.add_argument(
        '--max-iter', type=int, default=2, help='Max training iterations.')
    parser.add_argument(
        '--test-interval', type=int, default=1, help='Test interval.')
    parser.add_argument(
        '--log-interval', type=int, default=1, help='Logging interval.')
    parser.add_argument(
        '--seed', type=int, default=421, help='Random seed for training.')
    parser.add_argument(
        '--center-lr', type=float, default=0.011, help='Center learning rate.')
    parser.add_argument(
        '--std-lr', type=float, default=0.054, help='Std learning rate.')
    parser.add_argument(
        '--init-std', type=float, default=0.00000000000001, help='Initial std.')#0.095
    parser.add_argument(
        '--gpu-id', type=str, help='GPU(s) to use.')
    parser.add_argument(
        '--debug', action='store_true', help='Debug mode.')
    config, _ = parser.parse_known_args()
    return config



def main(config):
    log_dir = '/content/gdrive/My Drive/log/reactive/evolved'#'./log/water_world_ma'
    print(log_dir)
    init_vector, obs_params = load_model(model_dir=log_dir+'/')
    print(init_vector)
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    logger = util.create_logger(
        name='MultiAgentWaterWorld', log_dir=log_dir, debug=config.debug)

    logger.info('EvoJAX MultiAgentWaterWorld')
    logger.info('=' * 30)
    population = 40
    num_agents = 1024
    max_steps = 400
    #key = jax.random.PRNGKey(125)[None, :]
    train_task = PredatorFlocking(
        num_agents=num_agents, test=False, max_steps=max_steps)
    test_task = PredatorFlocking(
        num_agents=num_agents, test=True, max_steps=max_steps)
    #task_reset_fn = jax.jit(test_task.reset)
    #task_state = task_reset_fn(key)
    num_obs = 8##np.shape(task_state.obs)[-1]+1
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[-1],
        hidden_dims=[config.hidden_size,],
        output_dim=train_task.act_shape[-1],
        output_act_fn='tanh',
    )
    #init_vector = np.ones(num_obs)*(1/(num_obs))
    solver = PGPE(
        pop_size=population,
        param_size= num_obs,
        init_params=init_vector,
        optimizer='adam',
        center_learning_rate=config.center_lr,
        stdev_learning_rate=config.std_lr,
        init_stdev=config.init_std,
        logger=logger,
        seed=config.seed,
    )


    trainer = MTrainer(
        policy=policy,
        solver=solver,
        train_task=train_task,
        test_task=test_task,
        max_iter=config.max_iter,
        log_interval=config.log_interval,
        test_interval=config.test_interval,
        n_evaluations=config.num_tests,
        num_agents = num_agents,
        n_repeats=config.n_repeats,
        test_n_repeats=config.num_tests,
        seed=config.seed,
        log_dir=log_dir,
        model_dir=log_dir,
        logger=logger,
    )
    best_params = init_vector
    #trainer.run()
        # Visualize the policy.
    for j in range(10):
      task_reset_fn = jax.jit(test_task.reset)
      policy_reset_fn = jax.jit(policy.reset)
      step_fn = jax.jit(test_task.step)
      action_fn = jax.jit(policy.get_actions)
 
      #best_params = jnp.repeat(
      #    trainer.best_params, num_agents, axis=0)
      
      #best_params = best_params.reshape(num_agents,policy.num_params)
      #best_params = jnp.tile(
      #  trainer.best_params, num_agents).reshape(num_agents,num_obs)
      #print(trainer.best_params)
      #input("pause")
      best_params = jnp.tile(
        best_params, num_agents).reshape(num_agents,num_obs)
      key = jax.random.PRNGKey(j)[None, :]

      task_state = task_reset_fn(key)
      policy_state = policy_reset_fn(task_state)
      screens = []
      rew = 0
      for _ in range(150):
          num_tasks, num_agents = task_state.obs.shape[:2]
          task_state = task_state.replace(
              obs=task_state.obs.reshape((-1, *task_state.obs.shape[2:])))
          
          #action, policy_state = action_fn(task_state, best_params, policy_state)
          #action = action.reshape(num_tasks, num_agents, *action.shape[1:])
          task_state = task_state.replace(
              obs=task_state.obs.reshape(
                  num_tasks, num_agents, *task_state.obs.shape[1:]))
          org_obs = task_state.obs
          #actions, policy_state = self.policy.get_actions(
          #  task_state, params[:,:-self.num_obs], policy_state)
          
          #input("ww")
          action_f = jnp.sum(org_obs[0,:, :-1] * best_params[:, -num_obs:-1], axis=1)
          action_p = org_obs[0,:, -1] * best_params[:, -1]
          actions = jnp.expand_dims((action_f + action_p),axis=1)
          actions = actions.reshape(num_tasks, num_agents, num_tasks)
          
          task_state, reward, done = step_fn(task_state, actions)
          print(_,rew)
          rew+=reward
          screens.append(test_task.render(task_state))
      print(rew)
      #imgs = render(test_task, solver, policy)

      gif_file = os.path.join(log_dir, 'DangerZone_100_m' +str(j)+'.gif')
      screens[j].save(
          gif_file, save_all=True, append_images=screens[1:], duration=40, loop=0)

    #from google.colab import files
    #files.download( './log/water_world_ma/cartpole.gif')
    #Image(open(os.path.join(log_dir, 'cartpole.gif'),'rb').read())

    
'''
if task.multi_agent_training:
                num_tasks, num_agents = task_state.obs.shape[:2]
                # print(num_tasks,num_agents)
                # input("www")
                task_state = task_state.replace(
                    obs=task_state.obs.reshape((-1, *task_state.obs.shape[2:])))

            org_obs = task_state.obs

            #normed_obs = self.obs_normalizer.normalize_obs(org_obs, obs_params)
            #task_state = task_state.replace(obs=normed_obs)

            actions, policy_state = policy_net.get_actions(
                task_state, params, policy_state)


            if task.multi_agent_training:
                task_state = task_state.replace(
                    obs=task_state.obs.reshape(
                        (num_tasks, num_agents, *task_state.obs.shape[1:])))
                actions = actions.reshape(
                    (num_tasks, num_agents, *actions.shape[1:]))
'''



if __name__ == '__main__':
    configs = parse_args()
    if configs.gpu_id is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id
    main(configs)