Skip to content

Commit

Permalink
test(luyd): add model test code (#728)
Browse files Browse the repository at this point in the history
* Fix test files

* Add vac test and fix dt test

* Add qtrain test and GTrXLDQN test

* Fix ngu test

* Add transformer_segment_wrapper test

* Reformat
  • Loading branch information
AltmanD committed Sep 19, 2023
1 parent f131c36 commit 0401412
Show file tree
Hide file tree
Showing 21 changed files with 466 additions and 96 deletions.
2 changes: 1 addition & 1 deletion ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def test_discrete_dt():
from ding.utils import set_pkg_seed
from ding.data import create_dataset
from ding.config import compile_config
from ding.model.template.dt import DecisionTransformer
from ding.model.template.decision_transformer import DecisionTransformer
from ding.policy import DTPolicy
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
OfflineMemoryDataFetcher, offline_logger, termination_checker
Expand Down
2 changes: 1 addition & 1 deletion ding/example/dt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gym
from ditk import logging
from ding.model.template.dt import DecisionTransformer
from ding.model.template.decision_transformer import DecisionTransformer
from ding.policy import DTPolicy
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .maqac import DiscreteMAQAC, ContinuousMAQAC
from .madqn import MADQN
from .vae import VanillaVAE
from .dt import DecisionTransformer
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
action_preds = self.predict_action(h[:, 1]) # predict action given r, s
else:
state_embeddings = self.state_encoder(
states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()
states.reshape(-1, *self.state_dim).type(torch.float32).contiguous()
) # (batch * block_size, h_dim)
state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim)
returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32))
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def forward(self, x: torch.Tensor) -> Dict:
>>> # Init input's Keys:
>>> obs_dim, seq_len, bs, action_dim = 128, 64, 32, 4
>>> obs = torch.rand(seq_len, bs, obs_dim)
>>> model = GTrXLDiscreteHead(obs_dim, action_dim)
>>> model = GTrXLDQN(obs_dim, action_dim)
>>> outputs = model(obs)
>>> assert isinstance(outputs, dict)
"""
Expand Down
41 changes: 41 additions & 0 deletions ding/model/template/tests/test_acer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import pytest
from itertools import product

from ding.model.template import ACER
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestACER:

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_ACER(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
else:
inputs = torch.randn(B, *obs_shape)
model = ACER(obs_shape, act_shape)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
if isinstance(act_shape, int):
assert outputs_c['q_value'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_c['q_value'].shape == (B, *act_shape)

outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['logit'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['logit'].shape == (B, *act_shape)

outputs = {**outputs_a, **outputs_c}
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
75 changes: 75 additions & 0 deletions ding/model/template/tests/test_bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from itertools import product
import torch
from ding.model.template import BCQ
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestBCQ:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_BCQ(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, obs_shape)
else:
inputs_obs = torch.randn(B, *obs_shape)
if isinstance(act_shape, int):
inputs_act = torch.randn(B, act_shape)
else:
inputs_act = torch.randn(B, *act_shape)
inputs = {'obs': inputs_obs, 'action': inputs_act}
model = BCQ(obs_shape, act_shape)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
if isinstance(act_shape, int):
assert torch.stack(outputs_c['q_value']).shape == (2, B)
else:
assert torch.stack(outputs_c['q_value']).shape == (2, B)
self.output_check(model.critic, torch.stack(outputs_c['q_value']))

outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['action'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['action'].shape == (B, *act_shape)
self.output_check(model.actor, outputs_a)

outputs_vae = model(inputs, mode='compute_vae')
assert isinstance(outputs_vae, dict)
if isinstance(act_shape, int):
assert outputs_vae['recons_action'].shape == (B, act_shape)
assert outputs_vae['mu'].shape == (B, act_shape * 2)
assert outputs_vae['log_var'].shape == (B, act_shape * 2)
assert outputs_vae['z'].shape == (B, act_shape * 2)
elif len(act_shape) == 1:
assert outputs_vae['recons_action'].shape == (B, *act_shape)
assert outputs_vae['mu'].shape == (B, act_shape[0] * 2)
assert outputs_vae['log_var'].shape == (B, act_shape[0] * 2)
assert outputs_vae['z'].shape == (B, act_shape[0] * 2)
if isinstance(obs_shape, int):
assert outputs_vae['prediction_residual'].shape == (B, obs_shape)
else:
assert outputs_vae['prediction_residual'].shape == (B, *obs_shape)

outputs_eval = model(inputs, mode='compute_eval')
assert isinstance(outputs_eval, dict)
assert isinstance(outputs_eval, dict)
if isinstance(act_shape, int):
assert outputs_eval['action'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_eval['action'].shape == (B, *act_shape)
59 changes: 42 additions & 17 deletions ding/model/template/tests/test_decision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,49 @@
import pytest
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F

from ding.model.template import DecisionTransformer
from ding.torch_utils import is_differentiable

args = ['continuous', 'discrete']
action_space = ['continuous', 'discrete']
state_encoder = [None, nn.Sequential(nn.Flatten(), nn.Linear(8, 8), nn.Tanh())]
args = list(product(*[action_space, state_encoder]))
args.pop(1)


@pytest.mark.unittest
@pytest.mark.parametrize('action_space', args)
def test_decision_transformer(action_space):
@pytest.mark.parametrize('action_space, state_encoder', args)
def test_decision_transformer(action_space, state_encoder):
B, T = 4, 6
state_dim = 3
if state_encoder:
state_dim = (2, 2, 2)
else:
state_dim = 3
act_dim = 2
DT_model = DecisionTransformer(
state_dim=state_dim,
act_dim=act_dim,
state_encoder=state_encoder,
n_blocks=3,
h_dim=8,
context_len=T,
n_heads=2,
drop_p=0.1,
continuous=(action_space == 'continuous')
)
DT_model.configure_optimizers(1.0, 0.0003)

is_continuous = True if action_space == 'continuous' else False
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
states = torch.randn([B, T, state_dim]) # B x T x state_dim
if state_encoder:
timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
else:
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
if isinstance(state_dim, int):
states = torch.randn([B, T, state_dim]) # B x T x state_dim
else:
states = torch.randn([B, T, *state_dim]) # B x T x state_dim
if action_space == 'continuous':
actions = torch.randn([B, T, act_dim]) # B x T x act_dim
action_target = torch.randn([B, T, act_dim])
Expand All @@ -51,12 +66,19 @@ def test_decision_transformer(action_space):
state_preds, action_preds, return_preds = DT_model.forward(
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
)
assert state_preds.shape == (B, T, state_dim)
if state_encoder:
assert state_preds == None
assert return_preds == None
else:
assert state_preds.shape == (B, T, state_dim)
assert return_preds.shape == (B, T, 1)
assert action_preds.shape == (B, T, act_dim)
assert return_preds.shape == (B, T, 1)

# only consider non padded elements
action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
if state_encoder:
action_preds = action_preds.reshape(-1, act_dim)
else:
action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0]

if is_continuous:
action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
Expand All @@ -68,11 +90,14 @@ def test_decision_transformer(action_space):
else:
action_loss = F.cross_entropy(action_preds, action_target)

# print(action_loss)
# is_differentiable(action_loss, DT_model)
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
) # pass
if state_encoder:
is_differentiable(
action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder]
)
else:
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
)
57 changes: 57 additions & 0 deletions ding/model/template/tests/test_edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import pytest
from itertools import product

from ding.model.template import EDAC
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestEDAC:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_EDAC(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, obs_shape)
else:
inputs_obs = torch.randn(B, *obs_shape)
if isinstance(act_shape, int):
inputs_act = torch.randn(B, act_shape)
else:
inputs_act = torch.randn(B, *act_shape)
inputs = {'obs': inputs_obs, 'action': inputs_act}
model = EDAC(obs_shape, act_shape, ensemble_num=2)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
assert outputs_c['q_value'].shape == (2, B)
self.output_check(model.critic, outputs_c)

if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
else:
inputs = torch.randn(B, *obs_shape)
outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['logit'][0].shape == (B, act_shape)
assert outputs_a['logit'][1].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['logit'][0].shape == (B, *act_shape)
assert outputs_a['logit'][1].shape == (B, *act_shape)
outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]}
self.output_check(model.actor, outputs)
70 changes: 70 additions & 0 deletions ding/model/template/tests/test_ngu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from itertools import product
import torch
from ding.model.template import NGU
from ding.torch_utils import is_differentiable

B = 4
H = 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [4, (4, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestNGU:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_ngu(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, H, obs_shape)
else:
inputs_obs = torch.randn(B, H, *obs_shape)
if isinstance(act_shape, int):
inputs_prev_action = torch.ones(B, act_shape).long()
else:
inputs_prev_action = torch.ones(B, *act_shape).long()
inputs_prev_reward_extrinsic = torch.randn(B, H, 1)
inputs_beta = 2 * torch.ones([4, 4], dtype=torch.long)
inputs = {
'obs': inputs_obs,
'prev_state': None,
'prev_action': inputs_prev_action,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}

model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
assert isinstance(outputs, dict)
if isinstance(act_shape, int):
assert outputs['logit'].shape == (B, act_shape, act_shape)
elif len(act_shape) == 1:
assert outputs['logit'].shape == (B, *act_shape, *act_shape)
self.output_check(model, outputs['logit'])

inputs = {
'obs': inputs_obs,
'prev_state': None,
'action': inputs_prev_action,
'reward': inputs_prev_reward_extrinsic,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}
model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
assert isinstance(outputs, dict)
if isinstance(act_shape, int):
assert outputs['logit'].shape == (B, act_shape, act_shape)
elif len(act_shape) == 1:
assert outputs['logit'].shape == (B, *act_shape, *act_shape)
self.output_check(model, outputs['logit'])
Loading

0 comments on commit 0401412

Please sign in to comment.