In [1]:
from ast import mod
from telnetlib import DM
from turtle import mode
from unicodedata import name
import gym
from copy import deepcopy
import os
import os.path as osp
import torch
from scipy import stats
from statistics import mean 
import numpy as np
from torch.optim import Adam
import itertools
import random
import torch.nn as nn
import argparse
import pickle

def get_env_name(name):
    if ('humanoid' in name) or ('Humanoid' in name):
        return 'Humanoid-v3'
    if ('halfcheetah' in name) or ('HalfCheetah' in name):
        return 'HalfCheetah-v3'
    if ('ant' in name) or ('Ant' in name):
        return 'Ant-v3'
    if ('hopper' in name) or  ('Hopper' in name) :
        return 'Hopper-v3'
    if ('walker' in name) or ('Walker' in name) :
        return 'Walker2d-v3'
    return 'unknown'

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class PPO_Actor():
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        self.pi = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
        self.obs_mean = np.ones(obs_dim)
        self.obs_std = np.ones(obs_dim)
        self.clip = 10.0
        # print(type(self.pi))
    
    def normalize_o(self, o):
        o = o - self.obs_mean
        o = o / (self.obs_std + 1e-8)
        o = np.clip(o, -self.clip, self.clip)
        return o
    
    def act(self, o):
        o = self.normalize_o(o)
        o = torch.as_tensor(o, dtype=torch.float32)
        return self.pi(o).detach().numpy()
    
    def copy_model(self, md):
        self.pi.load_state_dict(md['pi'])
        self.obs_mean = md['obs_mean']
        self.obs_std = md['obs_std']
        self.clip = md['clip']
        
    def load(self, name):
        md = torch.load(name)
        self.copy_model(md)


def get_ppo_models(path, name):
    fpath = osp.join(path, name)
    models = []
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    env = gym.make(get_env_name(name))
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    for file_name in file_names:   
        if ".pt" not in file_name:
            continue
        fname = osp.join(fpath, file_name)
        print(file_name)
        model = PPO_Actor(obs_dim, action_dim, (64, 64), nn.Tanh)
        model.load(fname)
        models.append((name, file_name, model))
    return models

def get_models(path, name):
    if 'ppo' in name:
        return get_ppo_models(path, name)
    fpath = osp.join(path, name)
    models = []
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    for file_name in file_names:   
        fname = osp.join(fpath, file_name ,'pyt_save', 'model.pt')
        print(fname)
        model = torch.load(fname)
        models.append((name, file_name, model))
    return models

def save_state(env):
    return env.sim.get_state()

def restore_state(env, old_state):
    env.reset()
    env.sim.set_state(old_state)
    env.sim.forward()
    return env.get_obs()

def get_ppo_action(o, md):
    return md.act(o)

def get_action(o, md, name):
    if 'ppo' in name:
        return get_ppo_action(o, md)
    o = torch.as_tensor(o, dtype=torch.float32)
    return md.act(o)

def print_rets(rets):
    rets = np.array(rets)
    print("mean, max, min, std", np.mean(rets),np.max(rets) , np.min(rets) , np.std(rets))
    return np.mean(rets)



In [2]:
path = '/home/lclan/spinningup/data/'
trajs_path = 'trajs'
name = 'Humanoid-v3_sac_base'

def run_extra_steps(env, ep_len, md, md_name, step_num=50):
    max_ep_len = 1000
    total_r = 0
    o = env.get_obs()
    for i in range(step_num):
        a = get_action(o, md, md_name)
        o, r, d, _ = env.step(a)
        total_r += r
        ep_len += 1
        if d or (ep_len == max_ep_len):
            return (d, total_r)
    return (d, total_r)


In [36]:

def get_all_traj_names_with_same_env(path, trajs_path, name):
    all_trajs_names = []
    fpath = osp.join(path, trajs_path)
    print(fpath)
    file_names = os.listdir(fpath)
    if len(file_names) == 0:
        return []
    env_name = get_env_name(name)
    for file_name in file_names:
        if "trajs.pkl" not in file_name:
            continue
        tmp = get_env_name(file_name)
        if tmp == env_name:
            all_trajs_names.append(file_name)
    return all_trajs_names

def get_mean(trajs):
    rets  = deepcopy(trajs[2])
    return sum(rets) / len(rets)

def get_worst_models_names(all_trajs_names, num):
    names = []
    for trajs_name in all_trajs_names: # *4
        trajs_file_name = osp.join(path, trajs_path, trajs_name)
        tmp = []
        with open(trajs_file_name, 'rb') as f: 
            all_trajs = pickle.load(f)    
        for trajs in all_trajs:
            x = (trajs[1], get_mean(trajs))
            print(x)
            if len(tmp) < num:
                tmp.append(x)
            else:
                for j in range(num):
                    if x[1] < tmp[j][1]:
                        tmp[j] = x
                        break
        print(tmp)
        for j in range(len(tmp)):
            names.append(tmp[j][0])
    return names
                
def get_trajs_thr(trajs, top_ratio):
    rets = deepcopy(trajs[2])
    rets.sort()
    return rets[int(len(rets) * (1-top_ratio))]


def sample_midpoint_from_trajs(trajs):
    traj_id = random.randrange(len(trajs[-1]))
    midpoint_id = random.randrange(len(trajs[-1][traj_id]))
    return traj_id, midpoint_id
    

# total interactions will be test_num * step_num * number of agents * number of agents that generate trajs
# 20 * 500 * 10 * 10
def run_extra_steps_on_trajs(model, trajs, test_num, step_num, thr): # 1 model vs trajs of 1 model
    max_len = 999
    ret = []
    env_name = get_env_name(trajs[1])
    env = gym.make(env_name)
    for _ in range(test_num):
        while True:
            traj_id, midpoint_id = sample_midpoint_from_trajs(trajs)
            ep_len = trajs[-1][traj_id][midpoint_id][0]
            if (trajs[2][traj_id] >= (thr-0.1)) and (ep_len + step_num < max_len):
                break
        restore_state(env, trajs[-1][traj_id][midpoint_id][1]) 
        ep_len = trajs[-1][traj_id][midpoint_id][0] 
        tmp = run_extra_steps(env, ep_len, model[2], model[1], step_num)
        old_ret = sum(trajs[-2][traj_id][ep_len : ep_len + step_num])
        ret.append((traj_id, midpoint_id, old_ret, tmp))
    return ret


def test_all_condinue(path, trajs_path, name, test_num, step_num, top_ratio=0.5):
    models = get_models(path, name) # same algorithm
    for model in models: 
        print("model name: ", model[1])
    all_trajs_names = get_all_traj_names_with_same_env(path, trajs_path, name)
    print("all traj file name: ", all_trajs_names)
    worst_model_names = get_worst_models_names(all_trajs_names, 2)
    ret = []
    for trajs_name in all_trajs_names: # *4
        trajs_file_name = osp.join(path, trajs_path, trajs_name)
        print("start runing on: ", trajs_file_name)
        with open(trajs_file_name, 'rb') as f: 
            all_trajs = pickle.load(f)
        for trajs in all_trajs: # * 10
            if trajs[1] in worst_model_names:
                continue
            thr = get_trajs_thr(trajs, top_ratio)
            print("running with ", trajs[0], trajs[1], thr) 
            # print(trajs[0], trajs[1], type(trajs[2][0]), len(trajs[3][0]),  len(trajs[4][0])) # float , 1000, 100
            for model in models: # * 12 or 11
                if model[1] == trajs[1] or model[1] in worst_model_names:
                    continue
                tmp = run_extra_steps_on_trajs(model, trajs, test_num, step_num, thr)
                ret.append((trajs[0], trajs[1], model[0], model[1], tmp))    
            del trajs
    return ret

path = '/home/lclan/spinningup/data/'
trajs_path = 'trajs'
name = 'Humanoid-v3_sac_base'
step_num = 500
test_num = 20
results = test_all_condinue(path, trajs_path, name, test_num, step_num)
    

/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1206/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1209/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1211/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1200/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1205/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1208/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1203/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1204/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1210/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3_sac_base_s1201/pyt_save/model.pt
/home/lclan/spinningup/data/Humanoid-v3_sac_base/Humanoid-v3



running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s118 5560.746364243695
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s119 5978.221790305104
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s113 5575.999862139598
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s117 5387.061180316021
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s114 5458.346372005501
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s112 5662.7289161820045
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s110 5353.542318508195
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s121 5423.005697551584
running with  Humanoid-v3_td3_base Humanoid-v3_td3_base_s111 5775.427509245709
start runing on:  /home/lclan/spinningup/data/trajs/Humanoid-v3_sac_base_400_trajs.pkl
running with  Humanoid-v3_sac_base Humanoid-v3_sac_base_s1206 6314.830792734732
running with  Humanoid-v3_sac_base Humanoid-v3_sac_base_s1209 5625.6508803571905
running with  Humanoid-v3_sac_base Human

In [37]:
print(len(results))

380


In [75]:
def is_fail(data):
    if data[3][1] < 500:
        return 1.0
    return 0.0
    
def get_fail_rate(result): # result of one agent to one agent's trajs
    ret = 0
    for data in result:
        ret += is_fail(data)
    return ret/len(result)
    
    
stats = {}
for result in results:
    if result[0] not in stats:
        stats[result[0]] = []
    stats[result[0]].append(get_fail_rate(result[4]))



In [76]:
for key in stats.keys():
    np_arr = np.array(stats[key])
    print(key, np.mean(np_arr), np.std(np_arr))

Humanoid-v3_td3_base 0.2735 0.27581288947400556
Humanoid-v3_sac_base 0.35333333333333344 0.34863224814185567
sgld_ppo_humanoid 0.8 0.22649503305812246
vanilla_ppo_humanoid 0.44833333333333325 0.32441828281128393


In [78]:
path = '/home/lclan/spinningup/data/'
data_path = 'trajs/test_continue/'
env_names = ['Humanoid-v3', 'Ant-v3',  'Walker2d-v3','HalfCheetah-v3' , 'Hopper-v3']
def get_env_name(name):
    if ('humanoid' in name) or ('Humanoid' in name):
        return 'Humanoid-v3'
    if ('halfcheetah' in name) or ('HalfCheetah' in name):
        return 'HalfCheetah-v3'
    if ('ant' in name) or ('Ant' in name):
        return 'Ant-v3'
    if ('hopper' in name) or  ('Hopper' in name) :
        return 'Hopper-v3'
    if ('walker' in name) or ('Walker' in name) :
        return 'Walker2d-v3'
    return 'unknown'

def load_all_same_env_results(env_name, path, data_path):
    fpath = osp.join(path, data_path)
    print(fpath)
    file_names = os.listdir(fpath)
    rets = []
    for file_name in file_names:
        if ".pkl" in file_name and get_env_name(file_name) == env_name:
            print(file_name)
            file_name = osp.join(fpath, file_name)
            with open(file_name, 'rb') as f: 
                ret = pickle.load(f)
            rets.append(ret)
    return rets


def load_all_results(env_names, path, data_path):
    ret = {}
    for env_name in env_names:
        ret[env_name] = load_all_same_env_results(env_name, path, data_path)
    return ret

all_rets = load_all_results(env_names, path, data_path)
    



/home/lclan/spinningup/data/trajs/test_continue/
Humanoid-v3_td3_base_s500_tr50.pkl
sgld_ppo_humanoid_s500_tr50.pkl
vanilla_ppo_humanoid_s500_tr50.pkl
Humanoid-v3_sac_base_s500_tr50.pkl
/home/lclan/spinningup/data/trajs/test_continue/
Ant-v3_sac_base_s500_tr50.pkl
atla_ppo_ant_s500_tr50.pkl
vanilla_ppo_ant_s500_tr50.pkl
Ant-v3_td3_base_s500_tr50.pkl
/home/lclan/spinningup/data/trajs/test_continue/
vanilla_ppo_walker_s500_tr50.pkl
Walker2d-v3_td3_base_s500_tr50.pkl
Walker2d-v3_sac_base_s500_tr50.pkl
atla_ppo_walker_s500_tr50.pkl
/home/lclan/spinningup/data/trajs/test_continue/
atla_ppo_halfcheetah_s500_tr50.pkl
HalfCheetah-v3_sac_base_s500_tr50.pkl
HalfCheetah-v3_td3_base_s500_tr50.pkl
vanilla_ppo_halfcheetah_s500_tr50.pkl
/home/lclan/spinningup/data/trajs/test_continue/
vanilla_ppo_hopper_s500_tr50.pkl
Hopper-v3_sac_base_s500_tr50.pkl
atla_ppo_hopper_s500_tr50.pkl
Hopper-v3_td3_base_s500_tr50.pkl


In [85]:
def print_results(results):
    stats = {}
    for result in results:
        if (result[0], result[2]) not in stats:
            stats[(result[0], result[2])] = []
        stats[(result[0], result[2])].append(get_fail_rate(result[4]))
    for key in stats.keys():
        np_arr = np.array(stats[key])
        print(key, np.mean(np_arr), np.std(np_arr))
    return stats

def print_all_results(rets):
    for env_key in rets.keys():
        for data in rets[env_key]:
            print_results(data)

print_all_results(all_rets)

('Humanoid-v3_td3_base', 'Humanoid-v3_td3_base') 0.441388888888889 0.278616649994793
('Humanoid-v3_sac_base', 'Humanoid-v3_td3_base') 0.7292500000000001 0.2423338348229566
('sgld_ppo_humanoid', 'Humanoid-v3_td3_base') 0.9355 0.08232405480781421
('vanilla_ppo_humanoid', 'Humanoid-v3_td3_base') 0.6883333333333332 0.27154905429569975
('Humanoid-v3_td3_base', 'sgld_ppo_humanoid') 0.589 0.3325342087665568
('Humanoid-v3_sac_base', 'sgld_ppo_humanoid') 0.54625 0.3447712103700076
('sgld_ppo_humanoid', 'sgld_ppo_humanoid') 0.6863888888888889 0.2957373344501068
('vanilla_ppo_humanoid', 'sgld_ppo_humanoid') 0.5652777777777779 0.32871159354119994
('Humanoid-v3_td3_base', 'vanilla_ppo_humanoid') 0.6386111111111111 0.29031020552522563
('Humanoid-v3_sac_base', 'vanilla_ppo_humanoid') 0.6894444444444444 0.2566065345626391
('sgld_ppo_humanoid', 'vanilla_ppo_humanoid') 0.8527777777777777 0.2016406472348369
('vanilla_ppo_humanoid', 'vanilla_ppo_humanoid') 0.6659722222222222 0.27164319797516157
('Humanoid