A follow up repository of Jax-Journey. This repository provides a selection of notebooks for various NLP tasks, which are completely see-through (i.e., you can see the implementation till the basic Jax/Haiku modules, in a single notebook). These were meant to be used as further tutorials in Jax for NLP, and as a guide for the coding style followed in this awesome article by Madison May.
These notebooks, although mostly code, also mention the nuanced features, often missed when using off-the-shelf models. Moreover, they allow you to optimize everything right to the innermost modules. Also, we mention how to adapt the model to your use case, in each notebook.
A basic introductory notebook consisting of the original RoBERTa initialized version and randomly initialized version .
Here we realise the need for restructuring the code, and correspondingly, place all the code component-wise in src/
. The new things we code over the original implementation are:
- The masking function for MLM here,
- A HuggingFace Tokenizers based tokenizer, here
- A Language Embedding for TLM task, here.
- Additionally, we include an option to make the transformer auto-regressive and add a mask for the same, here. This is needed for CLM.
The final notebook can be found here.
Hoping to create a kind, giving, an open community that forms deep connetions by working together. Join us here : https://discord.gg/s6xSHG94u5