In [None]:
import torch
import numpy as np
from env import DroneEnv
from BC import BC, BCAgent
from PPO import PPO
from PPO import RolloutBuffer
from Trainer import Trainer

# PPO

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = DroneEnv()
state_shape = env.observation_space.shape
action_shape = env.action_space.shape

In [None]:
ppo = PPO(
    state_shape=env.observation_space.shape,
    action_shape=env.action_space.shape,
    seed=123,
    env = env,
    rollout_length=512,
    batch_size=512,
    device=device,
)



trainer = Trainer(
    env=env,
    algo=ppo,
    seed=123,
    num_steps= 100000,
    eval_interval=1000,
    num_eval_episodes=10,
    max_episode_steps= 30,
    save_dir = "/save_dir/ppo"
)

In [None]:
trainer.train()

In [None]:
trainer.plot()

In [None]:
trainer.save_weights(50000)
trainer.save_weights(100000)

# 模倣学習 + PPO

In [None]:
def collect_data(env, num_episodes=500, max_episode_steps=500):
    data = []
    for _ in range(num_episodes):
        state = env.reset()
        done = False
        timestep = 0
        episode_data = []
        while not done and timestep < max_episode_steps:
            action = env.action_space.sample()
            next_state, reward, done, _ = env.step(action)

            # 目標物との距離を計算
            target_distance = np.linalg.norm(env.target_position - env.drone_position)

            # 目標物との距離が小さくなる方向に進んだ場合のみデータを保存
            if target_distance < np.linalg.norm(env.target_position - state[:3]):
                episode_data.append((state, action, reward, next_state, done))

            state = next_state
            timestep += 1

        data.extend(episode_data)
    return data

In [None]:
env = DroneEnv()

# データ収集
data = collect_data(env, num_episodes=30000, max_episode_steps=200)

buffer_exp = RolloutBuffer(buffer_size=len(data), state_shape=env.observation_space.shape, action_shape=env.action_space.shape)
for state, action, reward, next_state, done in data:
    buffer_exp.append(state, action, reward, done, 0)


bc = BC(buffer_exp, 
        env.observation_space.shape, 
        env.action_space.shape, 
        seed=123, 
        batch_size=512)


trainer = Trainer(env=env,
                  algo=bc,
                  seed=123,
                  num_steps=30000,
                  eval_interval=100,
                  num_eval_episodes=10,
                  max_episode_steps= 30,
                  save_dir = "/save_dir/BC"
                  )
trainer.train()

In [None]:
trainer.plot()

In [None]:
# エラー起こりますが問題ないです。
trainer.save_weights(30000)

In [None]:
ppo = PPO(
    state_shape=env.observation_space.shape,
    action_shape=env.action_space.shape,
    seed=123,
    env = env,
    rollout_length=512,
    batch_size=512,
    device=device,
)




# 重みをロードする
ppo.actor.load_state_dict(torch.load('/save_dir/BC/actor_30000.pth'))




trainer = Trainer(
    env=env,
    algo=ppo,
    seed=123,
    num_steps= 50000,
    eval_interval=1000,
    num_eval_episodes=10,
    max_episode_steps= 30,
    save_dir = "/save_dir/ppo_bc"
)

In [None]:
trainer.train()

In [None]:
trainer.plot()

In [None]:
trainer.save_weights(50000)