Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LlamaForQuestionAnswering #28265

Closed
Nkluge-correa opened this issue Dec 27, 2023 · 8 comments
Closed

Add LlamaForQuestionAnswering #28265

Nkluge-correa opened this issue Dec 27, 2023 · 8 comments

Comments

@Nkluge-correa
Copy link

Feature request

Add a LlamaForQuestionAnswering class to the modeling_llama.py so Llama models have AutoModelForQuestionAnswering support (by also adding Llama-style models to the MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES in the modeling_auto.py file.

Motivation

1 - Evaluation benchmarks like Squad or FaQUAD are commonly used to evaluate language models.
2 - Many decoder-only transformers (BLOOM, Falcon, OpenAI GPT-2, GPT Neo, GPT NeoX, GPT-J, etc.) have support for the AutoModelForQuestionAnswering.
3 - Creating a fine-tuning/evaluation procedure using things like AutoModelForQuestionAnswering and evaluate.load('squad') is very simple, making these features very helpful and desirable.
4 - On the contrary, if one cannot use AutoModelForQuestionAnswering, like in the Llama style models, everything becomes more difficult.

Hence, I would like to request the addition of a LlamaForQuestionAnswering class to the modeling_llama.py file. Hence, we could all easily perform experiments with Llama models and squad-style Q&A benchmarks:

from transformers import AutoTokenizer, AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.3")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.3")

Your contribution

I think, as suggested by nielsr in the forum, we can use the GptjForQuestionAnswering as a starting point, adding a LlamaForQuestionAnswering to the modeling_llama.py file:

@add_start_docstrings(
    """
    The Llama 2 Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.transformer = LlamaModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(LLAMA_START_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1).to(start_logits.device)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1).to(end_logits.device)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

and then, we add the Llama models to the MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES in the modeling_auto.py file:

MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        # Model for Question Answering mapping
        ("open-llama", "OpenLlamaModel"),
        ("llama", "LlamaModel"),
        ("code_llama", "LlamaModel"),
...

I can try to make these changes if no one more qualified wants to take the job 😅.

imaditya123 added a commit to imaditya123/transformers that referenced this issue Dec 27, 2023
This represents the resolution for the repository issue on GitHub with the reference number huggingface#28265.

huggingface#28265
@Tanmaypatil123
Copy link
Contributor

Tanmaypatil123 commented Dec 28, 2023

Hey @NielsRogge I would like to work on this issue .

@nakranivaibhav
Copy link
Contributor

@ArthurZucker @NielsRogge Is this feature still requested?
I can work on it

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 15, 2024

Hey @nakranivaibhav , as you can see, @Tanmaypatil123 has already started working on it, let's not duplicate work ! 🤗 Unless the PR is not updated in a week or so, feel free to take over, starting from the review I did 😉

@nakranivaibhav
Copy link
Contributor

@ArthurZucker Alright, I'll keep an 👀 on it.

@nakranivaibhav
Copy link
Contributor

@ArthurZucker Can i take the issue now?

@ArthurZucker
Copy link
Collaborator

Sure, just feel free to open a PR and take into account my reviews!

@ghost
Copy link

ghost commented Feb 23, 2024

can i have that

@fasterinnerlooper
Copy link

Looks like the change has been merged. Can it be closed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants