Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
AltmanD committed Sep 18, 2023
1 parent 87294d5 commit 4eebb53
Show file tree
Hide file tree
Showing 15 changed files with 152 additions and 103 deletions.
3 changes: 1 addition & 2 deletions ding/model/template/tests/test_acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from ding.model.template import ACER
from ding.torch_utils import is_differentiable


B = 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestACER:

Expand Down Expand Up @@ -39,4 +39,3 @@ def test_ACER(self, obs_shape, act_shape):
outputs = {**outputs_a, **outputs_c}
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

1 change: 0 additions & 1 deletion ding/model/template/tests/test_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ding.model.template import BCQ
from ding.torch_utils import is_differentiable


B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
Expand Down
20 changes: 9 additions & 11 deletions ding/model/template/tests/test_decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
args = list(product(*[action_space, state_encoder]))
args.pop(1)


@pytest.mark.unittest
@pytest.mark.parametrize('action_space, state_encoder', args)
def test_decision_transformer(action_space, state_encoder):
Expand All @@ -36,7 +37,7 @@ def test_decision_transformer(action_space, state_encoder):

is_continuous = True if action_space == 'continuous' else False
if state_encoder:
timesteps = torch.randint(0, 100, [B, 3*T-1, 1], dtype=torch.long) # B x T
timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
else:
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
if isinstance(state_dim, int):
Expand Down Expand Up @@ -91,15 +92,12 @@ def test_decision_transformer(action_space, state_encoder):

if state_encoder:
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg,
DT_model.state_encoder
]
)
action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder]
)
else:
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
)
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
)
3 changes: 1 addition & 2 deletions ding/model/template/tests/test_edac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from ding.model.template import EDAC
from ding.torch_utils import is_differentiable


B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestEDAC:

Expand Down Expand Up @@ -55,4 +55,3 @@ def test_EDAC(self, obs_shape, act_shape):
assert outputs_a['logit'][1].shape == (B, *act_shape)
outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]}
self.output_check(model.actor, outputs)

25 changes: 16 additions & 9 deletions ding/model/template/tests/test_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def test_ngu(self, obs_shape, act_shape):
else:
inputs_prev_action = torch.ones(B, *act_shape).long()
inputs_prev_reward_extrinsic = torch.randn(B, H, 1)
inputs_beta = 2*torch.ones([4,4], dtype=torch.long)
inputs = {'obs': inputs_obs, 'prev_state': None,
'prev_action': inputs_prev_action, 'prev_reward_extrinsic':inputs_prev_reward_extrinsic,
'beta': inputs_beta}
inputs_beta = 2 * torch.ones([4, 4], dtype=torch.long)
inputs = {
'obs': inputs_obs,
'prev_state': None,
'prev_action': inputs_prev_action,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}

model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
Expand All @@ -48,11 +52,14 @@ def test_ngu(self, obs_shape, act_shape):
assert outputs['logit'].shape == (B, *act_shape, *act_shape)
self.output_check(model, outputs['logit'])

inputs = {'obs': inputs_obs, 'prev_state': None,
'action': inputs_prev_action,
'reward': inputs_prev_reward_extrinsic,
'prev_reward_extrinsic':inputs_prev_reward_extrinsic,
'beta': inputs_beta}
inputs = {
'obs': inputs_obs,
'prev_state': None,
'action': inputs_prev_action,
'reward': inputs_prev_reward_extrinsic,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}
model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
assert isinstance(outputs, dict)
Expand Down
6 changes: 3 additions & 3 deletions ding/model/template/tests/test_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def test_drqn_inference_res_link(self, obs_shape, act_shape):

@pytest.mark.tmp
def test_GTrXLDQN(self):
obs_dim, seq_len, bs, action_dim = [4,64,64], 64, 32, 4
obs_dim, seq_len, bs, action_dim = [4, 64, 64], 64, 32, 4
obs = torch.rand(seq_len, bs, *obs_dim)
model = GTrXLDQN(obs_dim, action_dim,encoder_hidden_size_list=[16,16,16])
model = GTrXLDQN(obs_dim, action_dim, encoder_hidden_size_list=[16, 16, 16])
outputs = model(obs)
assert isinstance(outputs, dict)
assert isinstance(outputs, dict)
7 changes: 4 additions & 3 deletions ding/model/template/tests/test_qtran.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from ding.model.template import QTran
from ding.torch_utils import is_differentiable


@pytest.mark.unittest
def test_qtran():
B = 1
obs_shape = (1,64,64)
obs_shape = (1, 64, 64)
act_shape = 2
# inputs = {
# 'obs': {'agent_state': torch.randn(B, *obs_shape),
# 'global_state': torch.randn(B, *obs_shape)},
# 'prev_state': [[torch.randn(1, 1, *obs_shape) for __ in range(1)] for _ in range(1)],
# 'action': torch.randn(B, act_shape)
# }
model = QTran(1, obs_shape, 4*64*64, act_shape, [8, 8, 8], 5)
# model.forward(inputs)
model = QTran(1, obs_shape, 4 * 64 * 64, act_shape, [8, 8, 8], 5)
# model.forward(inputs)
3 changes: 2 additions & 1 deletion ding/model/template/tests/test_vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def model_check(model, inputs):
if model.action_space == 'continuous':
outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum()
elif model.action_space == 'hybrid':
outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum()
outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum(
) + logit['action_args']['sigma'].sum()
else:
if model.multi_head:
outputs = value.sum() + sum([t.sum() for t in logit])
Expand Down
2 changes: 1 addition & 1 deletion ding/model/wrapper/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_transformer_segment_wrapper(self):
out = model.forward(inputs1)
info = model.info('info')
info = model.info('x')

def test_transformer_memory_wrapper(self):
seq_len, bs, obs_shape = 12, 8, 32
layer_num, memory_len, emb_dim = 3, 4, 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
action_shape=1,
),
learn=dict(
batch_size=400
batch_size=400,
learning_rate=0.001,
entropy_weight=0.001,
),
Expand Down
10 changes: 6 additions & 4 deletions dizoo/dmc2gym/entry/dmc2gym_onppo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ def wrapped_dmc2gym_env(cfg):
width=default_cfg["width"],
frame_skip=default_cfg["frame_skip"]
),
cfg={'env_wrapper': [
lambda env: Dmc2GymWrapper(env, default_cfg),
lambda env: EvalEpisodeReturnWrapper(env),
]}
cfg={
'env_wrapper': [
lambda env: Dmc2GymWrapper(env, default_cfg),
lambda env: EvalEpisodeReturnWrapper(env),
]
}
)


Expand Down
10 changes: 6 additions & 4 deletions dizoo/procgen/entry/coinrun_onppo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def wrapped_procgen_env(cfg):
num_levels=default_cfg.num_levels
) if default_cfg.control_level else
gym.make('procgen:procgen-' + default_cfg.env_id + '-v0', start_level=0, num_levels=1),
cfg={'env_wrapper': [
lambda env: CoinrunWrapper(env, default_cfg),
lambda env: EvalEpisodeReturnWrapper(env),
]}
cfg={
'env_wrapper': [
lambda env: CoinrunWrapper(env, default_cfg),
lambda env: EvalEpisodeReturnWrapper(env),
]
}
)


Expand Down
104 changes: 63 additions & 41 deletions dizoo/tabmwp/envs/tabmwp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(self, cfg):
openai.api_key = cfg.api_key
self.observation_space = gym.spaces.Dict()
self.action_space = gym.spaces.Discrete(self.cfg.cand_number * (self.cfg.cand_number - 1))
self.reward_space = gym.spaces.Box(
low=-1, high=1, shape=(1,), dtype=np.float32
)
self.reward_space = gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32)
self.correct_num = 0

# Initialize language model if needed.
Expand Down Expand Up @@ -61,8 +59,12 @@ def get_output(self, inp: str) -> str:
inputs = TabMWP.tokenizer(inp + " [MASK].", return_tensors="pt")
inputs = TabMWP.tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
inputs = {key: value.cuda() for key, value in inputs.items()}
outputs = TabMWP.model.generate(**inputs, max_length=512, eos_token_id=TabMWP.tokenizer.eop_token_id,
pad_token_id=TabMWP.tokenizer.eos_token_id)
outputs = TabMWP.model.generate(
**inputs,
max_length=512,
eos_token_id=TabMWP.tokenizer.eop_token_id,
pad_token_id=TabMWP.tokenizer.eos_token_id
)
outputs = TabMWP.tokenizer.decode(outputs[0].tolist())

t0 = outputs.find('<|startofpiece|>') + 16
Expand All @@ -78,29 +80,37 @@ def reset(self) -> dict:
if TabMWP.model is not None:
TabMWP.model = TabMWP.model.cuda()
if self.enable_replay:
self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514',
'19270', '23713', '17209', '33379', '34987', '11177']
self.cand_pids = [
'32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713',
'17209', '33379', '34987', '11177'
]
if self.cfg.seed == 0: # train
self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433',
'26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162',
'13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993',
'16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964',
'29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455',
'13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198',
'26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520',
'37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229',
'22918', '31680', '15024', '24607', '26930']
self.train_pids = [
'14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135',
'13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482',
'4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903',
'18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020',
'17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198',
'26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329',
'21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024',
'24607', '26930'
]
model_io_path = 'dizoo/tabmwp/data/model_in_out_train.txt'
if not os.path.exists(model_io_path):
os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O '
+ model_io_path + ' --no-check-certificate')
os.system(
f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O ' +
model_io_path + ' --no-check-certificate'
)
else:
self.train_pids = ['21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044',
'19492', '31882', '11991', '27594', '7637', '15394', '7666', '5177', '33761',
'13703', '29105']
self.train_pids = [
'21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044', '19492', '31882',
'11991', '27594', '7637', '15394', '7666', '5177', '33761', '13703', '29105'
]
model_io_path = 'dizoo/tabmwp/data/model_in_out_eval.txt'
os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O '
+ model_io_path + ' --no-check-certificate')
os.system(
f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O ' + model_io_path +
' --no-check-certificate'
)

self.cfg.cand_number = len(self.cand_pids)
self.cfg.train_number = len(self.train_pids)
Expand Down Expand Up @@ -135,8 +145,19 @@ def search_answer(self, pid, pids):
raise ValueError('item does not exists.')

def parse_all_answers(self):
self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492']
self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930']
self.cand_pids = [
'32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713',
'17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492'
]
self.train_pids = [
'14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135',
'13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970',
'4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504',
'11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245',
'15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056',
'7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903',
'11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930'
]
self.problem_id = 0
self.cfg.train_number = len(self.train_pids)
n = len(self.cand_pids)
Expand Down Expand Up @@ -221,24 +242,25 @@ def __repr__(self) -> str:

if __name__ == '__main__':
from easydict import EasyDict
env_cfg = EasyDict(dict(
cand_number=16,
train_number=20,
engine='text-davinci-002',
temperature=0.,
max_tokens=512,
top_p=1.,
frequency_penalty=0.,
presence_penalty=0.,
option_inds=["A", "B", "C", "D", "E", "F"],
api_key='xxx',
prompt_format='TQ-A',
enable_replay=True,
seed=0,
))
env_cfg = EasyDict(
dict(
cand_number=16,
train_number=20,
engine='text-davinci-002',
temperature=0.,
max_tokens=512,
top_p=1.,
frequency_penalty=0.,
presence_penalty=0.,
option_inds=["A", "B", "C", "D", "E", "F"],
api_key='xxx',
prompt_format='TQ-A',
enable_replay=True,
seed=0,
)
)
env = TabMWP(env_cfg)
env.seed(0)
env.reset()
env.parse_all_answers()
env.search_answer('22976', ['32889', '8044'])

Loading

0 comments on commit 4eebb53

Please sign in to comment.