-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arguments for simple H3 model than could learn Causal or Masked LM modeling? #3
Comments
That is mostly correct, but you would also need to pass in an attn_cuff to
specify how many heads you want in the two attention layers, like specified
here:
https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/configs/experiment/pile/h3.yaml#L17
You can use this as an example of how to construct the model in Python:
https://github.com/HazyResearch/H3/blob/8ebedd61275770b1fca6e0f8a31e642529d8aa97/examples/generate_text_h3.py#L37
The SSMLMHeadModel there has the same constructor as ConvLMHeaodel in this
repo.
We haven’t worked too much with masked LM yet, but let us know if you get
something working!
…On Sun, Mar 12, 2023 at 5:56 AM Brendan King ***@***.***> wrote:
Hi all! I've been trying to get the training experiments to run and
struggling with some errors in which hydra cannot parse the configs given
at /experiment/pile/h3 (I was following the instructions in experiments.md
<http://./experiments.md> for the Pile). I'm actually hoping to train on
a different dataset entirely though, for which I already have a working
pipeline. Given correct installation of the dependencies in this repo, is
there a best way to instantiate an H3 model that would be suitable for
Causal language modeling, and/or for Masked LM?
For example, hoping to come up with something comparable to this, and just
using it in my existing pipeline:
config = AutoConfig.from_pretrained(
"roberta-base",
vocab_size=tokenizer.vocab_size,
random_init=True,
is_decoder=True
)
model = RobertaForCausalLM(config)
Here is what I have for Causal LM, though I am still trying to sort out
the dependencies to get it to run.
model = ConvLMHeadModel(
d_model=768, n_layer=12, d_inner=768 * 4,
vocab_size=tokenizer.vocab_size, resid_dropout=0.0, embed_dropout=0.1,
layer="h3", attn_layer_idx=[1, 8],
attn_cfg=None,
fused_mlp=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
pad_vocab_size_multiple=8,
)
Are these reasonable choices for these arguments, and/or are there others
I would need to specify? Are there important pieces of the training or data
preparation specified in the configs, that I would need to replicate in
another pipeline designed for a HF transformer LM? Sorry if these are
obvious/documented, was just having a hard time reading through the configs.
—
Reply to this email directly, view it on GitHub
<#3>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIR3JDP6E6X6PSU6AELW3WMULANCNFSM6AAAAAAVYAQTXE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Awesome, thanks! This worked great. Just for my understanding, is the I also wanted to ask if there are any necessary steps to apply attention masks/prevent forward leakage of information in causal LM training for this model. Based on my experiments so far results would suggest there is no such leakage, but I was surprised no |
ConvLMHeadModel is copied from SSMLMHeadModel, with some adaptations to
work in the safari repo (working with this repo's interface into layers,
e.g.).
For the causal language modeling -- we make all the SSMs and convolutions
causal by padding the kernels with zeros, so it's a naturally causal model.
That's why we don't need an attention mask!
…On Sun, Mar 12, 2023 at 4:19 PM Brendan King ***@***.***> wrote:
Awesome, thanks! This worked great. Just for my understanding, is the
ConvLMHeadModel a synonym for SSMLMHeadModel, or do they just share the
same constructor signature?
I also wanted to ask if there are any necessary steps to apply attention
masks/prevent forward leakage of information in causal LM training for this
model. Based on my experiments so far results would suggest there is no
such leakage, but I was surprised no attention_mask or similar arguments
are accepted in forward. Apologies if I'm missing something that should
be obvious for SSMs/H3, I'm not very familiar with how they work and am
mostly trying to assist some students in setting up/debugging the model for
one of their experiments
—
Reply to this email directly, view it on GitHub
<#3 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIV46JCKLCXR67ZZS53W3Y4WPANCNFSM6AAAAAAVYAQTXE>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
And we make the attention layers causal as well by default :)
…On Sun, Mar 12, 2023 at 6:58 PM Dan Fu ***@***.***> wrote:
ConvLMHeadModel is copied from SSMLMHeadModel, with some adaptations to
work in the safari repo (working with this repo's interface into layers,
e.g.).
For the causal language modeling -- we make all the SSMs and convolutions
causal by padding the kernels with zeros, so it's a naturally causal model.
That's why we don't need an attention mask!
On Sun, Mar 12, 2023 at 4:19 PM Brendan King ***@***.***>
wrote:
> Awesome, thanks! This worked great. Just for my understanding, is the
> ConvLMHeadModel a synonym for SSMLMHeadModel, or do they just share the
> same constructor signature?
>
> I also wanted to ask if there are any necessary steps to apply attention
> masks/prevent forward leakage of information in causal LM training for this
> model. Based on my experiments so far results would suggest there is no
> such leakage, but I was surprised no attention_mask or similar arguments
> are accepted in forward. Apologies if I'm missing something that should
> be obvious for SSMs/H3, I'm not very familiar with how they work and am
> mostly trying to assist some students in setting up/debugging the model for
> one of their experiments
>
> —
> Reply to this email directly, view it on GitHub
> <#3 (comment)>,
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/ABDDIIV46JCKLCXR67ZZS53W3Y4WPANCNFSM6AAAAAAVYAQTXE>
> .
> You are receiving this because you commented.Message ID:
> ***@***.***>
>
|
Hi I had forgotten to come back to this, but this was very helpful and everything works as expected! Thanks for the help and the insights |
Hi all! I've been trying to get the training experiments to run and struggling with some errors in which hydra cannot parse the configs given at
/experiment/pile/h3
(I was following the instructions in experiments.md for the Pile). I'm actually hoping to train on a different dataset entirely though, for which I already have a working pipeline. Given correct installation of the dependencies in this repo, is there a best way to instantiate an H3 model that would be suitable for Causal language modeling, and/or for Masked LM?For example, hoping to come up with something comparable to this, and just using it in my existing pipeline:
Here is what I have for Causal LM, though I am still trying to sort out the dependencies to get it to run.
Are these reasonable choices for these arguments, and/or are there others I would need to specify? Are there important pieces of the training or data preparation specified in the configs, that I would need to replicate in another pipeline designed for a HF transformer LM? Sorry if these are obvious/documented, was just having a hard time reading through the configs.
The text was updated successfully, but these errors were encountered: