In [1]:
import os
import clang
from clang.cindex import *
from copy import deepcopy
import re
import random

In [2]:
Config.set_library_file("/home/dipu/anaconda3/lib/python3.9/site-packages/clang/native/libclang.so")

In [3]:
index = Index.create()
tu = index.parse("main.c")
root_cursor = tu.cursor

In [4]:
def get_spelling(node):
    children = list(node.get_children())
    tokens_list = list(node.get_tokens())
    
    if node.kind == CursorKind.UNARY_OPERATOR:
        tokens = tokens_list[0].spelling
        return str(tokens)
    elif node.kind == CursorKind.BINARY_OPERATOR:
        left_list = list(children[0].get_tokens())
        right_list = list(children[1].get_tokens())
        left = "".join([token.spelling for token in left_list])
        right = "".join([token.spelling for token in right_list])
        tokens = "".join([token.spelling for token in tokens_list])
        
        if len(left_list) + len(right_list) == len(tokens_list):
            tokens = tokens.replace(left, "", 1)
            tokens = tokens.replace(right, "", 1)
            return tokens
        elif len(tokens_list) == 3:
            return tokens_list[1].spelling
        elif len(left_list) < len(tokens_list):
            return tokens_list[len(left_list)].spelling
        elif len(right_list) < len(tokens_list):
            return tokens_list[len(tokens_list)-len(right_list)-1].spelling
        else:
            return node.spelling

    else:
        return node.spelling

In [5]:
def print_ast(node, indent):
    try:
        current_name = "".join([node.spelling for x in list(node.get_tokens())]) if len(list(node.get_tokens())) > 0 else node.spelling
        # print(" "*indent + node.spelling + " " + str(node.kind) + " " + str(node.type.spelling))
        print("  "*indent + get_spelling(node) + " " + str(node.kind))
        for c in node.get_children():
            print_ast(c, indent+2)
    except ValueError:
        pass

print_ast(root_cursor, 0)

main.c CursorKind.TRANSLATION_UNIT
    __u_char CursorKind.TYPEDEF_DECL
    __u_short CursorKind.TYPEDEF_DECL
    __u_int CursorKind.TYPEDEF_DECL
    __u_long CursorKind.TYPEDEF_DECL
    __int8_t CursorKind.TYPEDEF_DECL
    __uint8_t CursorKind.TYPEDEF_DECL
    __int16_t CursorKind.TYPEDEF_DECL
    __uint16_t CursorKind.TYPEDEF_DECL
    __int32_t CursorKind.TYPEDEF_DECL
    __uint32_t CursorKind.TYPEDEF_DECL
    __int64_t CursorKind.TYPEDEF_DECL
    __uint64_t CursorKind.TYPEDEF_DECL
    __int_least8_t CursorKind.TYPEDEF_DECL
        __int8_t CursorKind.TYPE_REF
    __uint_least8_t CursorKind.TYPEDEF_DECL
        __uint8_t CursorKind.TYPE_REF
    __int_least16_t CursorKind.TYPEDEF_DECL
        __int16_t CursorKind.TYPE_REF
    __uint_least16_t CursorKind.TYPEDEF_DECL
        __uint16_t CursorKind.TYPE_REF
    __int_least32_t CursorKind.TYPEDEF_DECL
        __int32_t CursorKind.TYPE_REF
    __uint_least32_t CursorKind.TYPEDEF_DECL
        __uint32_t CursorKind.TYPE_REF
    __int_least64

In [6]:
# regex for matching string containing numbers, capital and small letters, +, -, *, /, %, _, parentheses and comma only
expression_pattern = re.compile(r"^[a-zA-Z0-9_().,+\-*/%<>=!&|~\^]+$")

# regex for valid multiplication operator or & operator (not pointer)
left_pattern = re.compile(r"^[a-zA-Z0-9_)]+$")
right_pattern = re.compile(r"^[a-zA-Z0-9_(!~]+$")

# operator list (except * &)
operator_list = ["+", "-", "/", "%", "++", "--", "<", "<=", ">", ">=", "==", "!=", "&&", "||", "!", "|", "<<", ">>", "~", "^"]

In [7]:
def is_required_expression(node):
    # avoiding equals to ('=') operator
    if list(node.get_children())[0].kind == CursorKind.DECL_REF_EXPR:
        return False
    
    tokens = list(node.get_tokens())
    operator_set = set()
    
    for i in range(len(tokens)):
        spell = str(tokens[i].spelling)
        if (not expression_pattern.match(spell)) or (spell == "="):
            return False
        
        if (spell in operator_list) or \
        ((spell == "*" or spell == "&") and 0 < i < len(tokens)-1 and \
        left_pattern.match(str(tokens[i-1].spelling)) and \
        right_pattern.match(str(tokens[i+1].spelling))):
            operator_set.add(spell)
            
    
    # expression with atleast two different operators is needed
    if len(operator_set) <= 1:
        return False
    elif len(operator_set) == 2:
        # excluding operations having only '+' and '-' as both have same precedence
        if ('+' in operator_set) and ('-' in operator_set):
            return False
        # excluding operations having only '*' and '/' as both have same precedence
        elif ('*' in operator_set) and ('/' in operator_set):
            return False
                
    return True

In [8]:
precedence = {
    '++': 1, '--': 1, '!': 1, '~': 1,
    '*': 2, '/': 2, '%': 2,
    '+': 3, '-': 3,
    '<<': 4, '>>': 4,
    '<': 5, '>': 5, '<=': 5, '>=': 5,
    '==': 6, '!=': 6,
    '&': 7,
    '^': 8,
    '|': 9,
    '&&': 10,
    '||': 11
    }

# checks if removal of parenthesis results in precedence bug (i.e. different result) in the operation
def is_precedence_higher(op1_node, op2_node, parenthesis_at_right):
    op1 = get_spelling(op1_node)
    op2 = get_spelling(op2_node)
    
    # get the precedence of op1 and op2, default to 999 if not found
    precedence_op1 = precedence.get(op1, 999)
    precedence_op2 = precedence.get(op2, 999)
    
    if op1_node.kind == CursorKind.UNARY_OPERATOR and op1 == "-":
        precedence_op1 = 1
        
    if op2_node.kind == CursorKind.UNARY_OPERATOR and op2 == "-":
        precedence_op2 = 1
        
    # There is exception in case of '*', '/' and '%'
    # if '%' is done before or after '*' or '/' then result will vary unlike doing '*' before or after '/'
    # so removal of parantheses can create precedence bug due to left to right associativity
    if parenthesis_at_right and \
    ((op1 == '%' and (op2 == '*' or op2 == '/')) or ((op1 == '*' or op1 == '/') and op2 == '%')):
        return True
        
    return precedence_op1 < precedence_op2

In [9]:
def get_negative_sample(node, parent_node, grandparent_node, replacable):
    
    if node and node.kind == CursorKind.PAREN_EXPR:
        children = list(node.get_children())
        
        if len(children) > 0 and children[0].kind == CursorKind.BINARY_OPERATOR:
            if parent_node:
                if parent_node.kind == CursorKind.BINARY_OPERATOR or parent_node.kind == CursorKind.UNARY_OPERATOR:
                    grand_children = list(children[0].get_children())

                    if len(grand_children) > 0: # and grand_children[0].kind != CursorKind.BINARY_OPERATOR and grand_children[0].kind != CursorKind.PAREN_EXPR:
                        
                        parenthesis_at_left = node == list(parent_node.get_children())[0]
                        if is_precedence_higher(parent_node, children[0], not parenthesis_at_left):
                            # remove parenthesis of node (i.e. return tokens of children[0])
                            replacable.append([token.spelling for token in list(node.get_tokens())])
                elif parent_node.kind == CursorKind.UNEXPOSED_EXPR and grandparent_node and \
                (grandparent_node.kind == CursorKind.BINARY_OPERATOR or grandparent_node.kind == CursorKind.UNARY_OPERATOR):
                    grand_children = list(children[0].get_children())

                    if len(grand_children) > 0: # and grand_children[0].kind != CursorKind.BINARY_OPERATOR and grand_children[0].kind != CursorKind.PAREN_EXPR:
                        
                        parenthesis_at_left = parent_node == list(grandparent_node.get_children())[0]
                        if is_precedence_higher(grandparent_node, children[0], not parenthesis_at_left):
                            # remove parenthesis of node (i.e. return tokens of children[0])
                            replacable.append([token.spelling for token in list(node.get_tokens())])

    for child in node.get_children():
        get_negative_sample(child, node, parent_node, replacable)

In [10]:
def remove_parenthesis(full_list, remove_list):
    for i in range(len(full_list)):
        match_count = 0
        for j in range(len(remove_list)):
            if full_list[i+j] == remove_list[j]:
                match_count += 1
            else:
                break
                
        if len(remove_list) == match_count:
            return full_list[0:i] + full_list[i+1:i+len(remove_list)-1] + full_list[i+len(remove_list):]
            
    return full_list

In [11]:
def get_binary_expressions(node, exp_list):
    try:
        if (node.kind == CursorKind.BINARY_OPERATOR or node.kind == CursorKind.UNARY_OPERATOR)\
        and is_required_expression(node):
            tokens = list(node.get_tokens())
            expression = [token.spelling for token in tokens]
            exp_list.append(expression)
            
            replacable = []
            get_negative_sample(node, None, None, replacable)
            
            if len(replacable) > 0:
                replacable = sorted(replacable, key=lambda x: len(x), reverse=True)
                number_of_replacement = random.randrange(len(replacable)) + 1
                
                random_indices = random.sample(range(0, len(replacable)), number_of_replacement)
                random_indices.sort()   
                
                negative_sample = expression.copy()
                for idx in random_indices:
                    negative_sample = remove_parenthesis(negative_sample, replacable[idx])
                exp_list.append(negative_sample)
            else:
                exp_list.append([])
                
        else:
            for child in node.get_children():
                get_binary_expressions(child, exp_list)
                
    except Exception as e:
        # print("***Exception***", e)
        pass

In [12]:
exp_list = []
get_binary_expressions(root_cursor, exp_list)
print(exp_list)

[['(', 'a', '/', '(', 'b', '*', 'a', ')', ')', '%', 'b'], []]


---
## Using C Code Corpus Dataset
---

In [13]:
def generate_binary_operator_expression_dataset(root_dir):
    total_files, total_samples = 0, 0
    global current_file
    
    with open("binary_operator_expression_dataset.csv", 'a') as dataset:
        
        for root, dirs, files in os.walk(root_dir):
                for file in files:
                    if file.endswith(".c"):
                        total_files += 1

                        file_path = os.path.join(root, file)
                        current_file = file_path
                            
                        with open(file_path, 'rb') as f:
                                
                            content = str(f.read())
                            # print(current_file)

                            # ignoring .c files having more than 10,000 lines of code
                            if content.count("\\n") <= 10_000:

                                try:
                                    start_cursor = index.parse(file_path).cursor

                                    expression_list = []
                                    get_binary_expressions(start_cursor, expression_list)

                                    for expression in expression_list:
                                        positive_sample = "\t".join(expression)
                                        dataset.write(positive_sample + "\n")

                                        total_samples += 1

                                except:
                                        print("---Error occurred---")

                        current_file = file_path

                        if total_files % 1000 == 0:
                            print("Total files:", total_files, ",", "Total samples:", total_samples)

In [None]:
root_dir = '/home/dipu/Documents/AI/MinorProject/c-corpus/'

generate_binary_operator_expression_dataset(root_dir)

Total files: 1000 , Total samples: 1496 119.0
Total files: 2000 , Total samples: 3722 262.0
Total files: 3000 , Total samples: 8982 1060.0
Total files: 4000 , Total samples: 12544 1400.0
Total files: 5000 , Total samples: 15020 1706.0
Total files: 6000 , Total samples: 17604 2040.0
Total files: 7000 , Total samples: 23074 2275.0
Total files: 8000 , Total samples: 27924 2758.0
Total files: 9000 , Total samples: 32848 3229.0
Total files: 10000 , Total samples: 38638 3763.0
Total files: 11000 , Total samples: 47718 5037.0
