-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add T5 model #159
Add T5 model #159
Conversation
T5 also does not scale queries, so I added that layer in. It's strange though we match PT and Flax until the application of softmax on the attention weights, and then we start to diverge, but I don't think there's anything wrong with Axon's softmax implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, a couple comments!
lib/bumblebee/text/t5.ex
Outdated
name = opts[:name] | ||
|
||
hidden_state | ||
|> Layers.rms_norm(name: join(name, "layer_norm"), epsilon: spec.layer_norm_epsilon) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanmor5 this is the output norm actually, so we don't need :output_norm
, we just need to remove this here. That's the funny part of transformer implementations, they group the layers differently, but at the end of the day they are mostly the same :D
Also, I think we can add support for fft: [use_bias: false]
defaulting to true, and then we don't need the custom fft altogether!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonatanklosko I missed a step in the implementation which uses a custom ffn
based on the activation function passed, which I am adding now so we may still need the custom ffn
in that case, will see how generic I can make it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so this is the output layer, but in this case they use the state before the normalization layer for the shortcut connection. I removed the :output_norm
norm option and changed shortcut to use the parent, which I think is simpler and better shows the model difference/similarity :)
@seanmor5 I made some changes, feel free to merge once everything looks good to you. The CI is failing, probably OOM or similar, since I verified that everything passes locally. |
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
Google's Flan-T5 model is a really good open-source alternative to GPT, and it's based on this T5. This is still failing, I think it's something to do with the relative attention bias, but will go back and look in a little bit.
I had to change a few things:
I added customizable layer norm because T5 uses an RMS norm rather than a layer norm. I will eventually upstream RMS norm into Axon
I added a relative_attention_bias option because some T5 blocks rely on relative attention bias
I added an
output_norm
option to control whether or not to apply the finaloutput_norm
for each block. T5 does not and instead applies one globally for a block