-
Notifications
You must be signed in to change notification settings - Fork 0
/
PPO_trainer.py
47 lines (40 loc) · 1.11 KB
/
PPO_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gym
import gym_ur3
import time
from stable_baselines3 import PPO
import imageio
import pybullet as p
def test():
env = gym.make('ur3-v0')
# obs = env.reset()
for _ in range(10):
ob = env.reset()
for _ in range(20):
p.stepSimulation()
time.sleep(1/24)
# print(1)
def train():
env = gym.make('ur3-v0')
model = PPO('MlpPolicy',env,verbose=1)
model.learn(total_timesteps=300000)
model.save('../result/ppo.zip')
def result():
frames = []
env = gym.make('ur3-v0')
model = PPO.load('../result/ppo.zip')
for _ in range(10):
obs = env.reset()
for i in range(50):
action,_state=model.predict(obs,deterministic=True)
obs, _, done, _ = env.step(action)
# print(done)
frames.append(env.render())
if done:
# obs = env.reset()
time.sleep(1/30)
break
imageio.mimsave('../UJI-3D/result/ur3.gif', frames, 'GIF', duration=0.1)
if __name__ == '__main__':
# test()
train()
# result()