Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add tabmwp env and prompt pg policy #667

Merged
merged 61 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
c0f0cac
wrong
May 11, 2023
918171c
update config
May 11, 2023
de33a2c
update config
May 11, 2023
8ef6e1a
update command policy
May 15, 2023
fdb74c0
debug
May 15, 2023
353de6d
debug
May 15, 2023
6223146
debug
May 15, 2023
c2ecc48
debug
May 15, 2023
6773bee
debug
May 15, 2023
25f6b3a
debug
May 19, 2023
9abda20
debug
May 19, 2023
bfdc122
debug
May 19, 2023
0510c83
debug
May 19, 2023
2385df5
debug
May 19, 2023
4cef99b
debug
May 19, 2023
f18fafd
add glm
May 19, 2023
c12b2a2
add glm
May 19, 2023
dd9589e
add glm model
May 20, 2023
a783416
add glm model
May 20, 2023
0bb2df2
add glm model
May 20, 2023
4335a3f
add glm model
May 20, 2023
79b2598
add eval return
May 22, 2023
61e4694
reformat
May 22, 2023
59f4098
modify action space
May 23, 2023
c6afc5d
modify action space
May 23, 2023
9345de6
polish answer process
May 24, 2023
d89e39a
update policy
May 24, 2023
e805a0a
update rwkv
May 24, 2023
1b3f2b4
update policy
May 24, 2023
40b6c46
polish
May 25, 2023
e1f7cac
polish
May 25, 2023
0213f32
Merge branch 'main' of https://github.com/kxzxvbk/DI-engine
May 25, 2023
c1c22fd
resolve conflict
May 25, 2023
39e520d
debug prompt pg
Jun 15, 2023
11bc0ad
add parse
Jun 25, 2023
8c9c40d
update load env
Jun 26, 2023
9cac14e
add merge files
Jul 5, 2023
ff5ad2d
add merge files
Jul 5, 2023
f6d6ac4
feature(whl): add internlm
Jul 10, 2023
c716308
feature(whl): add internlm
Jul 10, 2023
43a8168
update fix parse
Jul 11, 2023
56068b1
add new dataset
Jul 25, 2023
d32b64b
fix datafiles
Jul 26, 2023
1063fbd
polish code
Jul 28, 2023
286d976
polish env
Jul 28, 2023
b229216
polish
Aug 1, 2023
a8fc87b
polish
Aug 1, 2023
c0ad294
add model wrapper
Aug 1, 2023
eff4155
polish wrapper
Aug 1, 2023
00a64e5
polish
Aug 1, 2023
e17f9d5
remove redundant files
Aug 1, 2023
73fb2ee
reformat
Aug 1, 2023
476babf
polish
Aug 11, 2023
b277505
Merge branch 'main' into gpt3_env
Sep 2, 2023
8ad299f
polish
Sep 2, 2023
5e13da9
merge main
Sep 2, 2023
2f15217
debug
Sep 2, 2023
f9c2e73
polish readme
Sep 3, 2023
34044a1
reformat
Sep 3, 2023
16b826b
polish tabmwp
Sep 4, 2023
239ff18
test
Sep 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/algo_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make algotest
1 change: 1 addition & 0 deletions .github/workflows/envpool_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install ".[envpool]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make envpooltest
1 change: 1 addition & 0 deletions .github/workflows/platform_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
python -m pip uninstall pytest-timeouts -y
make platformtest
2 changes: 2 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
python -m pip install box2d-py
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
Expand Down Expand Up @@ -53,5 +54,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make benchmark
100 changes: 51 additions & 49 deletions README.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
from .language_transformer import LanguageTransformer
# algorithm-specific
from .ppg import PPG
from .qmix import Mixer, QMix
Expand Down
63 changes: 63 additions & 0 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from ding.utils import MODEL_REGISTRY
from torch import nn
try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
import sys
from ditk import logging
logging.warning("not found transformer, please install it using: pip install transformers")
sys.exit(1)


@MODEL_REGISTRY.register('language_transformer')
class LanguageTransformer(nn.Module):

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False

if add_linear:
# Add an additional small, adjustable linear layer on top of BERT tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
else:
self.linear = None

def _calc_embedding(self, x: list) -> torch.Tensor:
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
# exactly ``max_length``, which can enable batch-wise computing.
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
output = self.model(**input, output_hidden_states=True)
# Get last layer hidden states
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: list, candidate_samples: list) -> dict:
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
25 changes: 25 additions & 0 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from ding.model.template.language_transformer import LanguageTransformer


@pytest.mark.unittest
class TestNLPPretrainedModel:

def check_model(self):
test_pids = [1]
cand_pids = [0, 2, 4]
problems = [
"This is problem 0", "This is the first question", "Second problem is here", "Another problem",
"This is the last problem"
]
ctxt_list = [problems[pid] for pid in test_pids]
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
49 changes: 49 additions & 0 deletions ding/model/wrapper/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,53 @@ def forward(self, *args, **kwargs):
return output


class CombinationArgmaxSampleWrapper(IModelWrapper):
r"""
Overview:
Used to help the model to sample combination argmax action.
"""

def forward(self, shot_number, *args, **kwargs):
output = self._model.forward(*args, **kwargs)
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(shot_number):
masked_logit = output['logit'] + mask
actions = masked_logit.argmax(dim=-1)
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e8
# `act` is shaped: (B, shot_number)
act = torch.stack(act, dim=1)
output['action'] = act
return output


class CombinationMultinomialSampleWrapper(IModelWrapper):
r"""
Overview:
Used to help the model to sample combination multinomial action.
"""

def forward(self, shot_number, *args, **kwargs):
output = self._model.forward(*args, **kwargs)
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(shot_number):
dist = torch.distributions.Categorical(logits=output['logit'] + mask)
actions = dist.sample()
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e8

# `act` is shaped: (B, shot_number)
act = torch.stack(act, dim=1)
output['action'] = act
return output


class HybridArgmaxSampleWrapper(IModelWrapper):
r"""
Overview:
Expand Down Expand Up @@ -906,6 +953,8 @@ def __init__(self, model, teacher_cfg):
# model wrapper
'target': TargetNetworkWrapper,
'teacher': TeacherNetworkWrapper,
'combination_argmax_sample': CombinationArgmaxSampleWrapper,
'combination_multinomial_sample': CombinationMultinomialSampleWrapper,
}


Expand Down
14 changes: 14 additions & 0 deletions ding/model/wrapper/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,17 @@ def test_transformer_memory_wrapper(self):
assert sum(new_memory2[:, -16:].flatten()) != 0
assert sum(new_memory2[:, :-16].flatten()) == 0
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))

def test_combination_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(shot_number=2, inputs=data)
assert output['action'].shape == (4, )
assert (output['action'] >= 0).all() and (output['action'] < 64).all()

def test_combination_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(shot_number=2, inputs=data)
assert output['action'].shape == (4, )
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
assert (output['action'] >= 0).all() and (output['action'] < 64).all()
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@

# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .bdq import BDQPolicy
from .bcq import BCQPolicy
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -438,3 +439,8 @@ def _get_setting_learn(self, command_info: dict) -> dict:

def _get_setting_eval(self, command_info: dict) -> dict:
return {}


@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass
Loading
Loading