Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

226 lines (197 sloc) 6.59 KB
extern crate serde_json;
extern crate hyper;
extern crate rand;
use std::collections::BTreeMap;
use std::io::Read;
use serde_json::Value;
use serde_json::value::{ToJson, from_value};
use hyper::client::Client;
use hyper::header::Headers;
use rand::{thread_rng, Rng};
pub type GymResult<T> = Result<T, hyper::Error>;
#[derive(Debug, Clone)]
pub enum Space {
DISCRETE{n: u64},
BOX{shape: Vec<u64>, high: Vec<f64>, low: Vec<f64>},
TUPLE{spaces: Vec<Box<Space>>}
}
impl Space {
fn from_json(info: &Value) -> Space {
match info.find("name").unwrap().as_str().unwrap() {
"Discrete" => {
let n = info.find("n").unwrap().as_u64().unwrap();
Space::DISCRETE{n: n}
},
"Box" => {
let shape = info.find("shape").unwrap().as_array().unwrap()
.into_iter().map(|x| x.as_u64().unwrap())
.collect::<Vec<_>>();
let high = info.find("high").unwrap().as_array().unwrap()
.into_iter().map(|x| x.as_f64().unwrap())
.collect::<Vec<_>>();
let low = info.find("low").unwrap().as_array().unwrap()
.into_iter().map(|x| x.as_f64().unwrap())
.collect::<Vec<_>>();
Space::BOX{shape: shape, high: high, low: low}
},
"Tuple" => panic!("Parsing for Tuple spaces is not yet implemented"),
e @ _ => panic!("Unrecognized space name: {}", e)
}
}
pub fn sample(&self) -> Vec<f64> {
let mut rng = thread_rng();
match *self {
Space::DISCRETE{n} => {
vec![(rng.gen::<u64>()%n) as f64]
},
Space::BOX{ref shape, ref high, ref low} => {
let mut ret = Vec::with_capacity(shape.iter().map(|x| *x as usize).product());
let mut index = 0;
for &i in shape {
for _ in 0..i {
ret.push(rng.gen_range(low[index], high[index]));
index += 1;
}
}
ret
},
Space::TUPLE{ref spaces} => {
let mut ret = Vec::new();
for space in spaces {
ret.extend(space.sample());
}
ret
}
}
}
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct State {
pub observation: Vec<f64>,
pub reward: f64,
pub done: bool,
pub info: Value,
}
#[allow(dead_code)]
pub struct Environment {
client: GymClient,
instance_id: String,
act_space: Space,
obs_space: Space,
}
impl Environment {
pub fn action_space<'a>(&'a self) -> &'a Space {
&self.act_space
}
pub fn observation_space<'a>(&'a self) -> &'a Space {
&self.obs_space
}
pub fn reset(&mut self) -> GymResult<Vec<f64>> {
let path = "/v1/envs/".to_string() + &self.instance_id + "/reset/";
let observation = try!(self.client.post(path, Value::Null));
let ret: Vec<_> = observation.find("observation").unwrap().as_array().unwrap()
.into_iter().map(|x| x.as_f64().unwrap())
.collect();
Ok(ret)
}
pub fn step(&mut self, action: Vec<f64>, render: bool) -> GymResult<State> {
let mut req = BTreeMap::new();
req.insert("render", Value::Bool(render));
match self.act_space {
Space::DISCRETE{..} => {
assert_eq!(action.len(), 1);
req.insert("action", Value::U64(action[0] as u64));
},
Space::BOX{ref shape, ..} => {
assert_eq!(action.len(), shape[0] as usize);
req.insert("action", action.to_json());
},
Space::TUPLE{..} => panic!("Actions for Tuple spaces not implemented yet")
}
let path = "/v1/envs/".to_string() + &self.instance_id + "/step/";
let state = try!(self.client.post(path, req.to_json()));
Ok(State {
observation: from_value(state.find("observation").unwrap().clone()).unwrap(),
reward: state.find("reward").unwrap().as_f64().unwrap(),
done: state.find("done").unwrap().as_bool().unwrap(),
info: state.find("info").unwrap().clone()
})
}
pub fn monitor_start(&mut self, directory: String, force: bool, resume: bool) -> GymResult<()> {
let mut req = BTreeMap::new();
req.insert("directory", Value::String(directory));
req.insert("force", Value::Bool(force));
req.insert("resume", Value::Bool(resume));
let path = "/v1/envs/".to_string() + &self.instance_id + "/monitor/start/";
try!(self.client.post(path, req.to_json()));
Ok(())
}
pub fn monitor_stop(&mut self) -> GymResult<()> {
let path = "/v1/envs/".to_string() + &self.instance_id + "/monitor/close/";
try!(self.client.post(path, Value::Null));
Ok(())
}
pub fn upload(&mut self, training_dir: String, api_key: String, algorithm_id: String) -> GymResult<()> {
let mut req = BTreeMap::new();
req.insert("training_dir", training_dir);
req.insert("api_key", api_key);
req.insert("algorithm_id", algorithm_id);
try!(self.client.post("/v1/upload/".to_string(), req.to_json()));
Ok(())
}
}
pub struct GymClient {
address: String,
handle: Client,
headers: Headers,
}
impl GymClient {
pub fn new(addr: String) -> GymClient {
let mut headers = Headers::new();
headers.set_raw("Content-Type", vec![b"application/json".to_vec()]);
GymClient {
address: addr,
handle: Client::new(),
headers: headers
}
}
pub fn make(mut self, env_id: &str) -> GymResult<Environment> {
let mut req: BTreeMap<&str, &str> = BTreeMap::new();
req.insert("env_id", env_id);
let instance_id = try!(self.post("/v1/envs/".to_string(), req.to_json()));
let instance_id = match instance_id.find("instance_id") {
Some(id) => id.as_str().unwrap(),
None => panic!("Unrecognized environment id: {}", env_id)
};
let obs_space = try!(self.get("/v1/envs/".to_string() + instance_id + "/observation_space/"));
let act_space = try!(self.get("/v1/envs/".to_string() + instance_id + "/action_space/"));
Ok(Environment {
client: self,
instance_id: instance_id.to_string(),
act_space: Space::from_json(act_space.find("info").unwrap()),
obs_space: Space::from_json(obs_space.find("info").unwrap())})
}
pub fn get_envs(&mut self) -> GymResult<BTreeMap<String, String>> {
let json = try!(self.get("/v1/envs/".to_string()));
Ok(from_value(json.find("all_envs").unwrap().clone()).unwrap())
}
fn post(&mut self, route: String, request: Value) -> GymResult<Value> {
let url = self.address.clone() + &route;
let mut resp = try!(self.handle.post(&url)
.body(&request.to_string())
.headers(self.headers.clone())
.send());
let mut json = String::new();
let _ = resp.read_to_string(&mut json);
Ok(serde_json::from_str(&json).unwrap_or(Value::Null))
}
fn get(&mut self, route: String) -> GymResult<Value> {
let url = self.address.clone() + &route;
let mut resp = try!(self.handle.get(&url)
.send());
let mut json = String::new();
let _ = resp.read_to_string(&mut json);
Ok(serde_json::from_str(&json).unwrap_or(Value::Null))
}
}
You can’t perform that action at this time.