In [24]:
from stable_baselines3 import SAC
import pybullet_envs
import numpy as np
import torch

In [41]:
from src.pybullet_utils import CnfndWrapper, rollout, noisy_rollout, T, dynamics

In [3]:
from src.models import Model
from src.learners.bc import BC
from src.learners.doubil_pybullet import DoubIL
from src.learners.residuil import ResiduIL

In [6]:
import warnings
warnings.filterwarnings('ignore')
from copy import deepcopy
import gym

In [86]:
env = gym.make("AntBulletEnv-v0")
e = 'ant'

In [87]:
from stable_baselines3 import SAC
expert_net = SAC.load("./src/experts/ant_expert")
def expert(s):
    return expert_net.predict(s, deterministic=True)

In [88]:
env.reset()
denoised_expert_trajs = []
Js = []
for _ in range(10):
    s_traj, a_traj, J = rollout(expert, env)
    denoised_expert_trajs.append((s_traj, a_traj))
    Js.append(J)

In [89]:
print(Js, np.mean(Js))

[2610.5423115705216, 2610.332346919023, 2600.592855553799, 2515.4703539535344, 2582.8462316244977, 2604.585722396691, 2592.4285408974274, 2629.381407214912, 2638.66988116291, 2509.9612106718205] 2589.481086196514


In [92]:
env.reset()
expert_trajs2 = []
Js = []
for _ in range(25):
    s_traj, a_traj, J = noisy_rollout(expert, env, sigma=2)
    expert_trajs2.append((s_traj, a_traj))
    Js.append(J)

In [93]:
print(np.mean(Js))

-1873.018392506913


In [94]:
# # Generate test set
env.reset()
denoised_test_s = []
denoised_test_a = []
P = []
V = []
C = []
for _ in range(100):
    s_traj, _, J, p, v, c = rollout(expert, env, full_state=True)
    print(J)
    denoised_test_s.append(s_traj)
    denoised_test_a.append(expert(s_traj)[0])
    P.append(p)
    V.append(v)
    C.append(c)
np.savez("data/denoised_test_{0}".format(e), s=denoised_test_s, a=denoised_test_a, P=P, V=V, C=C)

2544.945648855827
2593.468374573578
2615.621506275431
2602.64008701095
2545.436373113281
2597.5418054103484
2602.9585005406157
2561.0710072102684
2557.1138655724494
2610.476785356322
2601.5879726877565
2635.863999330733
2609.609505511137
2562.0125403224342
2590.8638059027876
2603.677988491415
2636.56603061404
596.0115787341166
2611.95926523942
2615.5010516192538
2507.9657553745847
2619.4387009665384
2506.6339257827917
2632.0990325376365
2645.0465274765
2613.6644817375213
2609.879151392323
2588.786393671973
2616.314265591755
2573.0184763649613
2602.5730310341523
2487.026566818345
2560.9267189738093
2618.7609077070656
2641.9758308744135
2446.649730112676
2534.6396131687798
2603.421665650239
2558.0048171205826
2625.903365280014
2639.059894379724
2570.843795886615
2514.16225757487
2600.4148680642293
2599.3240803730478
2474.202240528496
2511.412608942964
2543.878153299522
2549.0910409293883
2558.281172280432
2586.508882670907
2225.90753346523
2619.310279997673
2562.0518235268632
2474.471173

In [95]:
noisy_test_s = []
noisy_test_a = []
P = []
V = []
C = []
for _ in range(100):
    s_traj, _, J, p, v, c = noisy_rollout(expert, env, sigma=2, full_state=True)
    print(J)
    noisy_test_s.append(s_traj)
    noisy_test_a.append(expert(s_traj)[0])
    P.append(p)
    V.append(v)
    C.append(c)
np.savez("data/noisy_test_{0}".format(e), s=noisy_test_s, a=noisy_test_a, P=P, V=V, C=C)

-1739.2496679221674
-1884.7960371097422
-2082.1914232119793
-2080.236299186714
-1618.2266911935576
-1712.5033019162363
-1972.7175239836251
-1928.6110347552103
-1999.775606290051
-1417.0341926317078
-1460.6778430203415
-2233.9297839542196
-2110.3911973635663
-1935.1535048310438
-1843.8175272322148
-2141.0365833221967
-1677.3689736282702
-1629.0364215292093
-2086.1825592227915
-1651.0032674557851
-1866.263228251897
-1882.495035261848
-1725.7457258488696
-1733.0626621930492
-2097.3696528829737
-1818.5919847378948
-1789.2087841876207
-1831.9506775540003
-1807.2953576307643
-1687.8785477932763
-2072.7230389397378
-2012.731917603367
-1695.0432429474884
-1592.2542118405981
-1684.3174838323432
-2030.4771270318083
-2076.052762553642
-1946.9940578481176
-1634.2958068729843
-1958.7602760048262
-1916.110943994758
-1713.9925679716619
-1947.7262477866336
-2083.8619284877727
-1947.4525221366896
-1530.4981051747134
-1680.5620393826541
-1700.125762570866
-1681.7342681690725
-1513.7300363230217
-1646.22

In [96]:
# Generate training sets
for i in range(0, 5):
    for size in [10, 20, 30, 40, 50]:
        print(i, size)
        s_trajs = []
        a_trajs = []
        P = []
        V = []
        C = []
        for j in range(size):
            print(j)
            s_traj, a_traj, _, p, v, c, = noisy_rollout(expert, env, sigma=2, full_state=True)
            s_trajs.append(s_traj)
            a_trajs.append(a_traj)
            P.append(p)
            V.append(v)
            C.append(c)
        np.savez("data/train_{0}_{1}_{2}".format(size, i, e), s=s_trajs, a=a_trajs, P=P, V=V, C=C)

0 10
0
1
2
3
4
5
6
7
8
9
0 20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 30
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
0 40
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
0 50
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
1 10
0
1
2
3
4
5
6
7
8
9
1 20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 30
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
1 40
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
1 50
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
2 10
0
1
2
3
4
5
6
7
8
9
2 20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 30
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


In [97]:
denoised_test_s = np.load("./data/denoised_test_{0}.npz".format(e), allow_pickle=True)["s"]
noisy_test_s = np.load("./data/noisy_test_{0}.npz".format(e), allow_pickle=True)["s"]

denoised_test_a = np.load("./data/denoised_test_{0}.npz".format(e), allow_pickle=True)["a"]
noisy_test_a = np.load("./data/noisy_test_{0}.npz".format(e), allow_pickle=True)["a"]

denoised_test = [[denoised_test_s[i], denoised_test_a[i]] for i in range(len(denoised_test_a))]
noisy_test = [[noisy_test_s[i], noisy_test_a[i]] for i in range(len(noisy_test_a))]

In [98]:
def mse(pi, dataset):
    total = 0
    for (s_traj, a_traj) in dataset:
        total += np.linalg.norm(pi(s_traj) - a_traj)
    return total / len(dataset)

In [99]:
def eval_policy(pi, env, noisy=False):
    Js = []
    for _ in range(100):
        if noisy:
            s_traj, a_traj, J = noisy_rollout(pi, env, sigma=2)
        else:
            s_traj, a_traj, J = rollout(pi, env)
        Js.append(J)
    return np.mean(Js)

In [100]:
bc_mse_noisy = []
bc_mse_denoised = []
bc_J_noisy = []
bc_J_denoised = []

doubil_mse_noisy = []
doubil_mse_denoised = []
doubil_J_noisy = []
doubil_J_denoised = []

residuil_mse_noisy = []
residuil_mse_denoised = []
residuil_J_noisy = []
residuil_J_denoised = []

In [None]:
for i in range(0, 4):
    bc_mse_noisy_i = []
    bc_mse_denoised_i = []
    bc_J_noisy_i = []
    bc_J_denoised_i = []

    doubil_mse_noisy_i = []
    doubil_mse_denoised_i = []
    doubil_J_noisy_i = []
    doubil_J_denoised_i = []

    residuil_mse_noisy_i = []
    residuil_mse_denoised_i = []
    residuil_J_noisy_i = []
    residuil_J_denoised_i = []
    
    for size in [10, 20, 30, 40, 50]:
        s_trajs = np.load("./data/train_{0}_{1}_{2}.npz".format(size, i, e), allow_pickle=True)["s"]
        a_trajs = np.load("./data/train_{0}_{1}_{2}.npz".format(size, i, e), allow_pickle=True)["a"]
        p_trajs = np.load("./data/train_{0}_{1}_{2}.npz".format(size, i, e), allow_pickle=True)["P"]
        v_trajs = np.load("./data/train_{0}_{1}_{2}.npz".format(size, i, e), allow_pickle=True)["V"]
        c_trajs = np.load("./data/train_{0}_{1}_{2}.npz".format(size, i, e), allow_pickle=True)["C"]
        expert_trajs = [[s_trajs[i], a_trajs[i], p_trajs[i], v_trajs[i], c_trajs[i]] for i in range(len(s_trajs))]
        
    
        pi_BC = BC(expert_trajs, Model(env.observation_space.shape[0], env.action_space.shape[0]), wd=5e-3)
        bc_mse_noisy_i.append(mse(lambda s: pi_BC(torch.from_numpy(s).float()).detach().numpy(), noisy_test))
        bc_mse_denoised_i.append(mse(lambda s: pi_BC(torch.from_numpy(s).float()).detach().numpy(), denoised_test))
        bc_J_noisy_i.append(eval_policy(lambda s: pi_BC(torch.from_numpy(s).float()).detach().numpy(), env, noisy=True))
        bc_J_denoised_i.append(eval_policy(lambda s: pi_BC(torch.from_numpy(s).float()).detach().numpy(), env, noisy=False))
        print('BC',
              size,
              bc_mse_noisy_i[-1],
              bc_mse_denoised_i[-1],
              bc_J_noisy_i[-1],
              bc_J_denoised_i[-1])
        pi_DoubIL = DoubIL(expert_trajs,
                   Model(env.observation_space.shape[0], env.action_space.shape[0]),
                   lambda a, b, c, d: dynamics(a, b, c, d, env), pi_BC=pi_BC, nsamp=8, wd=5e-3)
        doubil_mse_noisy_i.append(mse(lambda s: pi_DoubIL(torch.from_numpy(s).float()).detach().numpy(), noisy_test))
        doubil_mse_denoised_i.append(mse(lambda s: pi_DoubIL(torch.from_numpy(s).float()).detach().numpy(), denoised_test))
        doubil_J_noisy_i.append(eval_policy(lambda s: pi_DoubIL(torch.from_numpy(s).float()).detach().numpy(), env, noisy=True))
        doubil_J_denoised_i.append(eval_policy(lambda s: pi_DoubIL(torch.from_numpy(s).float()).detach().numpy(), env, noisy=False))
        print('doubil',
              size,
              doubil_mse_noisy_i[-1],
              doubil_mse_denoised_i[-1],
              doubil_J_noisy_i[-1],
              doubil_J_denoised_i[-1])
        
        pi_ResiduIL = ResiduIL(expert_trajs,
                       Model(env.observation_space.shape[0], env.action_space.shape[0]),
                       Model(env.observation_space.shape[0], env.action_space.shape[0]), wd=5e-3, bc_reg=0)
        residuil_mse_noisy_i.append(mse(lambda s: pi_ResiduIL(torch.from_numpy(s).float()).detach().numpy(), noisy_test))
        residuil_mse_denoised_i.append(mse(lambda s: pi_ResiduIL(torch.from_numpy(s).float()).detach().numpy(), denoised_test))
        residuil_J_noisy_i.append(eval_policy(lambda s: pi_ResiduIL(torch.from_numpy(s).float()).detach().numpy(), env, noisy=True))
        residuil_J_denoised_i.append(eval_policy(lambda s: pi_ResiduIL(torch.from_numpy(s).float()).detach().numpy(), env, noisy=False))
        print('residuil',
              size,
              residuil_mse_noisy_i[-1],
              residuil_mse_denoised_i[-1],
              residuil_J_noisy_i[-1],
              residuil_J_denoised_i[-1])
        
    bc_mse_noisy.append(bc_mse_noisy_i)
    bc_mse_denoised.append(bc_mse_denoised_i)
    bc_J_noisy.append(bc_J_noisy_i)
    bc_J_denoised.append(bc_J_denoised_i)
    np.savez("data/bc_mse_noisy_{0}".format(e), bc_mse_noisy)
    np.savez("data/bc_mse_denoised_{0}".format(e), bc_mse_denoised)
    np.savez("data/bc_J_noisy_{0}".format(e), bc_J_noisy)
    np.savez("data/bc_J_denoised_{0}".format(e), bc_J_denoised)
    
    doubil_mse_noisy.append(doubil_mse_noisy_i)
    doubil_mse_denoised.append(doubil_mse_denoised_i)
    doubil_J_noisy.append(doubil_J_noisy_i)
    doubil_J_denoised.append(doubil_J_denoised_i)
    np.savez("data/doubil_mse_noisy_{0}".format(e), doubil_mse_noisy)
    np.savez("data/doubil_mse_denoised_{0}".format(e), doubil_mse_denoised)
    np.savez("data/doubil_J_noisy_{0}".format(e), doubil_J_noisy)
    np.savez("data/doubil_J_denoised_{0}".format(e), doubil_J_denoised)
        
    residuil_mse_noisy.append(residuil_mse_noisy_i)
    residuil_mse_denoised.append(residuil_mse_denoised_i)
    residuil_J_noisy.append(residuil_J_noisy_i)
    residuil_J_denoised.append(residuil_J_denoised_i)
    np.savez("data/residuil_mse_noisy_{0}".format(e), residuil_mse_noisy)
    np.savez("data/residuil_mse_denoised_{0}".format(e), residuil_mse_denoised)
    np.savez("data/residuil_J_noisy_{0}".format(e), residuil_J_noisy)
    np.savez("data/residuil_J_denoised_{0}".format(e), residuil_J_denoised)

BC Data (10000, 28) (10000, 8)
BC 10 189.17252990722656 337.0806532287598 -3098.8924585378845 -158.5860793952529
Done w/ BC
IV Data (9990, 28) (9990, 8)
doubil 10 183.85316482543945 158.70012229919433 -3089.2649673967967 -134.7565873914581
residuil 10 103.69670715332032 156.75333480834962 -2256.6802740353673 293.74834198103184
BC Data (20000, 28) (20000, 8)
BC 20 135.63109382629395 144.6160302734375 -2855.5249583624945 30.97797753622766
Done w/ BC
IV Data (19980, 28) (19980, 8)
doubil 20 89.93051605224609 84.2939697265625 -2316.320652110159 268.0262068710541
residuil 20 76.53827896118165 83.61533576965331 -2083.978501692212 538.6482931653599
BC Data (30000, 28) (30000, 8)
BC 30 129.76204193115234 109.56245193481445 -2993.172406919557 99.02259471853695
Done w/ BC
IV Data (29970, 28) (29970, 8)
doubil 30 61.732873878479005 70.47049655914307 -2040.2022866530722 653.785070992511
residuil 30 63.87045913696289 86.09256423950195 -1870.6138270357683 596.9400083686988
BC Data (40000, 28) (40000