In [1]:
# Early Exit NLP

In [2]:
!pip install pip install --upgrade-strategy eager "optimum[openvino,nncf]"



In [3]:
import torch
from torch import nn
from transformers import AutoModel, AutoModelForQuestionAnswering, AutoConfig, AutoTokenizer
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers.modeling_utils import PreTrainedModel

2023-09-27 20:19:48.239759: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-27 20:19:48.241588: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-27 20:19:48.280591: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-27 20:19:48.281343: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
def entropy(x):
    softmax_probs = nn.functional.softmax(x, dim=1)
    ee_entropy = -torch.sum(softmax_probs * torch.log(softmax_probs), dim=1)
    return ee_entropy


def activate_ee(logits, threshold):
    return torch.all(entropy(logits) < threshold)


# Taken from HuggingFace transformers
def create_extended_attention_mask(attention_mask, dtype):
    if attention_mask.dim() == 3:
        extended_attention_mask = attention_mask[:, None, :, :]
    else:
        extended_attention_mask = attention_mask[:, None, None, :]
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
    return extended_attention_mask


def compute_qa_loss(logits, start_positions, end_positions):
    start_logits, end_logits = logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1).contiguous()
    end_logits = end_logits.squeeze(-1).contiguous()
    qa_loss = 0

    if start_positions is not None and end_positions is not None:
        # If we are on multi-GPU, split add a dimension
        if len(start_positions.size()) > 1:
            start_positions = start_positions.squeeze(-1)
        if len(end_positions.size()) > 1:
            end_positions = end_positions.squeeze(-1)
        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        ignored_index = start_logits.size(1)
        start_positions = start_positions.clamp(0, ignored_index)
        end_positions = end_positions.clamp(0, ignored_index)

        loss_function = nn.CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_function(start_logits, start_positions)
        end_loss = loss_function(end_logits, end_positions)
        qa_loss = (start_loss + end_loss) / 2
    return qa_loss

In [5]:
class RampClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.classifier = nn.Linear(config.hidden_size, 2)

    def forward(self, hidden_state):
        logits = self.classifier(hidden_state)
        return logits


# Trainable model wrapper
class EEQATrainableModel(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super().__init__(config)
        orig_model = AutoModel.from_pretrained(config._name_or_path)

        self.embeddings = orig_model.embeddings
        self.transformer_layers = orig_model.encoder.layer
        self.ramp_classifiers = nn.ModuleList([RampClassifier(config) for _ in range(config.num_hidden_layers)])

        self.ee_entropy = 0.0
    
    def get_entropy_threshold(self):
        return self.ee_entropy
    
    def set_entropy_threshold(self, ee_entropy):
        self.ee_entropy = ee_entropy
    
    def forward(self, input_ids=None, attention_mask=None, start_positions=None, end_positions=None, training=False, training_phase=1, **kwargs):
        embeddings_output = self.embeddings(input_ids=input_ids)

        # required for back-propagation
        attention_mask = create_extended_attention_mask(attention_mask, torch.float32)

        layer_input = embeddings_output
        total_loss = 0.0

        for i, transformer_layer in enumerate(self.transformer_layers):
            layer_output = transformer_layer(hidden_states=layer_input, attention_mask=attention_mask)
            layer_logits = self.ramp_classifiers[i](layer_output[0])

            if training:
                if training_phase == 2 and i < len(self.transformer_layers)-1:
                    total_loss += compute_qa_loss(layer_logits, start_positions, end_positions)
                elif training_phase == 1 and i == len(self.transformer_layers)-1:
                    total_loss = compute_qa_loss(layer_logits, start_positions, end_positions)
            else:
                if activate_ee(layer_logits, self.ee_entropy):
                    print(f'Exit layer: {i+1}')
                    break

            layer_input = layer_output[0]
        
        start_logits, end_logits = layer_logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        return QuestionAnsweringModelOutput(loss=total_loss,
            start_logits=start_logits, 
            end_logits=end_logits, 
            hidden_states=None, 
            attentions=None,
        )

In [None]:
model_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = EEQATrainableModel(AutoConfig.from_pretrained(model_id))

questions = ["Which name is also used to describe the Amazon rainforest in English?", "Where do I live?"]
contexts =  ["The Amazon rainforest (Portuguese: Floresta Amazônica or Amazônia; Spanish: Selva Amazónica, Amazonía or usually Amazonia; French: Forêt amazonienne; Dutch: Amazoneregenwoud), also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. States or departments in four nations contain \"Amazonas\" in their names. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.", "My name is Clara and I live in Berkeley."]

answerer = pipeline("question-answering", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)

for i in range(len(questions)):
    response = answerer(question=questions[i], context=contexts[i])
    print(response)