Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add PC+MCTS code #603

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .application_entry_drex_collect_data import drex_collecting_data
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc_mcts import serial_pipeline_pc_mcts
173 changes: 173 additions & 0 deletions ding/entry/serial_entry_pc_mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Union, Optional, Tuple
import os
import torch
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy
from torch.utils.data import DataLoader, Dataset
import pickle

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed


class MCTSPCDataset(Dataset):

def __init__(self, data_dic, seq_len=4, hidden_state_noise=0):
self.observations = data_dic['obs']
self.actions = data_dic['actions']
self.hidden_states = data_dic['hidden_state']
self.seq_len = seq_len
self.length = len(self.observations) - seq_len - 1
self.hidden_state_noise = hidden_state_noise

def __getitem__(self, idx):
"""
Assume the trajectory is: o1, h2, h3, h4
"""
hidden_states = list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1]))
actions = torch.tensor(list(reversed(self.actions[idx: idx + self.seq_len])))
if self.hidden_state_noise > 0:
for i in range(len(hidden_states)):
hidden_states[i] += self.hidden_state_noise * torch.randn_like(hidden_states[i])
return {
'obs': self.observations[idx],
'hidden_states': hidden_states,
'action': actions
}

def __len__(self):
return self.length


def load_mcts_datasets(path, seq_len, batch_size=32, hidden_state_noise=0):
with open(path, 'rb') as f:
dic = pickle.load(f)
tot_len = len(dic['obs'])
train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()}
test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()}
return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True
, batch_size=batch_size), \
DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True,
batch_size=batch_size)


def serial_pipeline_pc_mcts(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
max_iter=int(1e6),
) -> Union['Policy', bool]: # noqa
r"""
Overview:
Serial pipeline entry of procedure cloning with MCTS.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
Returns:
- policy (:obj:`Policy`): Converged policy.
- convergence (:obj:`bool`): whether il training is converged
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)

# Env, Policy
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
# Random seed
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])

# Main components
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len,
batch_size=cfg.policy.learn.batch_size,
hidden_state_noise=cfg.policy.learn.hidden_state_noise)
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

# ==========
# Main loop
# ==========
learner.call_hook('before_run')
stop = False
epoch_per_test = 10
criterion = torch.nn.CrossEntropyLoss()
hidden_state_criterion = torch.nn.MSELoss()
for epoch in range(cfg.policy.learn.train_epoch):
# train
for i, train_data in enumerate(dataloader):
train_data['obs'] = train_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.
learner.train(train_data)
if learner.train_iter >= max_iter:
stop = True
break
if epoch % 69 == 0:
policy._optimizer.param_groups[0]['lr'] /= 10
if stop:
break

if epoch % epoch_per_test == 0:
losses = []
acces = []
for _, test_data in enumerate(test_dataloader):
logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.)
loss = criterion(logits, test_data['action'][:, -1].cuda()).item()
preds = torch.argmax(logits, dim=-1)
acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0]

losses.append(loss)
acces.append(acc)
tb_logger.add_scalar('learner_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter)
tb_logger.add_scalar('learner_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter)

losses = []
acces = []
for _, test_data in enumerate(dataloader):
logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.)
loss = criterion(logits, test_data['action'][:, -1].cuda()).item()
preds = torch.argmax(logits, dim=-1)
acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0]

losses.append(loss)
acces.append(acc)
tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter)
tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter)

# Test for forward eval function.
# losses = []
# mse_losses = []
# acces = []
# for _, test_data in enumerate(dataloader):
# test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda()
# logits, pred_hidden_states, hidden_state_embeddings = policy._model.test_forward_eval(
# test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.,
# test_hidden_states
# )
# loss = criterion(logits, test_data['action'].cuda()).item()
# mse_loss = hidden_state_criterion(pred_hidden_states, hidden_state_embeddings).item()
# preds = torch.argmax(logits, dim=-1)
# acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0]
#
# losses.append(loss)
# acces.append(acc)
# mse_losses.append(mse_loss)
# tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter)
# tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter)
# tb_logger.add_scalar('learner_iter/recurrent_train_mse_loss', sum(mse_losses) / len(mse_losses), learner.train_iter)
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
learner.call_hook('after_run')
print('final reward is: {}'.format(reward))
return policy, stop
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloning
from .procedure_cloning import ProcedureCloningMCTS
Loading