You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
* added attn_implementation to the model arguments
* added a check on the concept_value
* set None unit to a default value N/A
* set None value in concept_values to 0.0
* set _supports_sdpa = True in BertPreTrainedModel
* implemented flash attn
* do not overwrite the attention mask when flash attention is enabled
* upgraded huggingface transformers
* updated the logic for splitting heads
* make sure we load the model using the specified torch_dtype
* set the entire model to the corresponding dtype
* removed keyward arguments from hf_cehrgpt
* updated BertSelfFlashAttention.forward to return a tuple because the BERT layer expects such output
* test gpt2 implementation
* test gpt2 implementation
* pass the attn_implementation and torch_dtype to the model during fine-tuning
* set the default value of torch_dtype to auto
* convert age_at_index to the same data type as the bert output
* added logic to convert float32 to the corresponding precision
* removed mlm_skip_values
* updated the unit test after removing mlm_skip_values
* set the default value of torch_dtype to None
* convert concept_value_masks to torch.bool before using it in torch.where
* convert tensors back to the original dtype in the flash attention implementation
* check if torch_dtype is null before trying to get it from torch