Skip to content

Commit

Permalink
just make encoder / decoder work, cleanup later
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 12, 2023
1 parent 05f6f01 commit 6860360
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 35 deletions.
2 changes: 1 addition & 1 deletion palm_rlhf_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.palm_enc_dec import PaLMEncDec
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic
from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic, ActorCriticEncDec
18 changes: 10 additions & 8 deletions palm_rlhf_pytorch/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,13 @@ def finetune_parameters(self, scope = 'default'):

# default tokens

@property
def default_token_ids(self):
def default_token_ids(self, batch = 1):
device = self.device

if exists(self.default_start_token_id):
return torch.full((1, 1), self.default_start_token_id, device = device)
return torch.full((batch, 1), self.default_start_token_id, device = device)

return torch.randint(0, self.num_tokens, (1, 1), device = device)
return torch.randint(0, self.num_tokens, (batch, 1), device = device)

# generate function

Expand All @@ -503,12 +502,15 @@ def generate(
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
context = None,
batch_size = 1,
**kwargs
):
assert self.causal

if not exists(prompt):
prompt = self.default_token_ids
batch_size = context.shape[0] if exists(context) else batch_size
prompt = self.default_token_ids(batch = batch_size)
return_seq_without_prompt = False

prompt, leading_dims = pack([prompt], '* n')
Expand All @@ -519,7 +521,7 @@ def generate(
sample_num_times = max(1, seq_len - prompt.shape[-1])

for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = self.forward(out, return_logits_with_embedding = True, context = context, **kwargs)
logits, embeds = logits[:, -1], embeds[:, -1]

if exists(filter_logits_fn):
Expand Down Expand Up @@ -558,7 +560,7 @@ def forward(
return_only_embedding = False,
return_logits_with_embedding = False
):
x = default(prompt, lambda: self.default_token_ids)
x = default(prompt, lambda: self.default_token_ids(batch = (context.shape[0] if exists(context) else 1)))

assert not (exists(context) and not self.cross_attend)

Expand All @@ -585,7 +587,7 @@ def forward(
if exists(context):
context = self.to_cross_attn_key_values(context)
context = rearrange(context, 'b n (l r d) -> b n l r d', r = 2, l = len(self.layers))
cross_attn_key_values = tuple(tuple(tensor.unbind(dim = -2) for tensor in context.unbind(dim = -3)))
cross_attn_key_values = tuple(tensor.unbind(dim = -2) for tensor in context.unbind(dim = -3))

# parallel attention / ff blocks, passing in finetuning loras

Expand Down
2 changes: 1 addition & 1 deletion palm_rlhf_pytorch/palm_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
num_tokens,
depth,
enc_depth = None,
dec_default_start_token_id = None,
causal = True,
dim_head = 64,
heads = 8,
Expand All @@ -36,6 +35,7 @@ def __init__(
rotary_xpos_scale_base = 512,
finetune_scopes = tuple(),
cross_entropy_ignore_index = 0,
dec_default_start_token_id = None,
):
super().__init__()
self.dim = dim
Expand Down

0 comments on commit 6860360

Please sign in to comment.