In [54]:
from bidirectional_mistral import MistralBiModel
from transformers import MistralPreTrainedModel
import torch
import numpy as np
from typing import Optional, List
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast

In [50]:
class MistralForSequenceClassification(MistralPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = MistralBiModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        pooled_output = transformer_outputs[0][:, 0]
        logits = self.score(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

In [13]:
from transformers import AutoTokenizer, AutoConfig
from transformers.trainer_utils import get_last_checkpoint

In [14]:
checkpoint = get_last_checkpoint('mistral-64M-mlm')
checkpoint

'mistral-64M-mlm/checkpoint-65000'

In [15]:
config = AutoConfig.from_pretrained(checkpoint)
config.num_labels = 6
config.vocab = 6

In [45]:
model = MistralForSequenceClassification.from_pretrained(checkpoint, config = config)

Some weights of MistralForSequenceClassification were not initialized from the model checkpoint at mistral-64M-mlm/checkpoint-65000 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [55]:
input_ids = tokenizer(['abdvdf'], return_tensors = 'pt')
input_ids['labels'] = torch.from_numpy(np.array([1]))
input_ids.pop('token_type_ids', None)
input_ids

{'input_ids': tensor([[1245,   71,   89,   71,   73]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'labels': tensor([1])}

In [56]:
o = model(**input_ids)

In [57]:
o

SequenceClassifierOutputWithPast(loss=tensor(1.6555, grad_fn=<NllLossBackward0>), logits=tensor([[-0.7810,  0.9508, -1.3591,  1.0763,  1.3662,  1.2214]],
       grad_fn=<MmBackward0>), past_key_values=((tensor([[[[-3.8209e-01, -2.0591e-01, -2.8188e-01,  ..., -1.2521e+00,
           -5.7633e-01, -1.1836e+00],
          [-3.6968e-01, -2.6678e-01,  3.9742e-01,  ..., -5.7496e-01,
           -1.1871e+00, -1.3965e+00],
          [-4.8291e-02, -9.1121e-02, -4.9391e-02,  ..., -1.3449e+00,
           -8.1308e-01, -6.4767e-01],
          [-8.4001e-02,  9.5946e-03,  4.6402e-01,  ..., -5.7602e-01,
           -1.1870e+00, -1.3966e+00],
          [ 4.3431e-01, -2.8008e-01,  1.4686e-01,  ..., -9.3143e-01,
           -7.5473e-01, -1.0023e+00]],

         [[ 2.4169e-01, -2.0167e-01,  2.7433e-01,  ...,  3.5047e-01,
           -3.0927e-01, -1.8186e-01],
          [ 4.6138e-01, -1.4437e-01,  2.8775e-01,  ...,  4.9117e-01,
            3.1849e-04,  5.8913e-02],
          [-3.3780e-01, -3.0013e-01,  1.7285e-