# core

> Fill in a module description here

In [None]:
# | default_exp core

In [None]:
# | export
import code_tokenizers
import json

from transformers import AutoTokenizer
from tree_sitter import Language, Parser

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
# | export
class ASTNode:
    def __init__(self, node, is_internal, is_builtin, node_types):
        self.node = node
        self.is_internal = is_internal
        self.is_builtin = is_builtin

        self.type = node.type
        self.parent_type = node.parent.type
        self.type_id = node_types.index(self.type)
        self.parent_type_id = node_types.index(self.parent_type)

    def __str__(self):
        if self.type == -1 or self.parent_type == -1:
            return "< N/A >"
        if self.is_internal:
            return f"<{self.parent_type} -> {self.type} (internal)>"
        if self.is_builtin:
            return f"<{self.parent_type} -> {self.type} (builtin)>"
        return f"<{self.parent_type} -> {self.type}>"

In [None]:
# | export
def unroll_node_types(
    nested_node_types: dict,  # node_types from tree-sitter
) -> list:  # list of node types
    """Unroll nested node types into a flat list of node types. This includes subtypes as well."""
    node_types = [node_type["type"] for node_type in nested_node_types]
    node_subtypes = [
        node_subtype["type"]
        for node_type in node_types
        if "subtypes" in node_type
        for node_subtype in node_type["subtypes"]
    ]
    return list(set(node_types + node_subtypes))

In [None]:
# | export
# From: https://github.com/github/CodeSearchNet/tree/master/function_parser
def traverse(
    node,  # tree-sitter node
    results,  # list to append results to
) -> None:
    """Traverse in a recursive way, a tree-sitter node and append results to a list."""
    if node.type == "string":
        results.append(node)
        return
    for n in node.children:
        traverse(n, results)
    if not node.children:
        results.append(node)

In [None]:
# | export
def get_token_type(
    tok_span: tuple,  # (start, end) position of a token
    nodes: list,  # list of tree-sitter nodes
    lines: list,  # list of lines in the code
    internal_methods: list,  # list of internal methods
    acceptable_ast_types: list,  # list of AST types to accept for internal methods
    node_types: list,  # list of node types
) -> tuple:  # (parent_type, token_type) of the token
    """Get the parent AST type and token AST type of a token."""

    def get_node_span(node):
        def convert_to_offset(point):
            row, column = point
            chars_in_rows = sum(map(len, lines[:row])) + row
            chars_in_columns = len(lines[row][:column])

            offset = chars_in_rows + chars_in_columns
            return offset

        start_span = convert_to_offset(node.start_point)
        end_span = convert_to_offset(node.end_point)
        return start_span, end_span

    node_spans = [get_node_span(node) for node in nodes]
    for i, span in enumerate(node_spans):
        if (span[0] <= tok_span[0] and tok_span[0] < span[1]) or (
            span[0] < tok_span[1] and tok_span[1] <= span[1]
        ):
            is_internal = (
                nodes[i].text.decode() in internal_methods
                and nodes[i].parent.type in acceptable_ast_types
            )
            is_builtin = (
                nodes[i].text.decode() in dir(__builtins__)
                and nodes[i].parent.type == "call"
            )
            if not is_internal:
                if nodes[i].parent.parent is not None:
                    if nodes[i].parent.parent.type in "call":
                        if (
                            nodes[i].parent.parent.named_children[0].text.decode()
                            in internal_methods
                        ):
                            is_internal = True

            ast_node = ASTNode(nodes[i], is_internal, is_builtin, node_types)
            return ast_node

In [None]:
# | export
class CodeTokenizer:
    """A tokenizer for code, which aligns the tokens with the AST nodes."""

    def __init__(
        self,
        tokenizer,  # transformers tokenizer
        parser,  # tree-sitter parser
        language,  # tree-sitter language
        node_types,  # list of node types
        name_or_path,  # name or path of the tokenizer
        program_lang,  # programming language of the tokenizer
        padding_token,  # whether to add a padding token
    ):
        self.tokenizer = tokenizer
        self.parser = parser
        self.language = language
        self.node_types = node_types
        self.name_or_path = name_or_path
        self.program_lang = program_lang
        self.padding_token = padding_token

        if self.program_lang == "python":
            self.acceptable_ast_types = ["call", "argument_list"]

    def parse_tree(
        self,
        code,  # code to parse
        offset_mapping,  # offset mapping from the tokenizer to align the tokens with the AST nodes
        internal_methods,  # internal methods to parse the code
    ):  # returns a list of AST ids and a list of parent AST ids
        tree = self.parser.parse(bytes(code, "utf8"))
        nodes = []
        traverse(tree.root_node, nodes)
        self.nodes = nodes

        ast_nodes = []
        for i, (start, end) in enumerate(offset_mapping):
            ast_node = get_token_type(
                (start, end),
                nodes,
                code.split("\n"),
                internal_methods,
                acceptable_ast_types=self.acceptable_ast_types,
                node_types=self.node_types,
            )
            ast_nodes.append(ast_node)
        return ast_nodes

    def __call__(
        self,
        code,  # code or list of code to tokenize
        internal_methods=[],  # list of internal methods to check against
        return_merged=True,  # whether to string representations of the merged ASTs and parent ASTs
        **kwargs,  # kwargs for the underlying transformers tokenizer
    ):  # returns a dictionary of token ids, attention masks, AST ids, parent AST ids, and optionally the string representations of the merged ASTs and parent ASTs
        encoding = self.tokenizer(code, return_offsets_mapping=True, **kwargs)
        encoding["ast_ids"] = []
        encoding["parent_ast_ids"] = []
        encoding["is_internal_methods"] = []
        encoding["is_builtins"] = []
        if isinstance(code, list):
            batched_ast_nodes = []
            if internal_methods == []:
                internal_methods = [[] for _ in range(len(code))]
            for i, c in enumerate(code):
                ast_nodes = self.parse_tree(
                    c, encoding["offset_mapping"][i], internal_methods[i]
                )
                batched_ast_nodes.append(ast_nodes)
                ast_ids, parent_ast_id, is_internal_methods, is_builtin = [], [], [], []
                for ast_node in ast_nodes:
                    if ast_node is None:
                        ast_ids.append(-1)
                        parent_ast_id.append(-1)
                        is_internal_methods.append(False)
                        is_builtin.append(False)
                        continue
                    ast_ids.append(ast_node.type_id)
                    parent_ast_id.append(ast_node.parent_type_id)
                    is_internal_methods.append(ast_node.is_internal)
                    is_builtin.append(ast_node.is_builtin)
                encoding["ast_ids"].append(ast_ids)
                encoding["parent_ast_ids"].append(parent_ast_id)
                encoding["is_internal_methods"].append(is_internal_methods)
                encoding["is_builtins"].append(is_builtin)
        else:
            ast_nodes = self.parse_tree(
                code, encoding["offset_mapping"], internal_methods
            )
            for ast_node in ast_nodes:
                if ast_node is None:
                    encoding["ast_ids"].append(-1)
                    encoding["parent_ast_ids"].append(-1)
                    encoding["is_internal_methods"].append(False)
                    encoding["is_builtins"].append(False)
                    continue
                encoding["ast_ids"].append(ast_node.type_id)
                encoding["parent_ast_ids"].append(ast_node.parent_type_id)
                encoding["is_internal_methods"].append(ast_node.is_internal)
                encoding["is_builtins"].append(ast_node.is_builtin)

        if return_merged:
            # Merge the AST ids with their parent AST ids and use the names instead of the ids
            if isinstance(code, list):
                encoding["merged_ast"] = []
                for ast_nodes in batched_ast_nodes:
                    merged_ast = []
                    for ast_node in ast_nodes:
                        merged_ast.append(
                            str(ast_node) if ast_node is not None else "< N/A >"
                        )
                    encoding["merged_ast"].append(merged_ast)
            else:
                encoding["merged_ast"] = []
                for ast_node in ast_nodes:
                    encoding["merged_ast"].append(
                        str(ast_node) if ast_node is not None else "< N/A >"
                    )

        return encoding

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    @staticmethod
    def from_pretrained(
        name_or_path: str,  # name or path of the tokenizer
        program_lang: str,  # language of the tokenizer
        padding_token: str = None,  # padding token to use
    ):  # CodeTokenizer for the given language
        """Create a CodeTokenizer from a pretrained tokenizer for a given language."""
        tokenizer = AutoTokenizer.from_pretrained(name_or_path)
        if padding_token:
            tokenizer.add_special_tokens({"pad_token": padding_token})

        # Grab the node types from the tree-sitter language
        language = Language(
            f"{code_tokenizers.__path__[0]}/grammars/tree-sitter-languages.so",
            program_lang,
        )
        node_path = (
            f"{code_tokenizers.__path__[0]}/grammars/{program_lang}/src/node-types.json"
        )
        with open(node_path) as f:
            node_types = json.load(f)
        node_types = unroll_node_types(node_types)
        if program_lang == "python":
            node_types.append("as_pattern_target")
            node_types.append("ERROR")

        # Create a parser for the language
        parser = Parser()
        parser.set_language(language)

        return CodeTokenizer(
            tokenizer,
            parser,
            language,
            node_types,
            name_or_path,
            program_lang,
            padding_token,
        )

    def __reduce__(self):
        return (
            CodeTokenizer.from_pretrained,
            (self.name_or_path, self.program_lang, self.padding_token),
        )

    def __eq__(self, other):
        return (
            self.name_or_path == other.name_or_path
            and self.program_lang == other.program_lang
            and self.padding_token == other.padding_token
        )

In [None]:
# test the tokenizer
py_tokenizer = CodeTokenizer.from_pretrained("gpt2", "python")
code = "def foo():\n    print('hello world')"

encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["is_internal_methods"]) == len(encoding["input_ids"])
assert len(encoding["is_builtins"]) == len(encoding["input_ids"])

In [None]:
# test with list of code
code = ["def foo():\n    print('hello world')", "def bar():\n    print('hello world')"]
encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["is_internal_methods"]) == len(encoding["input_ids"])
assert len(encoding["is_builtins"]) == len(encoding["input_ids"])
assert len(encoding["ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["parent_ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["merged_ast"][0]) == len(encoding["input_ids"][0])
assert len(encoding["is_internal_methods"][0]) == len(encoding["input_ids"][0])
assert len(encoding["is_builtins"][0]) == len(encoding["input_ids"][0])

In [None]:
# test with internal methods
code = "def print():\n    print('print') #print\n    print = 1"
encoding = py_tokenizer(code, internal_methods=["print"])

for i in range(len(encoding["input_ids"])):
    if (
        "call" in encoding["merged_ast"][i]
        or "argument_list" in encoding["merged_ast"][i]
    ):
        assert encoding["is_internal_methods"][i] == True, encoding["merged_ast"][i]
    else:
        assert encoding["is_internal_methods"][i] == False, encoding["merged_ast"][i]

In [None]:
# test with internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2, internal_methods=[["print"], ["print"]])

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        if (
            "call" in encoding["merged_ast"][i][j]
            or "argument_list" in encoding["merged_ast"][i][j]
        ):
            assert encoding["is_internal_methods"][i][j] == True, encoding[
                "merged_ast"
            ][i][j]
        else:
            assert encoding["is_internal_methods"][i][j] == False, encoding[
                "merged_ast"
            ][i][j]

In [None]:
# test without internal methods
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer(code)

for i in range(len(encoding["input_ids"])):
    assert encoding["is_internal_methods"][i] == False

In [None]:
# test without internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2)

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        assert encoding["is_internal_methods"][i][j] == False

In [None]:
# test with builtins
code = "def foo():\n    print('print') #print\n    print = 1"
encoding = py_tokenizer(code)

for i in range(len(encoding["input_ids"])):
    if "call" in encoding["merged_ast"][i]:
        assert encoding["is_builtins"][i] == True, encoding["merged_ast"][i]
    else:
        assert encoding["is_builtins"][i] == False, encoding["merged_ast"][i]

In [None]:
# test with builtins and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2)

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        if "call" in encoding["merged_ast"][i][j]:
            assert encoding["is_builtins"][i][j] == True, encoding["merged_ast"][i][j]
        else:
            assert encoding["is_builtins"][i][j] == False, encoding["merged_ast"][i][j]

In [None]:
# |eval: false
# test the pickleability of the tokenizer
import pickle

assert py_tokenizer == pickle.loads(pickle.dumps(py_tokenizer))

In [None]:
# |eval: false
# test the time of multi-proc tokenization is faster than single proc tokenization
import time
from datasets import load_dataset

ds = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))

start = time.time()
single_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=1,
    load_from_cache_file=False,
)
total_single_proc = time.time() - start

start = time.time()
multi_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=4,
    load_from_cache_file=False,
)
total_multi_proc = time.time() - start

assert total_multi_proc < total_single_proc

Downloading readme:   0%|          | 0.00/401 [00:00<?, ?B/s]

Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523


Downloading and preparing dataset json/codeparrot--codeparrot-clean-valid to /work/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /work/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab. Subsequent calls will reuse this data.


  0%|          | 0/10 [00:00<?, ?ex/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors


        

#0:   0%|          | 0/3 [00:00<?, ?ex/s]

#1:   0%|          | 0/3 [00:00<?, ?ex/s]

#2:   0%|          | 0/2 [00:00<?, ?ex/s]

#3:   0%|          | 0/2 [00:00<?, ?ex/s]

In [None]:
# |eval: false
# test that the two datasets tokenized with single and multi processing are identical

for i in range(len(ds)):
    assert single_proc_ds[i]["input_ids"] == multi_proc_ds[i]["input_ids"]
    assert single_proc_ds[i]["attention_mask"] == multi_proc_ds[i]["attention_mask"]
    assert single_proc_ds[i]["offset_mapping"] == multi_proc_ds[i]["offset_mapping"]
    assert single_proc_ds[i]["ast_ids"] == multi_proc_ds[i]["ast_ids"]
    assert single_proc_ds[i]["parent_ast_ids"] == multi_proc_ds[i]["parent_ast_ids"]
    assert single_proc_ds[i]["merged_ast"] == multi_proc_ds[i]["merged_ast"]
    assert (
        single_proc_ds[i]["is_internal_methods"]
        == multi_proc_ds[i]["is_internal_methods"]
    )
    assert single_proc_ds[i]["is_builtins"] == multi_proc_ds[i]["is_builtins"]

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()

ImportError: You must install black: `pip install black` if you wish to use black formatting with nbdev