# core

> Fill in a module description here

In [None]:
#| default_exp core

In [1]:
#| export
import code_tokenizers
import json

from transformers import AutoTokenizer
from tree_sitter import Language, Parser

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
def unroll_node_types(
    nested_node_types: dict, # node_types from tree-sitter
) -> list: # list of node types
    """Unroll nested node types into a flat list of node types. This includes subtypes as well."""
    node_types = [node_type["type"] for node_type in nested_node_types]
    node_subtypes = [
        node_subtype["type"]
        for node_type in node_types
        if "subtypes" in node_type
        for node_subtype in node_type["subtypes"]
    ]
    return list(set(node_types + node_subtypes))

In [4]:
#| export
# From: https://github.com/github/CodeSearchNet/tree/master/function_parser
def traverse(
    node,       # tree-sitter node
    results,    # list to append results to
) -> None:
    """Traverse in a recursive way, a tree-sitter node and append results to a list."""
    if node.type == 'string':
        results.append(node)
        return
    for n in node.children:
        traverse(n, results)
    if not node.children:
        results.append(node)

In [5]:
#| export
def get_token_type(
    tok_span: tuple,            # (start, end) position of a token
    nodes: list,                # list of tree-sitter nodes
    lines: list,                # list of lines in the code
    internal_methods: list,     # list of internal methods
    acceptable_ast_types: list, # list of AST types to accept for internal methods
) -> tuple:                 # (parent_type, token_type) of the token
    """Get the parent AST type and token AST type of a token."""
    def get_node_span(node):
        def convert_to_offset(point):
            row, column = point
            chars_in_rows = sum(map(len, lines[:row])) + row
            chars_in_columns = len(lines[row][:column])

            offset = chars_in_rows + chars_in_columns
            return offset
        start_span = convert_to_offset(node.start_point)
        end_span = convert_to_offset(node.end_point)
        return start_span, end_span
    
    node_spans = [get_node_span(node) for node in nodes]
    for i, span in enumerate(node_spans):
        if (span[0] <= tok_span[0] and tok_span[0] < span[1]) or (span[0] < tok_span[1] and tok_span[1] <= span[1]):
            is_internal = nodes[i].text.decode() in internal_methods and nodes[i].parent.type in acceptable_ast_types
            if not is_internal:
                if nodes[i].parent.parent is not None:
                    if nodes[i].parent.parent.type in "call":
                        if nodes[i].parent.parent.named_children[0].text.decode() in internal_methods:
                            is_internal = True
            return nodes[i].parent.type, nodes[i].type, is_internal

In [6]:
#| export
class CodeTokenizer():
    """A tokenizer for code, which aligns the tokens with the AST nodes."""
    def __init__(
        self,
        tokenizer,          # transformers tokenizer
        parser,             # tree-sitter parser
        language,           # tree-sitter language
        node_types,         # list of node types
        name_or_path,       # name or path of the tokenizer
        program_lang,       # programming language of the tokenizer
        padding_token,  # whether to add a padding token
    ):
        self.tokenizer = tokenizer
        self.parser = parser
        self.language = language
        self.node_types = node_types
        self.name_or_path = name_or_path
        self.program_lang = program_lang
        self.padding_token = padding_token

        if self.program_lang == "python":
            self.acceptable_ast_types = ["call", "argument_list"]
    
    def parse_tree(
        self,
        code,               # code to parse
        offset_mapping,     # offset mapping from the tokenizer to align the tokens with the AST nodes
        internal_methods,   # internal methods to parse the code
    ):                      # returns a list of AST ids and a list of parent AST ids
        tree = self.parser.parse(bytes(code, "utf8"))
        nodes = []
        traverse(tree.root_node, nodes)
        self.nodes = nodes

        ast_ids = []
        parent_ast_ids = []
        is_internal_methods = []
        for i, (start, end) in enumerate(offset_mapping):
            if start == None or end == None:
                ast_ids.append(-1)
                parent_ast_ids.append(-1)
                continue
            type_info = get_token_type(
                (start, end),
                nodes,
                code.split("\n"),
                internal_methods,
                acceptable_ast_types=self.acceptable_ast_types,
            )
            if type_info is None:
                ast_ids.append(-1)
                parent_ast_ids.append(-1)
                is_internal_methods.append(False)
            else:
                parent_node_type, node_type, is_internal = type_info
                try:
                    ast_ids.append(self.node_types.index(node_type))
                    parent_ast_ids.append(self.node_types.index(parent_node_type))
                    is_internal_methods.append(is_internal)
                except Exception as e:
                    print(type_info)
                    print(code)
                    ast_ids.append(-1)
                    parent_ast_ids.append(-1)
                    is_internal_methods.append(False)
                    raise e

        return ast_ids, parent_ast_ids, is_internal_methods
    
    def __call__(
        self,
        code,                   # code or list of code to tokenize
        internal_methods=[],    # list of internal methods to check against
        return_merged=True,     # whether to string representations of the merged ASTs and parent ASTs
        **kwargs                # kwargs for the underlying transformers tokenizer
    ):                          # returns a dictionary of token ids, attention masks, AST ids, parent AST ids, and optionally the string representations of the merged ASTs and parent ASTs
        encoding = self.tokenizer(code, return_offsets_mapping=True, **kwargs)
        if isinstance(code, list):
            if internal_methods == []:
                internal_methods = [[] for _ in range(len(code))]
            encoding["ast_ids"] = []
            encoding["parent_ast_ids"] = []
            encoding["is_internal_methods"] = []
            for i, c in enumerate(code):
                ast_ids, parent_ast_ids, is_internal_methods = self.parse_tree(
                    c,
                    encoding["offset_mapping"][i],
                    internal_methods[i]
                )
                encoding["ast_ids"].append(ast_ids)
                encoding["parent_ast_ids"].append(parent_ast_ids)
                encoding["is_internal_methods"].append(is_internal_methods)
        else:
            encoding["ast_ids"], encoding["parent_ast_ids"], encoding["is_internal_methods"] = self.parse_tree(code, encoding["offset_mapping"], internal_methods)
        
        if return_merged:
            # Merge the AST ids with their parent AST ids and use the names instead of the ids
            if isinstance(code, list):
                encoding["merged_ast"] = []
                for parent_ast_ids, ast_ids, is_internals in zip(
                    encoding["parent_ast_ids"],
                    encoding["ast_ids"],
                    encoding["is_internal_methods"]
                ):
                    merged_ast = []
                    for i, (parent_ast_id, ast_id, is_internal) in enumerate(zip(parent_ast_ids, ast_ids, is_internals)):
                        if parent_ast_id == -1 or ast_id == -1:
                            merged_ast.append("< N/A >")
                        else:
                            if is_internal:
                                merged_ast.append(
                                    f"<{self.node_types[parent_ast_id]} -> {self.node_types[ast_id]} (internal)>"
                                )
                            else:
                                merged_ast.append(
                                    f"<{self.node_types[parent_ast_id]} -> {self.node_types[ast_id]}>"
                                )

                    encoding["merged_ast"].append(merged_ast)
            else:
                encoding["merged_ast"] = []
                for parent_ast_id, ast_id, is_internal in zip(encoding["parent_ast_ids"], encoding["ast_ids"], encoding["is_internal_methods"]):
                    if parent_ast_id == -1 or ast_id == -1:
                        encoding["merged_ast"].append("< N/A >")
                    else:
                        if is_internal:
                            encoding["merged_ast"].append(
                                f"<{self.node_types[parent_ast_id]} -> {self.node_types[ast_id]} (internal)>"
                            )
                        else:
                            encoding["merged_ast"].append(
                                f"<{self.node_types[parent_ast_id]} -> {self.node_types[ast_id]}>"
                            )

        return encoding
    
    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    @staticmethod
    def from_pretrained(
        name_or_path: str,          # name or path of the tokenizer
        program_lang: str,          # language of the tokenizer
        padding_token: str = None,  # padding token to use
    ):                              # CodeTokenizer for the given language
        """Create a CodeTokenizer from a pretrained tokenizer for a given language."""
        tokenizer = AutoTokenizer.from_pretrained(name_or_path)
        if padding_token:
            tokenizer.add_special_tokens({"pad_token": padding_token})

        # Grab the node types from the tree-sitter language
        language = Language(f"{code_tokenizers.__path__[0]}/grammars/tree-sitter-languages.so", program_lang)
        node_path = f"{code_tokenizers.__path__[0]}/grammars/{program_lang}/src/node-types.json"
        with open(node_path) as f:
            node_types = json.load(f)
        node_types = unroll_node_types(node_types)
        if program_lang == "python":
            node_types.append("as_pattern_target")
            node_types.append("ERROR")

        # Create a parser for the language
        parser = Parser()
        parser.set_language(language)
        
        return CodeTokenizer(tokenizer, parser, language, node_types, name_or_path, program_lang, padding_token)
    
    def __reduce__(self):
        return (CodeTokenizer.from_pretrained, (self.name_or_path, self.program_lang, self.padding_token))
    
    def __eq__(self, other):
        return self.name_or_path == other.name_or_path and self.program_lang == other.program_lang and self.padding_token == other.padding_token

In [7]:
# test the tokenizer
py_tokenizer = CodeTokenizer.from_pretrained("gpt2", "python")
code = "def foo():\n    print('hello world')"

encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["is_internal_methods"]) == len(encoding["input_ids"])

In [8]:
# test with list of code
code = ["def foo():\n    print('hello world')", "def bar():\n    print('hello world')"]
encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["parent_ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["merged_ast"][0]) == len(encoding["input_ids"][0])

In [9]:
# test with internal methods
code = "def print():\n    print('print') #print\n    print = 1"
encoding = py_tokenizer(code, internal_methods=["print"])

for i in range(len(encoding["input_ids"])):
    if "call" in encoding["merged_ast"][i] or "argument_list" in encoding["merged_ast"][i]:
        assert encoding["is_internal_methods"][i] == True, encoding["merged_ast"][i]
    else:
        assert encoding["is_internal_methods"][i] == False, encoding["merged_ast"][i]

In [11]:
# test with internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2, internal_methods=[["print"], ["print"]])

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        if "call" in encoding["merged_ast"][i][j] or "argument_list" in encoding["merged_ast"][i][j]:
            assert encoding["is_internal_methods"][i][j] == True, encoding["merged_ast"][i][j]
        else:
            assert encoding["is_internal_methods"][i][j] == False, encoding["merged_ast"][i][j]

In [12]:
# test without internal methods
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer(code)

for i in range(len(encoding["input_ids"])):
    assert encoding["is_internal_methods"][i] == False

In [13]:
# test without internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2)

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        assert encoding["is_internal_methods"][i][j] == False

In [14]:
# test the pickleability of the tokenizer
import pickle

assert py_tokenizer == pickle.loads(pickle.dumps(py_tokenizer))

In [15]:
#|eval: false
# test the time of multi-proc tokenization is faster than single proc tokenization
import time
from datasets import load_dataset

ds = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))

start = time.time()
single_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=1,
    load_from_cache_file=False
)
total_single_proc = time.time() - start

start = time.time()
multi_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=4,
    load_from_cache_file=False
)
total_multi_proc = time.time() - start

assert total_multi_proc < total_single_proc

Downloading readme:   0%|          | 0.00/401 [00:00<?, ?B/s]

Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523


Downloading and preparing dataset json/codeparrot--codeparrot-clean-valid to /home/nathan/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/nathan/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab. Subsequent calls will reuse this data.


  0%|          | 0/10 [00:00<?, ?ex/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors


        

#0:   0%|          | 0/3 [00:00<?, ?ex/s]

#1:   0%|          | 0/3 [00:00<?, ?ex/s]

#2:   0%|          | 0/2 [00:00<?, ?ex/s]

#3:   0%|          | 0/2 [00:00<?, ?ex/s]

In [16]:
#|eval: false
# test that the two datasets tokenized with single and multi processing are identical

for i in range(len(ds)):
    assert single_proc_ds[i]["input_ids"] == multi_proc_ds[i]["input_ids"]
    assert single_proc_ds[i]["attention_mask"] == multi_proc_ds[i]["attention_mask"]
    assert single_proc_ds[i]["offset_mapping"] == multi_proc_ds[i]["offset_mapping"]
    assert single_proc_ds[i]["ast_ids"] == multi_proc_ds[i]["ast_ids"]
    assert single_proc_ds[i]["parent_ast_ids"] == multi_proc_ds[i]["parent_ast_ids"]
    assert single_proc_ds[i]["merged_ast"] == multi_proc_ds[i]["merged_ast"]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()