In [1]:
import torch
import torch.nn as nn

from typing import List, Optional, Tuple, Union

from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEVICE = "cuda"

# textattack/facebook-bart-large-RTE
# textattack/roberta-base-RTE

In [18]:
class RobertaNeuralByPass(nn.Module):

    def __init__(self, model_name):
        super(RobertaNeuralByPass, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.encoder = self.model.roberta.encoder
        self.layer = self.encoder.layer
        self.classifier = self.model.classifier
        self.model.to(DEVICE)
    
    def tokenize(self, premise, hypothesis):
        encoded = self.tokenizer(
            premise, hypothesis,
            return_tensors="pt"
        )
        return encoded
    
    def full_forward(self, encoded):
        output = self.model(**encoded, output_hidden_states=True, output_attentions=True)
        hidden_states = output.hidden_states
        attentions = output.attentions
        logits = output.logits
        return logits, hidden_states, attentions
    
    def forward(
        self, 
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor], 
        layer_id
    ):
        all_hidden_states = ()
        for i, layer_module in enumerate(self.layer[layer_id-1:]):
            layer_outputs = layer_module(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask
                )
            hidden_states = layer_outputs[0]
            all_hidden_states += (hidden_states,)
        
        logits = self.classifier(hidden_states)
        return logits, all_hidden_states


In [19]:
roberta_bypass = RobertaNeuralByPass("roberta-large-mnli")

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
encoded = roberta_bypass.tokenize("All dogs are running", "Some animals are moving")
encoded.to(DEVICE)

{'input_ids': tensor([[   0, 3684, 3678,   32,  878,    2,    2, 6323, 3122,   32, 1375,    2]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [28]:
logits, hidden_states, attentions = roberta_bypass.full_forward(encoded)
logits

tensor([[-3.4080,  0.8012,  2.9936]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [29]:
logits, all_hidden_states = roberta_bypass.forward(hidden_states[20], encoded['attention_mask'], 21)
logits

tensor([[-3.4079,  0.8012,  2.9936]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [33]:
hidden_states[21] == all_hidden_states[0]

tensor([[[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [False, False, False,  ..., False, False, False]]], device='cuda:0')

In [31]:
hidden_states[1]

tensor([[[ 0.1519, -0.0937,  0.1447,  ..., -0.0300, -0.2084,  0.1826],
         [-1.3912, -1.7311, -0.0307,  ..., -0.3199,  1.0177, -1.2630],
         [-0.7188, -0.1330,  0.3086,  ..., -0.9791,  1.1518, -0.5150],
         ...,
         [-0.0531, -0.4782, -0.0294,  ..., -1.3526,  0.0126, -0.3265],
         [-0.3398, -1.6078,  0.7697,  ...,  0.3044,  1.5741,  0.2432],
         [-0.1567, -0.2220, -0.4250,  ..., -0.5397,  0.0649, -0.4106]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [32]:
all_hidden_states[0]

tensor([[[-0.0528, -0.0299,  0.0852,  ..., -0.1457,  0.0252, -0.3061],
         [-0.5096, -0.4977, -1.7494,  ...,  0.5328,  0.1078, -0.7403],
         [-0.1172,  0.1241, -1.5735,  ...,  0.6089,  0.3702, -0.4649],
         ...,
         [-0.9167, -1.3176, -0.0505,  ..., -0.6160, -0.1814,  0.4936],
         [-0.8177, -0.7924,  0.2921,  ..., -0.7807,  0.8872,  0.9065],
         [-0.4522, -0.4103, -0.2238,  ..., -0.1327,  0.1258,  1.1909]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)