Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 model #159

Merged
merged 23 commits into from
Feb 20, 2023
Merged

Add T5 model #159

merged 23 commits into from
Feb 20, 2023

Conversation

seanmor5
Copy link
Contributor

@seanmor5 seanmor5 commented Feb 13, 2023

Google's Flan-T5 model is a really good open-source alternative to GPT, and it's based on this T5. This is still failing, I think it's something to do with the relative attention bias, but will go back and look in a little bit.

I had to change a few things:

  1. I added customizable layer norm because T5 uses an RMS norm rather than a layer norm. I will eventually upstream RMS norm into Axon

  2. I added a relative_attention_bias option because some T5 blocks rely on relative attention bias

  3. I added an output_norm option to control whether or not to apply the final output_norm for each block. T5 does not and instead applies one globally for a block

@seanmor5
Copy link
Contributor Author

T5 also does not scale queries, so I added that layer in. It's strange though we match PT and Flax until the application of softmax on the attention weights, and then we start to diverge, but I don't think there's anything wrong with Axon's softmax implementation

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, a couple comments!

lib/bumblebee/text/t5.ex Outdated Show resolved Hide resolved
lib/bumblebee/layers/transformer.ex Outdated Show resolved Hide resolved
test/bumblebee/text/t5_test.exs Show resolved Hide resolved
lib/bumblebee/layers/transformer.ex Outdated Show resolved Hide resolved
lib/bumblebee/layers/transformer.ex Outdated Show resolved Hide resolved
lib/bumblebee/layers/transformer.ex Outdated Show resolved Hide resolved
name = opts[:name]

hidden_state
|> Layers.rms_norm(name: join(name, "layer_norm"), epsilon: spec.layer_norm_epsilon)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 this is the output norm actually, so we don't need :output_norm, we just need to remove this here. That's the funny part of transformer implementations, they group the layers differently, but at the end of the day they are mostly the same :D

Also, I think we can add support for fft: [use_bias: false] defaulting to true, and then we don't need the custom fft altogether!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko I missed a step in the implementation which uses a custom ffn based on the activation function passed, which I am adding now so we may still need the custom ffn in that case, will see how generic I can make it!

Copy link
Member

@jonatanklosko jonatanklosko Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so this is the output layer, but in this case they use the state before the normalization layer for the shortcut connection. I removed the :output_norm norm option and changed shortcut to use the parent, which I think is simpler and better shows the model difference/similarity :)

@jonatanklosko
Copy link
Member

@seanmor5 I made some changes, feel free to merge once everything looks good to you. The CI is failing, probably OOM or similar, since I verified that everything passes locally.

@seanmor5 seanmor5 merged commit a2df872 into main Feb 20, 2023
@seanmor5 seanmor5 deleted the sm-t5 branch February 20, 2023 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants