In [4]:
from transformers import AutoModel, AutoTokenizer, BatchEncoding
import torch

In [5]:
model = AutoModel.from_pretrained("microsoft/codebert-base")
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
tokenizer.is_fast

True

In [10]:
code = "def max(a,b): if a>b: return a else return b"
code_tokens: BatchEncoding = tokenizer(code, return_tensors='pt')
print(tokenizer.tokenize(code))
print(code_tokens)
print(code_tokens.word_ids())

['def', 'Ġmax', '(', 'a', ',', 'b', '):', 'Ġif', 'Ġa', '>', 'b', ':', 'Ġreturn', 'Ġa', 'Ġelse', 'Ġreturn', 'Ġb']
{'input_ids': tensor([[    0,  9232, 19220,  1640,   102,     6,   428,  3256,   114,    10,
         15698,   428,    35,   671,    10,  1493,   671,   741,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, None]


In [15]:
output = model(**code_tokens, output_hidden_states=True, output_attentions=True)
output

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.1685,  0.3331,  0.0392,  ..., -0.2262, -0.3359,  0.3277],
         [-1.0436,  0.3191,  0.3959,  ..., -0.4708, -0.1289,  0.5579],
         [-0.9022,  0.5009,  0.1820,  ..., -0.4935, -0.5855,  0.6971],
         ...,
         [-0.4663,  0.2088,  0.5154,  ..., -0.1752, -0.3702,  0.5890],
         [-0.4513,  0.4893,  0.4857,  ..., -0.3150, -0.6229,  0.3867],
         [-0.1703,  0.3353,  0.0404,  ..., -0.2282, -0.3384,  0.3300]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 4.9750e-01, -3.6809e-01, -5.7454e-01,  1.0966e-01,  3.4282e-01,
          7.0422e-02,  4.8956e-01, -3.5818e-01,  1.0147e-01, -3.2847e-01,
          4.4498e-01,  2.5206e-02, -2.6024e-01,  1.1288e-01, -6.0998e-02,
          5.9084e-01,  4.4048e-01, -5.2328e-01,  9.5334e-02,  3.2977e-01,
         -3.4737e-01,  5.2274e-01,  3.5409e-01,  8.1201e-04, -4.9782e-02,
          2.3763e-01,  1.2317e-01,  9.8236e-02,  4.9072e-01,  1.491

In [37]:
type(output)

transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions

In [58]:
import inspect
inspect.getmro(type(model))

(transformers.models.roberta.modeling_roberta.RobertaModel,
 transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel,
 transformers.modeling_utils.PreTrainedModel,
 torch.nn.modules.module.Module,
 transformers.modeling_utils.ModuleUtilsMixin,
 transformers.generation_utils.GenerationMixin,
 transformers.utils.hub.PushToHubMixin,
 object)

In [56]:
model.encoder.layer[0].output.dense.out_features

768

In [23]:
len(output.hidden_states), set([output.hidden_states[i].shape for i in range(len(output.hidden_states))])

(13, {torch.Size([1, 19, 768])})

In [31]:
torch.stack(output.hidden_states).squeeze(1)[[0,1,4], 0, :].shape

torch.Size([3, 768])

In [35]:
torch.Tensor([1,2,3,4,5])[[0,1,4]]

tensor([1., 2., 5.])

In [27]:
code_tokens.to

[0;31mSignature:[0m
[0mcode_tokens[0m[0;34m.[0m[0mchar_to_token[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mbatch_or_char_index[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mchar_index[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msequence_index[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mint[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Get the index of the token in the encoded output comprising a character in the original string for a sequence
of the batch.

Can be called as:

- `self.char_to_token(char_index)` if batch size is 1
- `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1

This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
are defined by the user). In 

In [1]:
from featurizers.parser_utils import parse

In [2]:
parse_result = parse(
    """
public void reduce(UTF8 key, Iterator<UTF8> values,
                    OutputCollector<UTF8, UTF8> output, Reporter reporter) throws IOException 
{
    int a = 10;
    while (a > 0) {
        a = a - 1;
    }
}
    """,
    'java'
)

In [8]:
root = parse_result.tree.root_node
root

<Node kind=program, start_point=(1, 0), end_point=(9, 4)>

True