In [114]:
from ast import mod
from telnetlib import DM
from turtle import mode
from unicodedata import name
import gym
from copy import deepcopy
import os
import os.path as osp
import torch
from scipy import stats
from statistics import mean 
import numpy as np
from torch.optim import Adam
import itertools
import random
import torch.nn as nn
import argparse
import pickle

def get_env_name(name):
    if ('humanoid' in name) or ('Humanoid' in name):
        return 'Humanoid-v3'
    if ('halfcheetah' in name) or ('HalfCheetah' in name):
        return 'HalfCheetah-v3'
    if ('ant' in name) or ('Ant' in name):
        return 'Ant-v3'
    if ('hopper' in name) or  ('Hopper' in name) :
        return 'Hopper-v3'
    if ('walker' in name) or ('Walker' in name) :
        return 'Walker2d-v3'
    return 'unknown'

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class PPO_Actor():
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        self.pi = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
        self.obs_mean = np.ones(obs_dim)
        self.obs_std = np.ones(obs_dim)
        self.clip = 10.0
        # print(type(self.pi))
    
    def normalize_o(self, o):
        o = o - self.obs_mean
        o = o / (self.obs_std + 1e-8)
        o = np.clip(o, -self.clip, self.clip)
        return o
    
    def act(self, o):
        o = self.normalize_o(o)
        o = torch.as_tensor(o, dtype=torch.float32)
        return self.pi(o).detach().numpy()
    
    def copy_model(self, md):
        self.pi.load_state_dict(md['pi'])
        self.obs_mean = md['obs_mean']
        self.obs_std = md['obs_std']
        self.clip = md['clip']
        
    def load(self, name):
        md = torch.load(name)
        self.copy_model(md)


def get_ppo_models(path, name):
    fpath = osp.join(path, name)
    models = []
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    env = gym.make(get_env_name(name))
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    for file_name in file_names:   
        if ".pt" not in file_name:
            continue
        fname = osp.join(fpath, file_name)
        print(file_name)
        model = PPO_Actor(obs_dim, action_dim, (64, 64), nn.Tanh)
        model.load(fname)
        models.append((name, file_name, model))
    return models

def get_models(path, name):
    print("get models ", path, name)
    if 'ppo' in name:
        return get_ppo_models(path, name)
    fpath = osp.join(path, name)
    models = {}
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    for file_name in file_names:   
        # fname = osp.join(fpath, file_name ,'pyt_save', 'model0.pt')
        fname = osp.join(fpath, file_name ,'pyt_save', 'model.pt')
        print(fname)
        model = torch.load(fname)
        models[file_name] = model
        # models.append((name, file_name, model))
    return models

def save_state(env):
    return env.sim.get_state()

def restore_state(env, old_state):
    env.reset()
    env.sim.set_state(old_state)
    env.sim.forward()
    return env.get_obs()

def get_ppo_action(o, md):
    return md.act(o)

def get_action(o, md, name):
    if 'ppo' in name:
        return get_ppo_action(o, md)
    if 'train' not in name:
        o = torch.as_tensor(o, dtype=torch.float32)
        return md.act(o)
    o = torch.as_tensor(o, dtype=torch.float32)
    return md.act(o, deterministic=False)

def get_q(o, a, md):
    o = torch.as_tensor(o, dtype=torch.float32)
    a = torch.as_tensor(a, dtype=torch.float32)
    q1 = md.q1(o, a)
    q2 = md.q2(o, a)
    return torch.min(q1, q2)

def get_all_traj_names_with_same_env(path, trajs_path, name):
    all_trajs_names = []
    fpath = osp.join(path, trajs_path)
    print(fpath)
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    env_name = get_env_name(name)
    for file_name in file_names:
        if "trajs.pkl" not in file_name:
            continue
        tmp = get_env_name(file_name)
        if tmp == env_name:
            all_trajs_names.append(file_name)
    return all_trajs_names


def load_all_same_env_results(cpath, env_name):
    print('load continue results: ', cpath)
    file_names = os.listdir(cpath)
    rets = []
    for file_name in file_names:
        if ".pkl" in file_name and get_env_name(file_name) == env_name:
            print(file_name)
            file_name = osp.join(cpath, file_name)
            with open(file_name, 'rb') as f: 
                ret = pickle.load(f)
            rets.append(ret)
    return rets

def is_fail(data):
    if data[3][0]:
        return 1.0
    # if data[3][1] < 100:
    #     return 1.0
    return 0.0

def get_fail_rate(result): # result of one agent to one agent's trajs
    ret = 0
    for data in result:
        ret += is_fail(data)
    return ret/len(result)

def get_result_mean(result):
    ret = 0
    for data in result: # data = (traj_id, midpoint_id, old_ret, (d, total_r))
        ret += data[3][1]
    return ret/len(result)

def get_traj_d(path, trajs_path, all_traj_names, algo_names, test_algo_name):
    env_name = get_env_name(test_algo_name)
    trajs_d = {}
    for aname in algo_names[env_name]:
        for trajs_name in all_traj_names:
            if aname in trajs_name:
                tname = osp.join(path, trajs_path, trajs_name)
                print(tname)
                with open(tname, 'rb') as f: 
                    trajs = pickle.load(f)
                trajs_d[aname] = trajs 
                break
    return trajs_d

def get_state(trajs_d, algo_name, agent_name, traj_id, midpoint_id):
    for data in trajs_d[algo_name]:
        if data[1] == agent_name:
            return data[-1][traj_id][midpoint_id]

def get_self_ret(cresults, trajs_d):
    self_ret = {}
    for algo_results in cresults:
        for results in algo_results:
            if results[0] == results[2] and results[1] == results[3]:
                x = {}
                for data in results[4]:
                    if not data[3][0]:
                        eplen, s = get_state(trajs_d, results[0], results[1], data[0], data[1])
                        x[(data[0], data[1])] = (data[3][1],s, epl)
                        
                if results[0] not in self_ret.keys():
                    self_ret[results[0]] = {}
                    print(results[0])
                self_ret[results[0]][results[1]] = x
    return self_ret


def is_good_key(key,  remove_keys, must_keys):
    for k in remove_keys:
        if k in key:
            return False
    for k in must_keys:
        if k not in key:
            return False
    return True
def print_result(fname, test_name_keys, remove_keys=[], must_keys=[]):
    f = open(fname, 'r')
    s = f.read()
    # print(s)
    x = s.split()
    i = 0
    data = {}
    while i < len(x):
        if x[i+1] not in data.keys():
            data[x[i+1]] = {}
        data[x[i+1]][x[i]] = x[i+2] + ' ' + x[i+3] + ' ' + x[i+4]
        # print(x[i], x[i+1], data[x[i+1]][x[i]])
        i += 5
    for env_name in algo_names.keys():
        print(env_name)
        for algo_name in algo_names[env_name]:
            # print(test_name_keys.keys())
            for test_name_key in test_name_keys:
                name = ''
                for test_algo_name in data[algo_name].keys():
                    if test_name_key in test_algo_name \
                        and is_good_key(test_algo_name, remove_keys, must_keys):
                        if (test_name_key + '_') in test_algo_name or \
                            test_name_key == test_algo_name[-len(test_name_key):] :
                            name = test_algo_name
                            break
                # print(algo_name, name, test_name_key)
                if name != '':
                    print(data[algo_name][name], end=" & ")
                else:
                    print(test_name_key, "cant find")
            print('')


In [123]:
fname = './data/gsac7/results.txt'
algo_names = {}
algo_names['Humanoid-v3'] = ['Humanoid-v3_sac_base', 'Humanoid-v3_td3_base', 'vanilla_ppo_humanoid',  'sgld_ppo_humanoid']
# algo_names['Ant-v3'] = ['Ant-v3_sac_base' , 'Ant-v3_td3_base', 'vanilla_ppo_ant', 'atla_ppo_ant']
algo_names['Walker2d-v3'] = ['Walker2d-v3_sac_base', 'Walker2d-v3_td3_base', 'vanilla_ppo_walker', 'atla_ppo_walker']
# algo_names['HalfCheetah-v3'] = ['HalfCheetah-v3_sac_base', 'HalfCheetah-v3_td3_base',  'vanilla_ppo_halfcheetah', 'atla_ppo_halfcheetah']
algo_names['Hopper-v3'] = ['Hopper-v3_sac_base', 'Hopper-v3_td3_base', 'vanilla_ppo_hopper',  'atla_ppo_hopper']
remove_keys = []
must_keys = ['sn10']
test_name_keys = ['owr10','owr11','owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
5.34 $\pm$ 2.67 & 4.02 $\pm$ 2.74 & 6.72 $\pm$ 4.25 & 8.31 $\pm$ 10.58 & 
3.38 $\pm$ 2.48 & 2.90 $\pm$ 3.07 & 3.19 $\pm$ 2.94 & 5.76 $\pm$ 4.85 & 
9.21 $\pm$ 2.37 & 7.41 $\pm$ 3.48 & 9.68 $\pm$ 4.90 & 10.58 $\pm$ 4.65 & 
20.05 $\pm$ 9.53 & 26.71 $\pm$ 10.42 & 24.10 $\pm$ 10.05 & 30.86 $\pm$ 12.73 & 
Walker2d-v3
18.86 $\pm$ 9.91 & 21.10 $\pm$ 15.34 & 16.95 $\pm$ 9.30 & 15.95 $\pm$ 10.84 & 
20.71 $\pm$ 9.22 & 17.24 $\pm$ 12.52 & 17.86 $\pm$ 9.20 & 12.95 $\pm$ 9.36 & 
13.19 $\pm$ 6.11 & 9.57 $\pm$ 8.93 & 10.71 $\pm$ 6.52 & 5.81 $\pm$ 3.19 & 
14.26 $\pm$ 6.42 & 10.65 $\pm$ 8.09 & 12.21 $\pm$ 5.87 & 6.64 $\pm$ 3.82 & 
Hopper-v3
27.07 $\pm$ 23.99 & 17.82 $\pm$ 5.87 & 27.41 $\pm$ 17.48 & 21.50 $\pm$ 10.78 & 
21.48 $\pm$ 23.15 & 16.90 $\pm$ 15.24 & 26.14 $\pm$ 19.40 & 26.43 $\pm$ 18.96 & 
22.43 $\pm$ 23.34 & 16.43 $\pm$ 8.37 & 21.00 $\pm$ 14.92 & 15.29 $\pm$ 9.84 & 
22.93 $\pm$ 24.74 & 19.18 $\pm$ 14.21 & 24.97 $\pm$ 20.87 & 26.53 $\pm$ 17.58 & 


In [122]:
remove_keys = []
must_keys = ['sn5']
test_name_keys = ['owr10','owr11','owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
8.10 $\pm$ 8.29 & 8.57 $\pm$ 5.26 & 17.04 $\pm$ 27.57 & 7.46 $\pm$ 5.21 & 
5.38 $\pm$ 6.93 & 7.10 $\pm$ 4.98 & 14.05 $\pm$ 28.75 & 3.62 $\pm$ 3.38 & 
10.42 $\pm$ 10.16 & 15.40 $\pm$ 6.67 & 20.26 $\pm$ 26.89 & 11.85 $\pm$ 6.58 & 
25.19 $\pm$ 15.30 & 42.14 $\pm$ 10.15 & 44.29 $\pm$ 20.35 & 40.76 $\pm$ 9.32 & 
Walker2d-v3
19.76 $\pm$ 13.08 & 13.14 $\pm$ 8.99 & 9.71 $\pm$ 5.32 & 13.43 $\pm$ 13.99 & 
17.00 $\pm$ 10.07 & 9.81 $\pm$ 6.01 & 10.43 $\pm$ 6.46 & 9.90 $\pm$ 11.43 & 
8.43 $\pm$ 6.02 & 5.81 $\pm$ 4.94 & 5.24 $\pm$ 3.47 & 5.76 $\pm$ 5.59 & 
9.90 $\pm$ 6.21 & 7.47 $\pm$ 3.83 & 7.59 $\pm$ 3.14 & 6.69 $\pm$ 6.30 & 
Hopper-v3
30.34 $\pm$ 24.24 & 20.14 $\pm$ 7.95 & 24.69 $\pm$ 17.25 & 10.75 $\pm$ 3.65 & 
27.29 $\pm$ 22.12 & 13.00 $\pm$ 12.50 & 28.57 $\pm$ 19.83 & 11.29 $\pm$ 7.36 & 
22.81 $\pm$ 23.04 & 14.57 $\pm$ 7.02 & 20.67 $\pm$ 14.76 & 6.38 $\pm$ 3.93 & 
27.69 $\pm$ 27.14 & 15.71 $\pm$ 11.75 & 28.03 $\pm$ 21.25 & 8.98 $\pm$ 5.50 & 


In [121]:
remove_keys = []
must_keys = ['sn2']
test_name_keys = ['owr10','owr11','owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
10.05 $\pm$ 4.79 & 13.97 $\pm$ 8.89 & 14.87 $\pm$ 8.52 & 18.41 $\pm$ 10.69 & 
7.19 $\pm$ 6.46 & 8.05 $\pm$ 4.91 & 6.10 $\pm$ 3.95 & 10.67 $\pm$ 8.66 & 
16.83 $\pm$ 8.51 & 21.43 $\pm$ 4.14 & 25.34 $\pm$ 13.78 & 26.61 $\pm$ 11.87 & 
41.14 $\pm$ 14.35 & 59.19 $\pm$ 11.34 & 61.81 $\pm$ 12.61 & 65.10 $\pm$ 11.93 & 
Walker2d-v3
17.38 $\pm$ 9.17 & 15.95 $\pm$ 9.74 & 10.81 $\pm$ 6.35 & 13.81 $\pm$ 2.82 & 
18.62 $\pm$ 10.90 & 10.57 $\pm$ 8.28 & 8.14 $\pm$ 4.57 & 10.90 $\pm$ 3.59 & 
10.19 $\pm$ 8.07 & 7.14 $\pm$ 4.69 & 7.95 $\pm$ 4.56 & 5.05 $\pm$ 4.56 & 
10.95 $\pm$ 8.24 & 7.87 $\pm$ 4.76 & 10.25 $\pm$ 6.15 & 7.52 $\pm$ 5.81 & 
Hopper-v3
28.10 $\pm$ 25.92 & 14.76 $\pm$ 7.12 & 18.64 $\pm$ 19.60 & 15.65 $\pm$ 6.60 & 
18.19 $\pm$ 29.62 & 11.81 $\pm$ 10.28 & 15.52 $\pm$ 22.90 & 9.76 $\pm$ 13.96 & 
21.38 $\pm$ 24.23 & 12.33 $\pm$ 11.62 & 13.48 $\pm$ 16.58 & 11.43 $\pm$ 6.25 & 
22.04 $\pm$ 29.44 & 15.37 $\pm$ 14.06 & 15.17 $\pm$ 21.17 & 11.63 $\pm$ 10.91 & 


In [124]:
remove_keys = ['sn10']
must_keys = ['sn1']
test_name_keys = ['owr10','owr11','owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
5.34 $\pm$ 2.67 & 36.93 $\pm$ 19.56 & 25.93 $\pm$ 11.88 & 8.31 $\pm$ 10.58 & 
3.38 $\pm$ 2.48 & 23.29 $\pm$ 8.67 & 20.05 $\pm$ 10.46 & 5.76 $\pm$ 4.85 & 
9.21 $\pm$ 2.37 & 42.54 $\pm$ 11.96 & 31.32 $\pm$ 9.69 & 10.58 $\pm$ 4.65 & 
20.05 $\pm$ 9.53 & 81.52 $\pm$ 4.78 & 72.05 $\pm$ 10.87 & 30.86 $\pm$ 12.73 & 
Walker2d-v3
20.81 $\pm$ 10.41 & 21.10 $\pm$ 15.34 & 21.76 $\pm$ 11.74 & 16.90 $\pm$ 12.04 & 
17.38 $\pm$ 9.01 & 17.24 $\pm$ 12.52 & 16.86 $\pm$ 10.93 & 14.19 $\pm$ 12.41 & 
11.29 $\pm$ 7.40 & 9.57 $\pm$ 8.93 & 7.95 $\pm$ 5.01 & 5.95 $\pm$ 5.54 & 
14.24 $\pm$ 8.07 & 11.95 $\pm$ 7.40 & 10.93 $\pm$ 7.22 & 8.32 $\pm$ 7.62 & 
Hopper-v3
27.07 $\pm$ 23.99 & 14.56 $\pm$ 7.37 & 20.07 $\pm$ 15.21 & 10.07 $\pm$ 3.17 & 
21.48 $\pm$ 23.15 & 11.52 $\pm$ 9.63 & 19.38 $\pm$ 20.35 & 5.24 $\pm$ 4.81 & 
22.43 $\pm$ 23.34 & 8.52 $\pm$ 7.38 & 18.19 $\pm$ 13.47 & 5.10 $\pm$ 3.00 & 
22.93 $\pm$ 24.74 & 11.63 $\pm$ 6.87 & 23.47 $\pm$ 19.44 & 6.60 $\pm$ 2.85 & 


In [86]:
path  = '/home/lclan/spinningup/data/'
trajs_path = 'trajs'
continue_path = '/home/lclan/spinningup/data/tmp/'
algo_names = {}
algo_names['Humanoid-v3'] = ['Humanoid-v3_sac_base', 'Humanoid-v3_td3_base', 'vanilla_ppo_humanoid',  'sgld_ppo_humanoid']
algo_names['Ant-v3'] = ['Ant-v3_sac_base' , 'Ant-v3_td3_base', 'vanilla_ppo_ant', 'atla_ppo_ant']
algo_names['Walker2d-v3'] = ['Walker2d-v3_sac_base', 'Walker2d-v3_td3_base', 'vanilla_ppo_walker', 'atla_ppo_walker']
algo_names['HalfCheetah-v3'] = ['HalfCheetah-v3_sac_base', 'HalfCheetah-v3_td3_base',  'vanilla_ppo_halfcheetah', 'atla_ppo_halfcheetah']
algo_names['Hopper-v3'] = ['Hopper-v3_sac_base', 'Hopper-v3_td3_base', 'vanilla_ppo_hopper',  'atla_ppo_hopper']
env_names = list(algo_names.keys())
test_algo_name = 'Humanoid-v3_gsac_base'


In [5]:
all_traj_names = get_all_traj_names_with_same_env(path, trajs_path, test_algo_name)
print(all_traj_names)

/home/lclan/spinningup/data/trajs
['Humanoid-v3_td3_base_400_trajs.pkl', 'Humanoid-v3_sac_base_400_trajs.pkl', 'sgld_ppo_humanoid_400_trajs.pkl', 'vanilla_ppo_humanoid_400_trajs.pkl']


In [6]:
cresults = load_all_same_env_results(continue_path, get_env_name(test_algo_name))

load continue results:  /home/lclan/spinningup/data/tmp/
Humanoid-v3_sac_base_s500_tr50_tn200.pkl
sgld_ppo_humanoid_s500_tr50_tn200.pkl
vanilla_ppo_humanoid_s500_tr50_tn200.pkl
Humanoid-v3_td3_base_s500_tr50_tn200.pkl


In [7]:
trajs_d = get_traj_d(path, trajs_path, all_traj_names, algo_names, test_algo_name)

/home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_humanoid_400_trajs.pkl
/home/lclan/spinningup/data/trajs/sgld_ppo_humanoid_400_trajs.pkl


In [13]:
self_ret = get_self_ret(cresults, trajs_d)

Humanoid-v3_sac_base
sgld_ppo_humanoid
vanilla_ppo_humanoid
Humanoid-v3_td3_base


In [11]:
test_models = get_models(path, test_algo_name)

get models  /home/lclan/spinningup/data/ Humanoid-v3_gsac_base
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1124/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1125/pyt_save/model.pt


In [17]:
print(self_ret['sgld_ppo_humanoid'].keys())

dict_keys(['sgld_ppo_humanoid_1.pt', 'sgld_ppo_humanoid_2.pt', 'sgld_ppo_humanoid_11.pt', 'sgld_ppo_humanoid_10.pt', 'sgld_ppo_humanoid_4.pt', 'sgld_ppo_humanoid_7.pt', 'sgld_ppo_humanoid_8.pt', 'sgld_ppo_humanoid_6.pt', 'sgld_ppo_humanoid_5.pt', 'sgld_ppo_humanoid_9.pt'])


In [28]:
testing_models = get_models(path, test_algo_name)


get models  /home/lclan/spinningup/data/ Humanoid-v3_gsac_base
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1124/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1125/pyt_save/model.pt


In [87]:
def run_extra_steps(env, ep_len, md, md_name, step_num=50):
    max_ep_len = 1000
    total_r = 0
    o = env.get_obs()
    for i in range(step_num):
        a = get_action(o, md, md_name)
        o, r, d, _ = env.step(a)
        total_r += r
        ep_len += 1
        if d or (ep_len == max_ep_len):
            return (d, total_r)
    return (d, total_r)

def test_models(self_ret, test_models, test_algo_name, test_num = 200, step_num = 500):
    ret = {}
    env_name = get_env_name(test_algo_name)
    env = gym.make(env_name)
    for gen_algo in self_ret.keys():
        score = []
        fail = []
        print("start testing on ", gen_algo, " with ", len(self_ret[gen_algo].keys()), " agents" )
        for k in self_ret[gen_algo].keys():
            cnt = 0
            for k2 in self_ret[gen_algo][k].keys():
                orig_ret, s, eplen = self_ret[gen_algo][k][k2]
                for key in test_models.keys():
                    md = test_models[key]
                    restore_state(env, s)
                    d, r = run_extra_steps(env, eplen, md, test_algo_name, step_num)
                    fail.append(int(d))
                    score.append(r)
                cnt += 1
                if cnt > test_num:
                    break
        print(len(score))
        print(len(fail))
        score = np.array(score)
        fail = np.array(fail)
        print('score: ', int(np.mean(score)), int(np.std(score)))
        print('fail: ', np.mean(fail), np.std(fail))
        ret[gen_algo] = (score, fail)
    return ret

# r = test_models(self_ret, testing_models, test_algo_name, test_num = 10)

In [34]:
len(r['Humanoid-v3_sac_base'][1])

198

In [37]:
for key in r.keys():
    l = np.array(r[key][1])
    print(key, np.mean(l), np.std(l))

Humanoid-v3_sac_base 0.030303030303030304 0.17141982574219336
sgld_ppo_humanoid 0.37272727272727274 0.48353040534444364
vanilla_ppo_humanoid 0.16666666666666666 0.37267799624996495
Humanoid-v3_td3_base 0.00909090909090909 0.09491187735373229


In [46]:
test_algo_path = 'ttn_2'
fpath = osp.join(path, test_algo_path)
print(fpath)
file_names = os.listdir(fpath)
file_names


/home/lclan/spinningup/data/ttn_2


['Hopper-v3_gsac_sn1_ttn1',
 'Walker2d-v3_gsac_sn1_ttn2',
 'Hopper-v3_gsac_sn1_ttn2',
 'Humanoid-v3_gsac_sn1_ttn4',
 'Walker2d-v3_gsac_sn1_ttn8',
 'Walker2d-v3_gsac_sn1_ttn4',
 'Hopper-v3_gsac_sn1_ttn4',
 'Humanoid-v3_gsac_sn1_ttn1',
 'Hopper-v3_gsac_sn1_ttn8',
 'Humanoid-v3_gsac_sn1_ttn2',
 'Walker2d-v3_gsac_sn1_ttn1',
 'Humanoid-v3_gsac_sn1_ttn8']

In [88]:
def test_gsac_agent(path, trajs_path, test_algo_path, test_algo_name, continue_path, gen_algo_names):
    all_traj_names = get_all_traj_names_with_same_env(path, trajs_path, test_algo_name)
    cresults = load_all_same_env_results(continue_path, get_env_name(test_algo_name))
    trajs_d = get_traj_d(path, trajs_path, all_traj_names, gen_algo_names, test_algo_name)
    self_ret = get_self_ret(cresults, trajs_d)
    tmp_path = osp.join(path, test_algo_path)
    print(tmp_path)
    testing_models = get_models(tmp_path, test_algo_name)
    ret = test_models(self_ret, testing_models, test_algo_name, test_num = 30)
    return ret


# r = test_gsac_agent(path, trajs_path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn4', continue_path, algo_names)



In [43]:
models = get_models('/home/lclan/spinningup/data/ttn', 'Humanoid-v3_gsac_sn1_ttn4')

get models  /home/lclan/spinningup/data/ Humanoid-v3_gsac_base
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1124/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_gsac_base/Humanoid-v3_gsac_base_s1125/pyt_save/model.pt


In [89]:
def get_models(path, name):
    print("get models ", path, name)
    if 'ppo' in name:
        return get_ppo_models(path, name)
    fpath = osp.join(path, name)
    models = {}
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    for file_name in file_names:   
        # fname = osp.join(fpath, file_name ,'pyt_save', 'model0.pt')
        fname = osp.join(fpath, file_name ,'pyt_save', 'model_13.pt')
        print(fname)
        model = torch.load(fname)
        models[file_name] = model
        # models.append((name, file_name, model))
    return models

# models = get_models('/home/lclan/spinningup/data/ttn_2', 'Humanoid-v3_gsac_sn1_ttn4')

In [48]:
r8 = test_gsac_agent(path, trajs_path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn8', continue_path, algo_names)

/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Humanoid-v3_sac_base_s500_tr50_tn200.pkl
sgld_ppo_humanoid_s500_tr50_tn200.pkl
vanilla_ppo_humanoid_s500_tr50_tn200.pkl
Humanoid-v3_td3_base_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_humanoid_400_trajs.pkl
/home/lclan/spinningup/data/trajs/sgld_ppo_humanoid_400_trajs.pkl
Humanoid-v3_sac_base
sgld_ppo_humanoid
vanilla_ppo_humanoid
Humanoid-v3_td3_base
/home/lclan/spinningup/data/ttn_2
get models  /home/lclan/spinningup/data/ttn_2 Humanoid-v3_gsac_sn1_ttn8
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn8/Humanoid-v3_gsac_sn1_ttn8_s1125/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn8/Humanoid-v3_gsac_sn1_ttn8_s1124/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn8/Hu



1395
1395
score:  2424 917
fail:  0.1770609318996416 0.3817202618338095
start testing on  sgld_ppo_humanoid  with  10  agents
1550
1550
score:  1078 1227
fail:  0.6845161290322581 0.464708293585289
start testing on  vanilla_ppo_humanoid  with  9  agents
1395
1395
score:  2185 1103
fail:  0.2616487455197133 0.43953234179941025
start testing on  Humanoid-v3_td3_base  with  10  agents
1550
1550
score:  2640 651
fail:  0.1 0.3


In [49]:
r2 = test_gsac_agent(path, trajs_path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn2', continue_path, algo_names)

/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Humanoid-v3_sac_base_s500_tr50_tn200.pkl
sgld_ppo_humanoid_s500_tr50_tn200.pkl
vanilla_ppo_humanoid_s500_tr50_tn200.pkl
Humanoid-v3_td3_base_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_humanoid_400_trajs.pkl
/home/lclan/spinningup/data/trajs/sgld_ppo_humanoid_400_trajs.pkl
Humanoid-v3_sac_base
sgld_ppo_humanoid
vanilla_ppo_humanoid
Humanoid-v3_td3_base
/home/lclan/spinningup/data/ttn_2
get models  /home/lclan/spinningup/data/ttn_2 Humanoid-v3_gsac_sn1_ttn2
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn2/Humanoid-v3_gsac_sn1_ttn2_s1128/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn2/Humanoid-v3_gsac_sn1_ttn2_s1126/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn2/Hu

In [50]:
r1 = test_gsac_agent(path, trajs_path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn1', continue_path, algo_names)

/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Humanoid-v3_sac_base_s500_tr50_tn200.pkl
sgld_ppo_humanoid_s500_tr50_tn200.pkl
vanilla_ppo_humanoid_s500_tr50_tn200.pkl
Humanoid-v3_td3_base_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_humanoid_400_trajs.pkl
/home/lclan/spinningup/data/trajs/sgld_ppo_humanoid_400_trajs.pkl
Humanoid-v3_sac_base
sgld_ppo_humanoid
vanilla_ppo_humanoid
Humanoid-v3_td3_base
/home/lclan/spinningup/data/ttn_2
get models  /home/lclan/spinningup/data/ttn_2 Humanoid-v3_gsac_sn1_ttn1
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn1/Humanoid-v3_gsac_sn1_ttn1_s1125/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn1/Humanoid-v3_gsac_sn1_ttn1_s1127/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Humanoid-v3_gsac_sn1_ttn1/Hu

In [56]:
test_algo_path = 'gsac_cs'
fpath = osp.join(path, test_algo_path)
print(fpath)
test_algo_names = os.listdir(fpath)
print(test_algo_names)
cs_results = {}
for test_algo_name in test_algo_names:
    print('start testing ', test_algo_name)
    cs_results[test_algo_name] = test_gsac_agent(path, trajs_path, test_algo_path, test_algo_name, continue_path, algo_names)


/home/lclan/spinningup/data/gsac_cs
['Humanoid-v3_gsac_sn1_cs200', 'Walker2d-v3_gsac_sn1_cs500', 'Walker2d-v3_gsac_sn1_cs50', 'Hopper-v3_gsac_sn1_cs200', 'Humanoid-v3_gsac_sn1_cs500', 'Humanoid-v3_gsac_sn1_cs50', 'Walker2d-v3_gsac_sn1_cs200', 'Hopper-v3_gsac_sn1_cs50', 'Hopper-v3_gsac_sn1_cs500']
start testing  Humanoid-v3_gsac_sn1_cs200
/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Humanoid-v3_sac_base_s500_tr50_tn200.pkl
sgld_ppo_humanoid_s500_tr50_tn200.pkl
vanilla_ppo_humanoid_s500_tr50_tn200.pkl
Humanoid-v3_td3_base_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Humanoid-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_humanoid_400_trajs.pkl
/home/lclan/spinningup/data/trajs/sgld_ppo_humanoid_400_trajs.pkl
Humanoid-v3_sac_base
sgld_ppo_humanoid
vanilla_ppo_humanoid
Humanoid-v3_td3_base
/home/lclan/spinningup/data/gsac_cs
get models  /home/



1395
1395
score:  2465 876
fail:  0.15770609318996415 0.36446519910784697
start testing on  sgld_ppo_humanoid  with  10  agents
1550
1550
score:  1351 1288
fail:  0.5825806451612904 0.4931332852736114
start testing on  vanilla_ppo_humanoid  with  9  agents
1395
1395
score:  2194 1085
fail:  0.26810035842293906 0.4429701527602403
start testing on  Humanoid-v3_td3_base  with  10  agents
1550
1550
score:  2484 819
fail:  0.16129032258064516 0.3677985242255284
start testing  Walker2d-v3_gsac_sn1_cs500
/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Walker2d-v3_td3_base_s500_tr50_tn200.pkl
Walker2d-v3_sac_base_s500_tr50_tn200.pkl
atla_ppo_walker_s500_tr50_tn200.pkl
vanilla_ppo_walker_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Walker2d-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Walker2d-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_walker_400_trajs.pkl
/home/lclan/spinningup/data/trajs/atla_ppo_w

In [57]:
test_algo_path = 'ttn_2'
w_r1 = test_gsac_agent(path, trajs_path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn1', continue_path, algo_names)
w_r2 = test_gsac_agent(path, trajs_path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn2', continue_path, algo_names)
w_r4 = test_gsac_agent(path, trajs_path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn4', continue_path, algo_names)
w_r8 = test_gsac_agent(path, trajs_path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn8', continue_path, algo_names)

/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
Walker2d-v3_td3_base_s500_tr50_tn200.pkl
Walker2d-v3_sac_base_s500_tr50_tn200.pkl
atla_ppo_walker_s500_tr50_tn200.pkl
vanilla_ppo_walker_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Walker2d-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Walker2d-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_walker_400_trajs.pkl
/home/lclan/spinningup/data/trajs/atla_ppo_walker_400_trajs.pkl
Walker2d-v3_td3_base
Walker2d-v3_sac_base
atla_ppo_walker
vanilla_ppo_walker
/home/lclan/spinningup/data/ttn_2
get models  /home/lclan/spinningup/data/ttn_2 Walker2d-v3_gsac_sn1_ttn1
/home/lclan/spinningup/data/ttn_2/Walker2d-v3_gsac_sn1_ttn1/Walker2d-v3_gsac_sn1_ttn1_s1128/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Walker2d-v3_gsac_sn1_ttn1/Walker2d-v3_gsac_sn1_ttn1_s1127/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Walker2d-v3_gsac_sn1_ttn1/Walker2d-v3_gs

In [58]:
test_algo_path = 'ttn_2'
ho_r1 = test_gsac_agent(path, trajs_path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn1', continue_path, algo_names)
ho_r2 = test_gsac_agent(path, trajs_path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn2', continue_path, algo_names)
ho_r4 = test_gsac_agent(path, trajs_path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn4', continue_path, algo_names)
ho_r8 = test_gsac_agent(path, trajs_path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn8', continue_path, algo_names)

/home/lclan/spinningup/data/trajs
load continue results:  /home/lclan/spinningup/data/tmp/
vanilla_ppo_hopper_s500_tr50_tn200.pkl
atla_ppo_hopper_s500_tr50_tn200.pkl
Hopper-v3_td3_base_s500_tr50_tn200.pkl
Hopper-v3_sac_base_s500_tr50_tn200.pkl
/home/lclan/spinningup/data/trajs/Hopper-v3_sac_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/Hopper-v3_td3_base_400_trajs.pkl
/home/lclan/spinningup/data/trajs/vanilla_ppo_hopper_400_trajs.pkl
/home/lclan/spinningup/data/trajs/atla_ppo_hopper_400_trajs.pkl
vanilla_ppo_hopper
atla_ppo_hopper
Hopper-v3_td3_base
Hopper-v3_sac_base
/home/lclan/spinningup/data/ttn_2
get models  /home/lclan/spinningup/data/ttn_2 Hopper-v3_gsac_sn1_ttn1
/home/lclan/spinningup/data/ttn_2/Hopper-v3_gsac_sn1_ttn1/Hopper-v3_gsac_sn1_ttn1_s1126/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Hopper-v3_gsac_sn1_ttn1/Hopper-v3_gsac_sn1_ttn1_s1128/pyt_save/model_13.pt
/home/lclan/spinningup/data/ttn_2/Hopper-v3_gsac_sn1_ttn1/Hopper-v3_gsac_sn1_ttn1_s1127/pyt_save

In [90]:
def print_2f(*args):
    __builtins__.print(*("%.2f" % (a if isinstance(a, float) else a)
                         for a in args), sep='', end = "")
    
def get_models_num(path, name):
    fpath = osp.join(path, name)
    file_names = os.listdir(fpath)
    return len(file_names)


def print_single_result(result, path, test_algo_path, test_algo_name):
    fpath = osp.join(path, test_algo_path)
    num = get_models_num(fpath, test_algo_name)
    
    print(result.keys() , "with", num , "agents")
    for traj_name in result.keys():
        fail = np.array(result[traj_name][1])
        x = [0] * num
        l = []
        for i in range(len(fail)):
            x[i%num] += fail[i]
        for i in range(num):
            l.append(x[i]/(len(fail))*num)
        # print(l)
        print_2f(100*np.mean(l))
        print(" $\pm$ ", sep='', end = "")
        print_2f(100*np.std(l))
        print()
        
# test_algo_path = 'ttn_2'
# print_single_result(ho_r1, path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn1')
# print_single_result(ho_r2, path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn2')
# print_single_result(ho_r4, path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn4')
# print_single_result(ho_r8, path, test_algo_path, 'Hopper-v3_gsac_sn1_ttn8')


        
        

In [71]:
test_algo_path = 'ttn_2'
print_single_result(w_r1, path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn1')
print_single_result(w_r2, path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn2')
print_single_result(w_r4, path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn4')
print_single_result(w_r8, path, test_algo_path, 'Walker2d-v3_gsac_sn1_ttn8')


dict_keys(['Walker2d-v3_td3_base', 'Walker2d-v3_sac_base', 'atla_ppo_walker', 'vanilla_ppo_walker']) with 5 agents
29.74 $\pm$ 7.13
31.23 $\pm$ 3.39
28.93 $\pm$ 12.90
24.65 $\pm$ 11.86
dict_keys(['Walker2d-v3_td3_base', 'Walker2d-v3_sac_base', 'atla_ppo_walker', 'vanilla_ppo_walker']) with 5 agents
9.42 $\pm$ 6.66
14.58 $\pm$ 7.21
8.15 $\pm$ 5.24
6.06 $\pm$ 5.22
dict_keys(['Walker2d-v3_td3_base', 'Walker2d-v3_sac_base', 'atla_ppo_walker', 'vanilla_ppo_walker']) with 5 agents
8.13 $\pm$ 7.00
13.03 $\pm$ 10.25
3.94 $\pm$ 3.76
3.03 $\pm$ 2.78
dict_keys(['Walker2d-v3_td3_base', 'Walker2d-v3_sac_base', 'atla_ppo_walker', 'vanilla_ppo_walker']) with 5 agents
7.48 $\pm$ 1.12
11.35 $\pm$ 3.55
3.19 $\pm$ 0.29
2.39 $\pm$ 0.60


In [75]:
test_algo_path = 'ttn_2'
print_single_result(r1, path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn1')
print_single_result(r2, path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn2')
print_single_result(r, path, test_algo_path,  'Humanoid-v3_gsac_sn1_ttn4')
print_single_result(r8, path, test_algo_path, 'Humanoid-v3_gsac_sn1_ttn8')



dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
37.78 $\pm$ 19.35
81.94 $\pm$ 10.84
46.16 $\pm$ 19.82
25.42 $\pm$ 8.84
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
38.78 $\pm$ 19.44
80.58 $\pm$ 10.87
47.67 $\pm$ 11.23
23.29 $\pm$ 10.05
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
25.66 $\pm$ 18.69
62.26 $\pm$ 15.20
31.18 $\pm$ 10.70
14.52 $\pm$ 7.19
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
17.71 $\pm$ 6.25
68.45 $\pm$ 8.11
26.16 $\pm$ 7.78
10.00 $\pm$ 7.64


In [78]:
test_algo_path = 'gsac_cs'
l = list(cs_results.keys())
l.sort()
for x in l:
    print(x)
    print_single_result(cs_results[x], path, test_algo_path, x)

Hopper-v3_gsac_sn1_cs200
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
12.39 $\pm$ 13.65
20.65 $\pm$ 22.68
10.32 $\pm$ 12.53
14.29 $\pm$ 17.78
Hopper-v3_gsac_sn1_cs50
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
42.26 $\pm$ 29.18
52.35 $\pm$ 23.67
47.23 $\pm$ 21.98
43.13 $\pm$ 30.88
Hopper-v3_gsac_sn1_cs500
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
12.19 $\pm$ 10.19
26.45 $\pm$ 18.91
18.13 $\pm$ 18.86
16.77 $\pm$ 14.38
Humanoid-v3_gsac_sn1_cs200
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
15.77 $\pm$ 7.33
58.26 $\pm$ 12.73
26.81 $\pm$ 7.35
16.13 $\pm$ 12.57
Humanoid-v3_gsac_sn1_cs50
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
31.04 $\pm$ 34.5

In [76]:
tn_result = {}
tn_result['Walker2d-v3_gsac_sn1_ttn1'] = w_r1
tn_result['Walker2d-v3_gsac_sn1_ttn2'] = w_r2
tn_result['Walker2d-v3_gsac_sn1_ttn4'] = w_r4
tn_result['Walker2d-v3_gsac_sn1_ttn8'] = w_r8
tn_result['Hopper-v3_gsac_sn1_ttn1'] = ho_r1
tn_result['Hopper-v3_gsac_sn1_ttn2'] = ho_r2
tn_result['Hopper-v3_gsac_sn1_ttn4'] = ho_r4
tn_result['Hopper-v3_gsac_sn1_ttn8'] = ho_r8
tn_result['Humanoid-v3_gsac_sn1_ttn1'] = r1
tn_result['Humanoid-v3_gsac_sn1_ttn2'] = r2
tn_result['Humanoid-v3_gsac_sn1_ttn4'] = r
tn_result['Humanoid-v3_gsac_sn1_ttn8'] = r8




In [79]:
test_algo_path = 'ttn_2'
l = list(tn_result.keys())
l.sort()
for x in l:
    print(x)
    print_single_result(tn_result[x], path, test_algo_path, x)

Hopper-v3_gsac_sn1_ttn1
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
29.61 $\pm$ 34.10
35.02 $\pm$ 32.38
32.97 $\pm$ 34.97
35.21 $\pm$ 32.37
Hopper-v3_gsac_sn1_ttn2
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
24.65 $\pm$ 32.96
30.69 $\pm$ 31.55
29.42 $\pm$ 33.77
25.81 $\pm$ 33.67
Hopper-v3_gsac_sn1_ttn4
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
3.29 $\pm$ 2.31
7.65 $\pm$ 5.68
8.90 $\pm$ 9.03
3.32 $\pm$ 3.56
Hopper-v3_gsac_sn1_ttn8
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
6.71 $\pm$ 5.18
14.47 $\pm$ 10.20
8.90 $\pm$ 4.91
7.93 $\pm$ 6.16
Humanoid-v3_gsac_sn1_ttn1
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
37.78 $\pm$ 19.35
81.94 $\pm$ 10.84
46.16

In [80]:
test_algo_path = 'gsac_cs'
l = list(cs_results.keys())
l.sort()
for x in l:
    print(x)
    print_single_result(cs_results[x], path, test_algo_path, x)

Hopper-v3_gsac_sn1_cs200
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
12.39 $\pm$ 13.65
20.65 $\pm$ 22.68
10.32 $\pm$ 12.53
14.29 $\pm$ 17.78
Hopper-v3_gsac_sn1_cs50
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
42.26 $\pm$ 29.18
52.35 $\pm$ 23.67
47.23 $\pm$ 21.98
43.13 $\pm$ 30.88
Hopper-v3_gsac_sn1_cs500
dict_keys(['vanilla_ppo_hopper', 'atla_ppo_hopper', 'Hopper-v3_td3_base', 'Hopper-v3_sac_base']) with 5 agents
12.19 $\pm$ 10.19
26.45 $\pm$ 18.91
18.13 $\pm$ 18.86
16.77 $\pm$ 14.38
Humanoid-v3_gsac_sn1_cs200
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
15.77 $\pm$ 7.33
58.26 $\pm$ 12.73
26.81 $\pm$ 7.35
16.13 $\pm$ 12.57
Humanoid-v3_gsac_sn1_cs50
dict_keys(['Humanoid-v3_sac_base', 'sgld_ppo_humanoid', 'vanilla_ppo_humanoid', 'Humanoid-v3_td3_base']) with 5 agents
31.04 $\pm$ 34.5

In [9]:
algo_names = {}
algo_names['Humanoid-v3'] = ['Humanoid-v3_sac_base', 'Humanoid-v3_td3_base', 'vanilla_ppo_humanoid',  'sgld_ppo_humanoid']
# algo_names['Ant-v3'] = ['Ant-v3_sac_base' , 'Ant-v3_td3_base', 'vanilla_ppo_ant', 'atla_ppo_ant']
algo_names['Walker2d-v3'] = ['Walker2d-v3_sac_base', 'Walker2d-v3_td3_base', 'vanilla_ppo_walker', 'atla_ppo_walker']
# algo_names['HalfCheetah-v3'] = ['HalfCheetah-v3_sac_base', 'HalfCheetah-v3_td3_base',  'vanilla_ppo_halfcheetah', 'atla_ppo_halfcheetah']
algo_names['Hopper-v3'] = ['Hopper-v3_sac_base', 'Hopper-v3_td3_base', 'vanilla_ppo_hopper',  'atla_ppo_hopper']

test_name_keys = ['cs50', 'cs100', 'cs200', 'cs500']

fname = './data/gsac_cs_2/results.txt'
f = open(fname, 'r')
s = f.read()
# print(s)
x = s.split()
i = 0
data = {}
while i < len(x):
    if x[i+1] not in data.keys():
        data[x[i+1]] = {}
    data[x[i+1]][x[i]] = x[i+2] + ' ' + x[i+3] + ' ' + x[i+4]
    # print(x[i], x[i+1], data[x[i]][x[i+1]])
    i += 5

    

In [10]:
test_name_keys = ['cs50', 'cs100', 'cs200', 'cs500']
for env_name in algo_names.keys():
    print(env_name)
    for algo_name in algo_names[env_name]:
        # print(data[algo_name].keys())
        for test_name_key in test_name_keys:
            name = ''
            for test_algo_name in data[algo_name].keys():
                if test_name_key in test_algo_name:
                    if (test_name_key + '_') in test_algo_name or \
                        test_name_key == test_algo_name[-len(test_name_key):] :
                        name = test_algo_name
                        break
            # print(algo_name, name, test_name_key)
            print(data[algo_name][name], end=" & ")
        print('')
                     

Humanoid-v3
30.48 $\pm$ 34.95 & 13.23 $\pm$ 6.84 & 14.34 $\pm$ 6.00 & 14.76 $\pm$ 5.08 & 
25.76 $\pm$ 37.30 & 7.81 $\pm$ 4.76 & 10.67 $\pm$ 3.74 & 11.76 $\pm$ 6.61 & 
40.63 $\pm$ 30.09 & 25.08 $\pm$ 3.36 & 25.03 $\pm$ 5.04 & 28.36 $\pm$ 7.31 & 
69.76 $\pm$ 15.71 & 58.10 $\pm$ 8.32 & 61.81 $\pm$ 12.27 & 65.43 $\pm$ 10.88 & 
Walker2d-v3
6.10 $\pm$ 2.34 & 5.38 $\pm$ 3.58 & 4.48 $\pm$ 3.15 & 6.52 $\pm$ 2.87 & 
3.43 $\pm$ 1.52 & 3.71 $\pm$ 2.48 & 3.14 $\pm$ 1.80 & 3.48 $\pm$ 2.20 & 
1.95 $\pm$ 1.35 & 1.19 $\pm$ 0.86 & 1.81 $\pm$ 1.06 & 2.81 $\pm$ 1.97 & 
3.36 $\pm$ 1.94 & 2.33 $\pm$ 0.99 & 2.88 $\pm$ 1.24 & 3.78 $\pm$ 2.24 & 
Hopper-v3
4.35 $\pm$ 4.73 & 2.59 $\pm$ 2.41 & 7.96 $\pm$ 5.68 & 7.82 $\pm$ 6.54 & 
3.76 $\pm$ 5.74 & 3.05 $\pm$ 5.03 & 5.19 $\pm$ 4.37 & 6.52 $\pm$ 9.11 & 
3.14 $\pm$ 4.57 & 2.62 $\pm$ 2.95 & 5.86 $\pm$ 5.15 & 3.71 $\pm$ 6.79 & 
3.67 $\pm$ 3.31 & 4.69 $\pm$ 4.01 & 6.05 $\pm$ 3.88 & 5.85 $\pm$ 7.50 & 


In [86]:
test_algo_name = 'Hopper-v3_gsac_sn1_cs200'
key = 'cs200'
print(test_algo_name[-len(key):])

cs200


In [91]:
fname = './data/test_algo/results.txt'
test_name_keys = ['owr13', 'owr16']

def print_result(fname, test_name_keys):
    f = open(fname, 'r')
    s = f.read()
    # print(s)
    x = s.split()
    i = 0
    data = {}
    while i < len(x):
        if x[i+1] not in data.keys():
            data[x[i+1]] = {}
        data[x[i+1]][x[i]] = x[i+2] + ' ' + x[i+3] + ' ' + x[i+4]
        # print(x[i], x[i+1], data[x[i]][x[i+1]])
        i += 5
    for env_name in algo_names.keys():
        print(env_name)
        for algo_name in algo_names[env_name]:
            # print(data[algo_name].keys())
            for test_name_key in test_name_keys:
                name = ''
                for test_algo_name in data[algo_name].keys():
                    if test_name_key in test_algo_name:
                        if (test_name_key + '_') in test_algo_name or \
                            test_name_key == test_algo_name[-len(test_name_key):] :
                            name = test_algo_name
                            break
                # print(algo_name, name, test_name_key)
                print(data[algo_name][name], end=" & ")
            print('')

print_result(fname, test_name_keys)

Humanoid-v3
4.23 $\pm$ 3.46 & 11.22 $\pm$ 6.79 & 
3.81 $\pm$ 2.79 & 5.90 $\pm$ 6.95 & 
9.84 $\pm$ 4.27 & 17.88 $\pm$ 8.88 & 
37.14 $\pm$ 10.45 & 54.38 $\pm$ 16.61 & 
Ant-v3


KeyError: 'Ant-v3_sac_base'

In [12]:
fname = './data/test_algo/results1.txt'

print_result(fname, test_name_keys)

Humanoid-v3
1.80 $\pm$ 1.32 & 11.01 $\pm$ 7.03 & 
0.86 $\pm$ 0.76 & 5.43 $\pm$ 7.23 & 
9.21 $\pm$ 3.25 & 14.50 $\pm$ 6.63 & 
33.33 $\pm$ 8.26 & 49.90 $\pm$ 14.76 & 
Walker2d-v3
16.00 $\pm$ 23.81 & 8.86 $\pm$ 4.33 & 
13.81 $\pm$ 21.23 & 8.76 $\pm$ 5.49 & 
5.81 $\pm$ 6.22 & 1.62 $\pm$ 1.23 & 
6.92 $\pm$ 5.84 & 2.21 $\pm$ 1.72 & 
Hopper-v3
14.15 $\pm$ 9.19 & 14.69 $\pm$ 7.76 & 
25.90 $\pm$ 22.86 & 21.05 $\pm$ 14.47 & 
13.81 $\pm$ 13.19 & 12.19 $\pm$ 3.40 & 
22.31 $\pm$ 19.77 & 17.96 $\pm$ 11.44 & 


In [13]:
fname = './data/ttn_3/results.txt'
test_name_keys = ['ttn1', 'ttn2', 'ttn4', 'ttn8']

print_result(fname, test_name_keys)

Humanoid-v3
55.71 $\pm$ 16.94 & 41.53 $\pm$ 19.72 & 24.02 $\pm$ 14.71 & 11.43 $\pm$ 7.21 & 
40.00 $\pm$ 13.94 & 22.95 $\pm$ 11.57 & 15.00 $\pm$ 7.96 & 7.29 $\pm$ 2.79 & 
59.31 $\pm$ 13.96 & 41.53 $\pm$ 12.53 & 30.69 $\pm$ 12.08 & 24.81 $\pm$ 3.89 & 
90.86 $\pm$ 4.31 & 78.81 $\pm$ 5.83 & 71.95 $\pm$ 11.84 & 59.90 $\pm$ 9.60 & 
Walker2d-v3
28.62 $\pm$ 6.53 & 20.00 $\pm$ 14.68 & 8.95 $\pm$ 5.99 & 8.14 $\pm$ 5.89 & 
21.05 $\pm$ 6.46 & 17.29 $\pm$ 11.37 & 7.71 $\pm$ 4.31 & 4.71 $\pm$ 3.64 & 
14.00 $\pm$ 5.83 & 10.19 $\pm$ 10.78 & 3.48 $\pm$ 1.66 & 3.05 $\pm$ 2.14 & 
19.32 $\pm$ 4.70 & 12.08 $\pm$ 9.00 & 5.46 $\pm$ 3.59 & 3.83 $\pm$ 2.24 & 
Hopper-v3
19.80 $\pm$ 15.34 & 10.20 $\pm$ 6.23 & 8.64 $\pm$ 5.50 & 4.01 $\pm$ 3.76 & 
19.14 $\pm$ 19.51 & 5.10 $\pm$ 4.26 & 11.57 $\pm$ 7.75 & 5.00 $\pm$ 4.51 & 
17.38 $\pm$ 15.77 & 8.57 $\pm$ 7.30 & 6.38 $\pm$ 4.13 & 2.81 $\pm$ 2.93 & 
22.65 $\pm$ 22.65 & 5.31 $\pm$ 3.47 & 9.66 $\pm$ 7.71 & 6.46 $\pm$ 5.77 & 


In [92]:
def is_good_key(key,  remove_keys, must_keys):
    for k in remove_keys:
        if k in key:
            return False
    for k in must_keys:
        if k not in key:
            return False
    return True
def print_result(fname, test_name_keys, remove_keys=[], must_keys=[]):
    f = open(fname, 'r')
    s = f.read()
    # print(s)
    x = s.split()
    i = 0
    data = {}
    while i < len(x):
        if x[i+1] not in data.keys():
            data[x[i+1]] = {}
        data[x[i+1]][x[i]] = x[i+2] + ' ' + x[i+3] + ' ' + x[i+4]
        # print(x[i], x[i+1], data[x[i]][x[i+1]])
        i += 5
    for env_name in algo_names.keys():
        print(env_name)
        for algo_name in algo_names[env_name]:
            # print(data[algo_name].keys())
            for test_name_key in test_name_keys:
                name = ''
                for test_algo_name in data[algo_name].keys():
                    if test_name_key in test_algo_name \
                        and is_good_key(test_algo_name, remove_keys, must_keys):
                        if (test_name_key + '_') in test_algo_name or \
                            test_name_key == test_algo_name[-len(test_name_key):] :
                            name = test_algo_name
                            break
                # print(algo_name, name, test_name_key)
                if name != '':
                    print(data[algo_name][name], end=" & ")
                else:
                    print(test_name_key, "cant find")
            print('')


In [18]:
fname = './data/gsac3_cs/results.txt'
remove_keys = ['cs', 'ssr']
must_keys = ['tttr']
test_name_keys = ['sn2', 'sn5']
print_result(fname, test_name_keys, remove_keys, must_keys)



Humanoid-v3
6.56 $\pm$ 5.03 & 4.34 $\pm$ 4.26 & 
3.62 $\pm$ 3.71 & 3.76 $\pm$ 4.05 & 
10.16 $\pm$ 6.39 & 9.63 $\pm$ 5.19 & 
36.19 $\pm$ 16.04 & 18.67 $\pm$ 7.79 & 
Walker2d-v3
14.48 $\pm$ 8.89 & 13.71 $\pm$ 7.77 & 
13.38 $\pm$ 7.67 & 10.43 $\pm$ 7.27 & 
9.00 $\pm$ 5.64 & 7.19 $\pm$ 4.37 & 
10.33 $\pm$ 7.05 & 9.97 $\pm$ 6.15 & 
Hopper-v3
16.39 $\pm$ 14.11 & 13.74 $\pm$ 6.18 & 
18.57 $\pm$ 20.70 & 24.10 $\pm$ 19.87 & 
14.14 $\pm$ 16.11 & 10.57 $\pm$ 7.27 & 
15.17 $\pm$ 18.93 & 19.39 $\pm$ 12.61 & 


In [19]:
fname = './data/gsac3_cs/results.txt'
remove_keys = ['cs', 'ssr', 'tttr']
must_keys = ['sn5']
test_name_keys = ['owr13', 'owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
4.23 $\pm$ 3.46 & 11.22 $\pm$ 6.79 & 
3.81 $\pm$ 2.79 & 5.90 $\pm$ 6.95 & 
9.84 $\pm$ 4.27 & 17.88 $\pm$ 8.88 & 
37.14 $\pm$ 10.45 & 54.38 $\pm$ 16.61 & 
Walker2d-v3
21.33 $\pm$ 21.82 & 8.90 $\pm$ 5.52 & 
18.86 $\pm$ 19.05 & 7.76 $\pm$ 4.88 & 
10.76 $\pm$ 5.31 & 3.52 $\pm$ 3.60 & 
12.58 $\pm$ 5.35 & 4.99 $\pm$ 4.32 & 
Hopper-v3
14.15 $\pm$ 9.19 & 13.13 $\pm$ 8.51 & 
25.90 $\pm$ 22.86 & 21.38 $\pm$ 18.95 & 
13.81 $\pm$ 13.19 & 14.19 $\pm$ 10.94 & 
22.31 $\pm$ 19.77 & 19.32 $\pm$ 17.55 & 


In [21]:
fname = './data/gsac3_cs/results.txt'
remove_keys = ['ssr', 'tttr']
must_keys = ['owr11', 'sn5']
test_name_keys = ['cs50', 'cs200']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
14.55 $\pm$ 15.75 & 36.61 $\pm$ 13.11 & 
10.90 $\pm$ 10.00 & 23.05 $\pm$ 12.22 & 
21.59 $\pm$ 10.09 & 41.75 $\pm$ 10.92 & 
26.38 $\pm$ 11.65 & 85.57 $\pm$ 6.22 & 
Walker2d-v3
22.14 $\pm$ 7.76 & 14.43 $\pm$ 7.90 & 
19.62 $\pm$ 6.42 & 11.71 $\pm$ 7.40 & 
14.62 $\pm$ 10.16 & 6.38 $\pm$ 3.62 & 
16.52 $\pm$ 8.07 & 8.92 $\pm$ 5.00 & 
Hopper-v3
13.20 $\pm$ 8.87 & 17.41 $\pm$ 6.88 & 
12.00 $\pm$ 11.87 & 24.81 $\pm$ 13.87 & 
8.67 $\pm$ 8.49 & 17.33 $\pm$ 9.80 & 
12.72 $\pm$ 13.26 & 21.70 $\pm$ 15.65 & 


In [23]:
fname = './data/gsac4/results.txt'
remove_keys = ['ssr', 'tttr']
must_keys = ['owr10']
test_name_keys = ['sn1', 'sn2', 'sn5']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
24.87 $\pm$ 11.39 & 18.52 $\pm$ 11.32 & 4.59 $\pm$ 2.28 & 
18.89 $\pm$ 7.66 & 14.68 $\pm$ 21.23 & 2.46 $\pm$ 2.32 & 
32.10 $\pm$ 6.46 & 20.37 $\pm$ 9.73 & 14.81 $\pm$ 5.39 & 
71.59 $\pm$ 7.66 & 55.16 $\pm$ 11.88 & 36.27 $\pm$ 4.82 & 
Walker2d-v3
28.02 $\pm$ 15.44 & 22.62 $\pm$ 10.94 & 16.19 $\pm$ 7.44 & 
25.16 $\pm$ 14.16 & 15.87 $\pm$ 6.45 & 12.94 $\pm$ 3.75 & 
14.05 $\pm$ 7.01 & 6.83 $\pm$ 4.39 & 9.13 $\pm$ 4.33 & 
18.30 $\pm$ 8.09 & 10.48 $\pm$ 7.74 & 9.82 $\pm$ 5.20 & 
Hopper-v3
16.12 $\pm$ 7.26 & 13.20 $\pm$ 6.51 & 18.71 $\pm$ 14.69 & 
13.95 $\pm$ 13.11 & 11.00 $\pm$ 10.04 & 10.62 $\pm$ 11.84 & 
12.38 $\pm$ 7.73 & 12.43 $\pm$ 9.17 & 15.57 $\pm$ 12.36 & 
12.59 $\pm$ 9.38 & 12.93 $\pm$ 10.66 & 15.17 $\pm$ 14.55 & 


In [24]:
fname = './data/gsac4/results.txt'
remove_keys = ['ssr', 'tttr']
must_keys = ['sn5']
test_name_keys = ['owr10', 'owr11', 'owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
4.59 $\pm$ 2.28 & 7.94 $\pm$ 8.21 & 9.44 $\pm$ 5.07 & 4.76 $\pm$ 2.80 & 
2.46 $\pm$ 2.32 & 2.94 $\pm$ 3.44 & 2.14 $\pm$ 2.38 & 0.79 $\pm$ 0.35 & 
14.81 $\pm$ 5.39 & 15.70 $\pm$ 11.69 & 15.52 $\pm$ 12.12 & 13.14 $\pm$ 3.06 & 
36.27 $\pm$ 4.82 & 38.10 $\pm$ 10.74 & 44.37 $\pm$ 18.71 & 41.51 $\pm$ 5.52 & 
Walker2d-v3
16.19 $\pm$ 7.44 & 10.48 $\pm$ 5.53 & 10.00 $\pm$ 7.17 & 6.43 $\pm$ 5.26 & 
12.94 $\pm$ 3.75 & 8.17 $\pm$ 3.63 & 8.57 $\pm$ 6.77 & 5.32 $\pm$ 3.87 & 
9.13 $\pm$ 4.33 & 8.02 $\pm$ 3.97 & 1.51 $\pm$ 1.18 & 1.35 $\pm$ 1.66 & 
9.82 $\pm$ 5.20 & 9.19 $\pm$ 5.51 & 2.72 $\pm$ 2.65 & 2.67 $\pm$ 3.51 & 
Hopper-v3
18.71 $\pm$ 14.69 & 15.44 $\pm$ 10.06 & 22.59 $\pm$ 9.99 & 20.27 $\pm$ 14.35 & 
10.62 $\pm$ 11.84 & 14.14 $\pm$ 14.84 & 18.95 $\pm$ 10.64 & 15.33 $\pm$ 9.58 & 
15.57 $\pm$ 12.36 & 9.86 $\pm$ 7.62 & 18.48 $\pm$ 7.55 & 16.19 $\pm$ 13.00 & 
15.17 $\pm$ 14.55 & 12.31 $\pm$ 11.17 & 20.34 $\pm$ 10.04 & 16.60 $\pm$ 14.03 & 


In [30]:
fname = './data/gsac4_tttr75/results.txt'
remove_keys = ['ssr' ]
must_keys = ['owr10', 'tttr']
test_name_keys = ['sn1', 'sn2', 'sn5']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
19.74 $\pm$ 18.14 & 8.52 $\pm$ 5.74 & 5.77 $\pm$ 3.06 & 
9.29 $\pm$ 7.94 & 3.48 $\pm$ 2.89 & 3.76 $\pm$ 4.49 & 
26.19 $\pm$ 10.61 & 15.77 $\pm$ 5.25 & 9.95 $\pm$ 5.38 & 
64.24 $\pm$ 8.57 & 48.52 $\pm$ 7.44 & 25.24 $\pm$ 6.47 & 
Walker2d-v3
20.10 $\pm$ 12.22 & 16.67 $\pm$ 9.47 & 9.62 $\pm$ 3.72 & 
16.38 $\pm$ 10.38 & 13.14 $\pm$ 6.16 & 10.05 $\pm$ 5.81 & 
12.10 $\pm$ 9.35 & 8.95 $\pm$ 6.09 & 5.90 $\pm$ 2.64 & 
15.89 $\pm$ 9.78 & 10.40 $\pm$ 7.28 & 7.67 $\pm$ 3.90 & 
Hopper-v3
13.61 $\pm$ 10.21 & 19.80 $\pm$ 16.03 & 14.15 $\pm$ 8.12 & 
22.14 $\pm$ 21.08 & 20.90 $\pm$ 17.33 & 17.67 $\pm$ 12.32 & 
12.67 $\pm$ 14.38 & 18.19 $\pm$ 18.12 & 12.81 $\pm$ 12.57 & 
24.29 $\pm$ 20.53 & 23.54 $\pm$ 21.86 & 21.22 $\pm$ 16.07 & 


In [31]:
fname = './data/gsac4_tttr75/results.txt'
remove_keys = ['ssr' ]
must_keys = ['sn5', 'tttr']
test_name_keys = ['owr10', 'owr11', 'owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

Humanoid-v3
5.77 $\pm$ 3.06 & 4.44 $\pm$ 5.34 & 8.36 $\pm$ 5.62 & 9.10 $\pm$ 5.55 & 
3.76 $\pm$ 4.49 & 2.24 $\pm$ 2.67 & 6.05 $\pm$ 4.30 & 3.62 $\pm$ 3.56 & 
9.95 $\pm$ 5.38 & 12.06 $\pm$ 6.33 & 15.82 $\pm$ 5.20 & 12.43 $\pm$ 5.61 & 
25.24 $\pm$ 6.47 & 33.29 $\pm$ 15.26 & 48.52 $\pm$ 11.47 & 43.62 $\pm$ 6.97 & 
Walker2d-v3
9.62 $\pm$ 3.72 & 11.57 $\pm$ 6.92 & 6.52 $\pm$ 3.92 & 7.33 $\pm$ 4.62 & 
10.05 $\pm$ 5.81 & 11.57 $\pm$ 7.70 & 6.38 $\pm$ 5.76 & 5.48 $\pm$ 3.35 & 
5.90 $\pm$ 2.64 & 7.00 $\pm$ 4.10 & 5.24 $\pm$ 4.69 & 2.57 $\pm$ 3.25 & 
7.67 $\pm$ 3.90 & 8.17 $\pm$ 5.16 & 7.49 $\pm$ 5.27 & 2.83 $\pm$ 2.04 & 
Hopper-v3
14.15 $\pm$ 8.12 & 23.67 $\pm$ 15.10 & 23.20 $\pm$ 10.63 & 9.80 $\pm$ 4.51 & 
17.67 $\pm$ 12.32 & 18.38 $\pm$ 19.22 & 28.52 $\pm$ 19.56 & 13.81 $\pm$ 8.20 & 
12.81 $\pm$ 12.57 & 20.24 $\pm$ 16.39 & 21.24 $\pm$ 12.68 & 8.52 $\pm$ 6.04 & 
21.22 $\pm$ 16.07 & 21.09 $\pm$ 17.25 & 28.98 $\pm$ 18.76 & 10.27 $\pm$ 9.48 & 


In [83]:
fname = './data/gsac4_tttr75/results.txt'
remove_keys = ['ssr' ]
must_keys = ['sn5', 'tttr']
test_name_keys = ['owr10', 'owr11', 'owr13','owr16']
print_result(fname, test_name_keys, remove_keys, must_keys)

AttributeError: 'list' object has no attribute 'keys'

In [32]:

algo_names = ['Humanoid-v3_sac_base', 'Walker2d-v3_sac_base', 'Hopper-v3_sac_base',
              'Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10',
              'Walker2d-v3_gsac3_tttr75_sn5_sr3_owr16',
              'Hopper-v3_gsac3_tttr75_sn5_sr3_owr16',
              'Humanoid-v3_gsac_sn1_cs100',
              'Walker2d-v3_gsac_sn1_cs100',
              'Hopper-v3_gsac_sn1_cs100'
              ]
path = './data/'
algo_paths = ['./','./', './', './gsac4_tttr75/', './gsac4_tttr75/', './gsac4_tttr75/', './gsac_cs_2/', './gsac_cs_2/','./gsac_cs_2/']



In [56]:
def get_results(id, algo_names, algo_paths, path):
    fpath = osp.join(path, algo_paths[id])
    print(fpath)
    x = os.listdir(fpath)
    ret = {}
    for name in x:
        if 'play_on' in name and algo_names[id] in name:
            fname = osp.join(fpath, name)
            print(fname)
            f = open(fname, 'rb')
            r = pickle.load(f)
            for key in r.keys():
                for k2 in r[key].keys():
                    ret[k2] = r[key][k2]
    return ret

x = get_results(3, algo_names, algo_paths, path)

./data/./gsac4_tttr75/
./data/./gsac4_tttr75/Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10_play_on_Humanoid-v3_td3_base.pkl
./data/./gsac4_tttr75/Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10_play_on_Humanoid-v3_sac_base.pkl
./data/./gsac4_tttr75/Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10_play_on_vanilla_ppo_humanoid.pkl
./data/./gsac4_tttr75/Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10_play_on_sgld_ppo_humanoid.pkl


In [57]:
x.keys()

dict_keys(['Humanoid-v3_td3_base', 'Humanoid-v3_sac_base', 'vanilla_ppo_humanoid', 'sgld_ppo_humanoid'])

In [65]:
def print_2f(*args):
    __builtins__.print(*("%.2f" % (a if isinstance(a, float) else a)
                         for a in args), sep='', end = "")
def print_rets(x):
    
    
    for k in x:
        rets = np.array(x[k][0])
        print(k, end="")
        print_2f(np.mean(rets))
        print(" $\pm$ " , end='')
        print_2f(np.std(rets))
        print()

print_rets(x)
        

Humanoid-v3_td3_base2862.20 $\pm$ 504.93
Humanoid-v3_sac_base2815.20 $\pm$ 622.45
vanilla_ppo_humanoid2692.23 $\pm$ 793.35
sgld_ppo_humanoid2312.02 $\pm$ 1198.09


In [69]:
gen_algo_names = {}
gen_algo_names['Humanoid-v3'] = ['Humanoid-v3_sac_base', 'Humanoid-v3_td3_base', 'vanilla_ppo_humanoid',  'sgld_ppo_humanoid']
# algo_names['Ant-v3'] = ['Ant-v3_sac_base' , 'Ant-v3_td3_base', 'vanilla_ppo_ant', 'atla_ppo_ant']
gen_algo_names['Walker2d-v3'] = ['Walker2d-v3_sac_base', 'Walker2d-v3_td3_base', 'vanilla_ppo_walker', 'atla_ppo_walker']
# algo_names['HalfCheetah-v3'] = ['HalfCheetah-v3_sac_base', 'HalfCheetah-v3_td3_base',  'vanilla_ppo_halfcheetah', 'atla_ppo_halfcheetah']
gen_algo_names['Hopper-v3'] = ['Hopper-v3_sac_base', 'Hopper-v3_td3_base', 'vanilla_ppo_hopper',  'atla_ppo_hopper']


In [75]:

algo_names = ['Humanoid-v3_sac_base', 'Walker2d-v3_sac_base', 'Hopper-v3_sac_base',
              'Humanoid-v3_gsac_sn1_cs100',
              'Walker2d-v3_gsac_sn1_cs100',
              'Hopper-v3_gsac_sn1_cs100',
              'Humanoid-v3_gsac3_tttr75_sn5_sr3_owr10',
              'Walker2d-v3_gsac3_tttr75_sn5_sr3_owr16',
              'Hopper-v3_gsac3_tttr75_sn5_sr3_owr16',
              ]
path = './data/'
algo_paths = ['./sac_base/','./sac_base/', './sac_base/', './gsac_cs_2/', './gsac_cs_2/','./gsac_cs_2/', './gsac4_tttr75/', './gsac4_tttr75/', './gsac4_tttr75/', ]


def get_results(id, algo_names, algo_paths, path):
    fpath = osp.join(path, algo_paths[id])
    print(fpath)
    x = os.listdir(fpath)
    ret = {}
    for name in x:
        if 'play_on' in name and algo_names[id] in name:
            fname = osp.join(fpath, name)
            # print(fname)
            f = open(fname, 'rb')
            r = pickle.load(f)
            for key in r.keys():
                for k2 in r[key].keys():
                    ret[k2] = r[key][k2]
    return ret

all_resutls = {}


for k in gen_algo_names.keys():
    for name in gen_algo_names[k]:
        all_resutls[name] = []

for id in range(9):
    x =  get_results(id, algo_names, algo_paths, path)
    for k in x.keys():
        rets = np.array(x[k][0])
        ff = np.array(x[k][1])
        all_resutls[k].append((k, np.mean(rets), np.std(rets),np.mean(ff),np.std(ff)))
        



        
        

./data/./sac_base/
./data/./sac_base/
./data/./sac_base/
./data/./gsac_cs_2/
./data/./gsac_cs_2/
./data/./gsac_cs_2/
./data/./gsac4_tttr75/
./data/./gsac4_tttr75/
./data/./gsac4_tttr75/


In [82]:
def print_tuple(t):
    print(int(t[1]), end='')
    print(" $\pm$ " , end='')
    print(int(t[2]), end='')
    print(" & " , end='')
    print_2f(t[3]*100)
    print(" $\pm$ " , end='')
    print_2f(t[4]*100)
    
for k in all_resutls.keys():
    # print(k, end=' ')
    for t in all_resutls[k]:
        print_tuple(t)
        print(" & " , end='')
    print()

    

2004 $\pm$ 1193 & 32.94 $\pm$ 47.00 & 2591 $\pm$ 841 & 13.23 $\pm$ 33.88 & 2815 $\pm$ 622 & 5.77 $\pm$ 23.31 & 
2242 $\pm$ 1071 & 24.21 $\pm$ 42.83 & 2731 $\pm$ 621 & 7.81 $\pm$ 26.83 & 2862 $\pm$ 504 & 3.76 $\pm$ 19.03 & 
1675 $\pm$ 1266 & 45.46 $\pm$ 49.79 & 2266 $\pm$ 1120 & 25.08 $\pm$ 43.35 & 2692 $\pm$ 793 & 9.95 $\pm$ 29.93 & 
671 $\pm$ 996 & 83.06 $\pm$ 37.51 & 1382 $\pm$ 1343 & 58.10 $\pm$ 49.34 & 2312 $\pm$ 1198 & 25.24 $\pm$ 43.44 & 
2286 $\pm$ 1138 & 25.20 $\pm$ 43.42 & 2931 $\pm$ 658 & 5.38 $\pm$ 22.56 & 2784 $\pm$ 725 & 7.33 $\pm$ 26.07 & 
2326 $\pm$ 1078 & 23.25 $\pm$ 42.25 & 2965 $\pm$ 537 & 3.71 $\pm$ 18.91 & 2822 $\pm$ 632 & 5.48 $\pm$ 22.75 & 
2464 $\pm$ 884 & 13.89 $\pm$ 34.58 & 2931 $\pm$ 336 & 1.19 $\pm$ 10.85 & 2798 $\pm$ 473 & 2.57 $\pm$ 15.83 & 
2314 $\pm$ 1007 & 18.73 $\pm$ 39.02 & 2888 $\pm$ 449 & 2.33 $\pm$ 15.09 & 2782 $\pm$ 488 & 2.83 $\pm$ 16.59 & 
1777 $\pm$ 390 & 14.74 $\pm$ 35.45 & 1846 $\pm$ 148 & 2.59 $\pm$ 15.87 & 1818 $\pm$ 323 & 9.80 $\pm$ 29.73 &