You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I'm trying an ILQL training with a gpt-j network trained with this code. I don't have this problem with the original pre-trained net, nor with a flan-xl.
Traceback (most recent call last):
File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 118, in <module>
main()
File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 109, in main
trlx.train(
File "/home/jupyter/trlx/trlx/trlx.py", line 126, in train
trainer.learn()
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 539, in learn
results = self.evaluate()
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 384, in evaluate
samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"])
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 276, in generate_eval
return self.accelerator.unwrap_model(self.model).generate(
File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 307, in generate
out = self.forward(
File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 263, in forward
outputs = self.base_model(**forward_kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 854, in forward
transformer_outputs = self.transformer(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 689, in forward
outputs = block(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 309, in forward
attn_outputs = self.attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 257, in forward
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 183, in _attn
attn_output = torch.matmul(attn_weights, value)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [256, 101] but got: [256, 1].
@maxreciprocate Regarding the dataset and train, I use this train() code:
trlx.train(
samples = [(text,output) for text,output in zip(ttv_ds['train']['text'],ttv_ds['train']['output'])],
rewards = labels,
eval_prompts=ttv_ds['validation']['text'][:16],
config = config,
)
Where:
samples = [(string, string), (string, string), ...] # list of tuples (string, string)
labels = [0,1,0,1...] # list of labels 0/1
samples = [string, string, ..] # list of strings
Hi,
I'm trying an ILQL training with a gpt-j network trained with this code. I don't have this problem with the original pre-trained net, nor with a flan-xl.
This is my config:
Thanks.
The text was updated successfully, but these errors were encountered: