In [8]:
import gym
from gym.spaces import Box
import numpy as np
import matlab
import matlab.engine
import ray

@ray.remote
class DistributedTSCSEnv(gym.Env):
    def __init__(self, config):
        ## Initialize matlab
        self.eng = matlab.engine.start_matlab()
        self.eng.addpath('../DDPG/TSCS')

        ## Env hyperparameters
        self.nCyl = config['nCyl']
        self.nFreq = config['nFreq']
        self.M = matlab.double([self.nCyl])
        self.k0amax = matlab.double([config['k0amax']])
        self.k0amin = matlab.double([config['k0amin']])
        self.F = matlab.double([self.nFreq])
        self.actionRange = config['actionRange']
        self.episodeLength = config['episodeLength']

        ## State variables
        self.config = None
        self.TSCS = None
        self.RMS = None
        self.timestep = None
        self.lowest = None

        ## Observation and action space
        self.observation_space = Box(
            low=-100.,
            high=100.,
            ## Number of cylinders + number of wavenumbers + 2 additional variables (rms, timestep)
            shape=((1, self.nCyl * 2 + self.nFreq + 2)))

        self.action_space = Box(
            low=-self.actionRange,
            high=self.actionRange,
            shape=(int(self.nCyl * 2),))

    def validConfig(self, config):
        """
        Checks if config is within bounds and does not overlap cylinders
        """
        withinBounds = False
        overlap = False
        if (-5 < config).all() and (config < 5).all():
            withinBounds = True

            coords = config.reshape(self.nCyl, 2)
            for i in range(self.nCyl):
                for j in range(self.nCyl):
                    if i != j:
                        x1, y1 = coords[i]
                        x2, y2 = coords[j]
                        d = np.sqrt((x2-x1)**2 + (y2-y1)**2)
                        if d <= 2.1:
                            overlap = True
        return withinBounds and not overlap

    def getConfig(self):
        """
        Generates a configuration which is within bounds 
        and not overlaping cylinders
        """
        while True:
            config = np.random.uniform(-5., 5., (1, self.nCyl * 2))
            if self.validConfig(config):
                break
        return config

    def getMetric(self, config):
        """
        This calculates total cross secitonal scattering across nFreq wavenumbers
        from k0amax to k0amin. Also calculates RMS of these wavenumbers.
        """
        # x = self.eng.transpose(matlab.double(*self.config.reshape(self.nCyl * 2)))
        x = self.eng.transpose(matlab.double(*config.tolist()))
        tscs = self.eng.getMetric(x, self.M, self.k0amax, self.k0amin, self.F)
        tscs = np.array(tscs).T
        rms = np.sqrt(np.power(tscs, 2).mean()).reshape(1, 1)
        return tscs, rms

    def reset(self):
        """
        Generates starting config and calculates its tscs
        """
        self.config = self.getConfig()
        self.TSCS, self.RMS = self.getMetric(self.config)
        self.timestep = np.array([[0.0]])
        self.lowest = np.asscalar(self.RMS)
        state = np.concatenate((self.config, self.TSCS, self.RMS, self.timestep), axis=-1)
        return state

    def getReward(self, RMS, isValid):
        """
        Computes reward based on change in scattering 
        proporitional to how close it is to zero
        """
        reward = -np.sqrt(RMS).item()
        if not isValid:
            reward += -1.0
        return reward

    def step(self, action):
        """
        If the config after applying the action is not valid
        we revert back to previous state and give negative reward
        otherwise, reward is calculated by a function on the next scattering.
        """
        nextConfig = self.config.copy() + action

        valid = False
        if self.validConfig(nextConfig):
            self.config = nextConfig
            valid = True

        self.TSCS, self.RMS = self.getMetric(self.config)
        reward = self.getReward(self.RMS, valid)
        self.timestep += 1/self.episodeLength

        if self.RMS < self.lowest:
            self.lowest = self.RMS.item()

        done = False
        if int(self.timestep) == 1:
            done = True

        info = {
            'meanTSCS':self.TSCS.mean(),
            'rms':self.RMS,
            'lowest':self.lowest}

        state = np.concatenate((self.config, self.TSCS, self.RMS, self.timestep), axis=-1)
        return state, reward, done, info
    
    def rollout_episode(self):
        data = {
            'states': np.ndarray(shape=(self.episodeLength, self.observation_space.shape[1])),
            'actions': np.ndarray(shape=(self.episodeLength, self.action_space.shape[0])),
            'rewards': np.ndarray(shape=(self.episodeLength, 1)),
            'next_states': np.ndarray(shape=(self.episodeLength, self.observation_space.shape[1])),
            'dones': np.ndarray(shape=(self.episodeLength, 1))}

        state = self.reset()
        done = False
        idx = 0
        while not done:
            action = self.action_space.sample()
            next_state, reward, done, info = self.step(action)

            data['states'][idx] = state
            data['actions'][idx] = action
            data['rewards'][idx] = reward
            data['next_states'][idx] = next_state
            data['dones'][idx] = done
            idx += 1
        return data

In [23]:
# from env import DistributedTSCSEnv
import ray
import torch
import numpy as np
import time

In [10]:
ray.init()

2020-10-31 00:25:24,367	INFO services.py:1164 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '10.0.0.12',
 'raylet_ip_address': '10.0.0.12',
 'redis_address': '10.0.0.12:6379',
 'object_store_address': '/tmp/ray/session_2020-10-31_00-25-23_139231_3070/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-10-31_00-25-23_139231_3070/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2020-10-31_00-25-23_139231_3070',
 'metrics_export_port': 64357}

In [70]:
config = {
    'nCyl': 4,
    'k0amax': 0.45,
    'k0amin': 0.35,
    'nFreq': 11,
    'actionRange': 0.2,
    'episodeLength': 100}

envs = [DistributedTSCSEnv.remote(config) for i in range(5)]

In [73]:
futures = [env.rollout_episode.remote() for env in envs]

In [74]:
start = time.time()
data = ray.get(futures)
time.time() - start

3.653256416320801

In [76]:
for i in range(10):
    futures = [env.rollout_episode.remote() for env in envs]
    start = time.time()
    data = ray.get(futures)
    print(time.time() - start)

3.8919758796691895
3.911433458328247
3.9060847759246826
4.022394180297852
4.050745487213135
3.936521053314209
3.9807753562927246
4.032641649246216
3.955822467803955
3.963179588317871


In [77]:
env = DistributedTSCSEnv.remote(config)

In [80]:
futures = env.rollout_episode.remote()

In [81]:
start = time.time()
future = ray.get(futures)
time.time() - start

3.4818291664123535

In [83]:
import time

In [7]:
start_time = time.time()
data = []
for _ in range(5):
    data.append(rollout_episode.remote(env))
time.time() - start_time

ray.get(data)

RayTaskError(AttributeError): [36mray::rollout_episode()[39m (pid=23772, ip=10.0.0.12)
  File "python/ray/_raylet.pyx", line 484, in ray._raylet.execute_task
  File "<ipython-input-5-10e25df8a31e>", line 4, in rollout_episode
  File "/home/tristan/anaconda3/envs/TSCSProject/lib/python3.8/site-packages/ray/actor.py", line 804, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has "
AttributeError: 'ActorHandle' object has no attribute 'episodeLength'

In [84]:
ray.shutdown()