Skip to content

Commit

Permalink
feature(whl): add tabmwp env and prompt pg policy (#667)
Browse files Browse the repository at this point in the history
* wrong

* update config

* update command policy

* debug

* debug

* add glm

* add glm

* add glm model

* add eval return

* reformat

* modify action space

* modify action space

* polish answer process

* update policy

* update rwkv

* update policy

* polish

* polish

* debug prompt pg

* add parse

* update load env

* add merge files

* add merge files

* feature(whl): add internlm

* feature(whl): add internlm

* update fix parse

* add new dataset

* fix datafiles

* polish code

* polish env

* polish

* polish

* add model wrapper

* polish wrapper

* polish

* remove redundant files

* reformat

* polish

* debug

* polish readme

* reformat

* polish tabmwp

* test
  • Loading branch information
kxzxvbk committed Sep 4, 2023
1 parent efa59b2 commit 3659d81
Show file tree
Hide file tree
Showing 21 changed files with 1,107 additions and 49 deletions.
1 change: 1 addition & 0 deletions .github/workflows/algo_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make algotest
1 change: 1 addition & 0 deletions .github/workflows/envpool_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install ".[envpool]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make envpooltest
1 change: 1 addition & 0 deletions .github/workflows/platform_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
python -m pip uninstall pytest-timeouts -y
make platformtest
2 changes: 2 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
python -m pip install box2d-py
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
Expand Down Expand Up @@ -53,5 +54,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make benchmark
100 changes: 51 additions & 49 deletions README.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
from .language_transformer import LanguageTransformer
# algorithm-specific
from .ppg import PPG
from .qmix import Mixer, QMix
Expand Down
63 changes: 63 additions & 0 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from ding.utils import MODEL_REGISTRY
from torch import nn
try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
import sys
from ditk import logging
logging.warning("not found transformer, please install it using: pip install transformers")
sys.exit(1)


@MODEL_REGISTRY.register('language_transformer')
class LanguageTransformer(nn.Module):

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False

if add_linear:
# Add an additional small, adjustable linear layer on top of BERT tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
else:
self.linear = None

def _calc_embedding(self, x: list) -> torch.Tensor:
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
# exactly ``max_length``, which can enable batch-wise computing.
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
output = self.model(**input, output_hidden_states=True)
# Get last layer hidden states
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: list, candidate_samples: list) -> dict:
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
25 changes: 25 additions & 0 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from ding.model.template.language_transformer import LanguageTransformer


@pytest.mark.unittest
class TestNLPPretrainedModel:

def check_model(self):
test_pids = [1]
cand_pids = [0, 2, 4]
problems = [
"This is problem 0", "This is the first question", "Second problem is here", "Another problem",
"This is the last problem"
]
ctxt_list = [problems[pid] for pid in test_pids]
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
49 changes: 49 additions & 0 deletions ding/model/wrapper/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,53 @@ def forward(self, *args, **kwargs):
return output


class CombinationArgmaxSampleWrapper(IModelWrapper):
r"""
Overview:
Used to help the model to sample combination argmax action.
"""

def forward(self, shot_number, *args, **kwargs):
output = self._model.forward(*args, **kwargs)
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(shot_number):
masked_logit = output['logit'] + mask
actions = masked_logit.argmax(dim=-1)
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e8
# `act` is shaped: (B, shot_number)
act = torch.stack(act, dim=1)
output['action'] = act
return output


class CombinationMultinomialSampleWrapper(IModelWrapper):
r"""
Overview:
Used to help the model to sample combination multinomial action.
"""

def forward(self, shot_number, *args, **kwargs):
output = self._model.forward(*args, **kwargs)
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(shot_number):
dist = torch.distributions.Categorical(logits=output['logit'] + mask)
actions = dist.sample()
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e8

# `act` is shaped: (B, shot_number)
act = torch.stack(act, dim=1)
output['action'] = act
return output


class HybridArgmaxSampleWrapper(IModelWrapper):
r"""
Overview:
Expand Down Expand Up @@ -906,6 +953,8 @@ def __init__(self, model, teacher_cfg):
# model wrapper
'target': TargetNetworkWrapper,
'teacher': TeacherNetworkWrapper,
'combination_argmax_sample': CombinationArgmaxSampleWrapper,
'combination_multinomial_sample': CombinationMultinomialSampleWrapper,
}


Expand Down
14 changes: 14 additions & 0 deletions ding/model/wrapper/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,17 @@ def test_transformer_memory_wrapper(self):
assert sum(new_memory2[:, -16:].flatten()) != 0
assert sum(new_memory2[:, :-16].flatten()) == 0
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))

def test_combination_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(shot_number=2, inputs=data)
assert output['action'].shape == (4, )
assert (output['action'] >= 0).all() and (output['action'] < 64).all()

def test_combination_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(shot_number=2, inputs=data)
assert output['action'].shape == (4, )
assert (output['action'] >= 0).all() and (output['action'] < 64).all()
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@

# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .bdq import BDQPolicy
from .bcq import BCQPolicy
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -438,3 +439,8 @@ def _get_setting_learn(self, command_info: dict) -> dict:

def _get_setting_eval(self, command_info: dict) -> dict:
return {}


@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass
Loading

0 comments on commit 3659d81

Please sign in to comment.