# Modifying AST with `srcml`
Parses very OK! Can we modify??

In [8]:
# Imports
from lxml import etree as et
from lxml.builder import ElementMaker
import subprocess
from pathlib import Path
import copy
import re

srcml_exe = 'srcml/bin/srcml'
namespaces={'src': 'http://www.srcML.org/srcML/src'}
E = ElementMaker(namespace="http://www.srcML.org/srcML/src")

In [9]:
# Print XML from root
def prettyprint(node):
    print(et.tostring(node, encoding="unicode", pretty_print=True))
# prettyprint(xmldata)

def xp(node, xpath):
    return node.xpath(xpath, namespaces=namespaces)

def start_pos(node):
    # prettyprint(node)
    return node.get('{http://www.srcML.org/srcML/position}start')

def get_space(node, front_back):
    if front_back == 'front':
        regex = rf'^<{et.QName(node).localname}[^>]+>(\s+)'
    elif front_back == 'back':
        regex = r'(\s+)$'
    else:
        raise
    m = re.search(regex, et.tostring(node, encoding='unicode'))
    if m:
        return m.group(1)
    else:
        return ''

In [39]:
# Functions for running srcml command.
def srcml(filepath):
    """Run srcml.
    If the filepath is a .c file, return xml tree as lxml ElementTree.
    If the filepath is an .xml file, return source code as a string."""
    assert filepath.exists()
    args = [srcml_exe, filepath]
    args = [str(a) for a in args]
    if filepath.suffix == '.c':
        args += ['--position']
    print('Running SrcML:', ' '.join(args))
    proc = subprocess.run(args, capture_output=True)
    if proc.returncode != 0:
        print('Error', proc.returncode)
        print(proc.stderr)
        return None
    if filepath.suffix == '.xml':
        return proc.stdout.decode('utf-8')
    elif filepath.suffix == '.c':
        with open(str(filepath) + '.xml', 'wb') as f:
            f.write(proc.stdout)
        xml = et.fromstring(proc.stdout)
        return xml

fname = Path('tests/testbed/testbed.c')
root = srcml(fname)
xmldata = xp(root, '//src:unit')[0]
# assert 4 == len(xp(xmldata, '//src:if_stmt'))
# assert 1 == len(xp(xmldata, '//src:switch'))

Running SrcML: srcml/bin/srcml tests/testbed/testbed.c --position


In [40]:
# refactoring: Permute Statement
from dataclasses import dataclass
import import_ipynb
import pdg
import importlib
import itertools
importlib.reload(pdg)

basic = ['expr_stmt', 'decl_stmt']
skip = []

@dataclass
class Statement:
    line: int
    column: int
    code: str
    node: et.Element

    def __str__(self):
        return f'({self.line}:{self.column}) {self.code}'

    def __repr__(self):
        return str(self)

def get_basic_blocks(root):
    blocks = []

    for block_content in xp(root, './/src:block_content'):
        b = []
        # BFS
        q = [block_content]
        visited = set()
        while len(q) > 0:
            u = q.pop(0)
            if et.QName(u).localname in skip:
                # print('skip', "".join(u.itertext()))
                pass
            elif et.QName(u).localname in basic and len(xp(u, './/src:condition')) == 0:
                line, column = map(int, start_pos(u).split(':'))
                text = "".join(u.itertext())
                # print('adding', text)
                b.append(Statement(line, column, text, u))  # Add stmt to current block
            else:
                # print("".join(u.itertext()), 'store', [s[2] for s in b])
                if len(b) > 0:
                    blocks.append(b)  # Add current block to result
                    b = []
            visited.add(u)
            if et.QName(u).localname == 'block_content' or len(xp(u, './/src:block_content')) == 0:
                for v in u:
                    if v not in visited:
                        q.append(v)
    return blocks

def get_pdg_node(stmt, info):
    methodname = xp(stmt.node, './ancestor::src:function/src:name')[0].text
    g = pdg.load_annotated_ast_nodes(info["project"], methodname)
    match = [n for n in g.values() if n["lineNumber"] == stmt.line]
    if len(match) > 1:
        return sorted(match, key=lambda n: abs(n["columnNumber"] - stmt.column))[0]
    elif len(match) == 1:
        return match[0]
    else:
        return None

def independent_stmts(basic_block):
    independent = []
    for i, b in enumerate(basic_block):
        for c in basic_block[i+1:]:
            if c[1]["id"] not in b[1]["dependencies"] and b[1]["id"] not in c[1]["dependencies"]:
                independent.append((b[0].node, c[0].node))
    return independent

def permute_stmt(root, picker=lambda i: i[0], info=None):
    assert info is not None
    root = copy.deepcopy(root)
    basic_blocks = get_basic_blocks(root)
    # print('starting basic blocks:', '\n'.join(map(str, basic_blocks)))
    # return
    candidate_blocks = []
    for b in basic_blocks:
        # if all(get_pdg_node(s) is not None for s in b):
        pdg_nodes = [get_pdg_node(s, info) for s in b]
        # print(str(b), [n["code"] if n is not None else None for n in pdg_nodes])
        new_b = [(s, p) for s, p in zip(b, pdg_nodes) if p is not None]
        if len(new_b) > 1:
            candidate_blocks.append(new_b)
    # print('candidate basic blocks:', '\n'.join(str([str(s[0]) for s in b]) for b in candidate_blocks))

    independent = list(itertools.chain(*(independent_stmts(block) for block in candidate_blocks)))
    # print(len(independent))
    picked = picker(independent)
    # list(map(prettyprint, picked))
    a, b = picked
    a_parent = a.getparent()
    a_idx = a.getparent().index(a)
    del a_parent[a_idx]
    new_b = copy.deepcopy(b)
    a_parent.insert(a_idx, new_b)
    # new_b.tail = '\n'
    b_parent = b.getparent()
    b_idx = b.getparent().index(b)
    del b_parent[b_idx]
    new_a = copy.deepcopy(a)
    b_parent.insert(b_idx, new_a)
    # new_a.tail = '\n'
    
    return root

my_root = srcml(Path('tests/blocktest/blocktest.c'))
test_xmldata = permute_stmt(xp(my_root, '//src:unit')[0], info={"project": 'tests/blocktest'})

importing Jupyter notebook from pdg.ipynb
Running SrcML: srcml/bin/srcml tests/blocktest/blocktest.c --position


In [41]:
# Refactoring: rename variable 
from random_word import RandomWords
words = RandomWords()

def rename_variable(root, picker=lambda i: i[0], info=None):
    root = copy.deepcopy(root)
    all_names = xp(root, f'//src:function//src:decl_stmt/src:decl/src:name')
    target_name_node = picker(all_names)
    original_target_name = target_name = target_name_node.text
    target_name = words.get_random_word()
    # NOTE: This is not necessary for now
    # all_names_str = [c.text for c in all_names]
    # while target_name in all_names_str:
    #     target_name = words.get_random_word()
    function_name = xp(target_name_node, './ancestor::src:function')[0].xpath('./src:name', namespaces=namespaces)[0].text
    targets = xp(root, f'//src:name[text() = "{original_target_name}"][ancestor::src:function[./src:name[text() = "{function_name}"]]]')
    assert len(targets) > 0
    for target in targets:
        target.text = target_name
    return root

renamed_xmldata = rename_variable(xmldata)

In [42]:
# Refactoring: insert noop
def insert_noop(root, picker=lambda i: i[0], info=None):
    root = copy.deepcopy(root)
    all_targets = xp(root, f'//src:*[contains(local-name(), "_stmt")]')
    target = picker(all_targets)
    target_parent = target.getparent()
    decl = E.decl_stmt(E.decl(
        E.type(E.name('int'), ' '),
        E.name('fubar', ' '),
        E.init('= ', E.expr(E.literal('123', type='number'))), ';'))
    # decl.tail = target.tail
    target_parent.insert(target_parent.index(target)+1, decl)
    return root

nooped_xmldata = insert_noop(xmldata)
assert len(xp(xmldata, '//src:name[text() = "123"]')) == 0
assert len(xp(nooped_xmldata, '//src:literal[text() = "123"]')) > 0

In [43]:
# Refactoring: exchange switch with if/else
from collections import OrderedDict

def get_stmts_by_case(switch):
    """
    Return an ordered mapping of executable statements to their respective case ranges.
    Switch cases have peculiar control flow because of default and fallthrough.
    The only fallthroughs we handle are ones where all cases in the fallthrough have the
    same executable statements. For example, this switch is not handled:
    We only handle defaults if they come after all cases.
    """
    block_content = xp(switch, './src:block/src:block_content')[0]
    stmts_by_case = OrderedDict()
    cases_key = None
    cases = []

    # Parse all executable statements (stmt) from the switch.
    for node in block_content:
        if et.QName(node).localname == 'case' or et.QName(node).localname == 'default':
            cases.append(node)
            cases_key = tuple(copy.deepcopy(cases))
        else:
            if cases_key not in stmts_by_case:
                stmts_by_case[cases_key] = []
            stmts_by_case[cases_key].append(node)
            if et.QName(node).localname in ['break', 'return']:
                cases = []

    # All blocks of executable statements must end with a "break;"
    for cases, stmts in stmts_by_case.items():
        tag_name = et.QName(stmts[-1]).localname
        if tag_name == 'break':
            stmts.pop()
        elif tag_name == 'return':
            pass
        else:
            raise

    # Disallow all fallthrough blocks because they are not sound
    def get_case_text(cases):
        """Get the text inside a collection of cases"""
        result = set()
        for c in cases:
            expr = xp(c, './/src:expr')
            if len(expr) > 0:
                result.add(''.join(expr[0].itertext()))
            else:
                result.add('default')
        return result
    items = list(stmts_by_case.items())
    for i in range(len(items)-1):
        cases = items[i][0]
        assert not get_case_text(cases).issubset(get_case_text(items[i+1][0])), 'Fallthroughs are not supported!'

    return stmts_by_case

def gen_if_stmt(switch):
    """Generate a big if_stmt (if/elif/else) from a switch statement"""

    stmts_by_case = get_stmts_by_case(switch)
    narrow_ws = get_space(xp(switch, './src:condition')[0], 'back')
    wide_ws = get_space(xp(switch, './src:block/src:block_content')[0], 'front')
    condition_variable = copy.deepcopy(xp(switch, './src:condition/src:expr')[0])
    # condition_variable.tail = ''
    IF, ELIF, ELSE = range(3)  # Type of conditional to generate

    def gen_conditional(cases, stmts, if_type):
        """Generate and return a conditional (if/elif/else for a switch case)"""

        if if_type in [IF, ELIF]:
            # Generate a boolean condition expression or'ing together all cases
            sub_exprs = []
            for i, case in enumerate(cases):
                if et.QName(case).localname == 'case':
                    case_value = xp(case, './src:expr')[0]
                    # case_value.tail = ''
                    sub_expr = E.expr(copy.deepcopy(condition_variable), ' ', E.operator('=='), case_value)
                    sub_exprs.append(sub_expr)
                if i < len(cases) - 1:
                    sub_exprs.append(E.operator('||'))
            # stmts[-1].tail = narrow_ws
            condition = E.expr(*sub_exprs)

        # We have to use __call__ because calling a funcition named "if" is a syntax error
        if if_type == IF:
            return E.__call__('if', 'if ', E.condition('(', condition, ')'), narrow_ws, E.block('{', E.block_content(wide_ws, *stmts), '}'))
        elif if_type == ELIF:
            return E.__call__('if', 'else if ', E.condition('(', condition, ')'), narrow_ws, E.block('{', E.block_content(wide_ws, *stmts), '}'), type='elseif')
        elif if_type == ELSE:
            return E.__call__('else', 'else ', narrow_ws, E.block('{', E.block_content(wide_ws, *stmts)), narrow_ws, '}')
        else:
            raise

    items = list(stmts_by_case.items())
    ifs = []
    # Move default to last if it is alone
    for i, (cases, stmts) in enumerate(items):
        if any(et.QName(c).localname == 'default' for c in cases):
            default = items.pop(i)
            items.append(default)
            break
    for i, (cases, stmts) in enumerate(items):
        ifs.append(narrow_ws)
        if any(et.QName(c).localname == 'default' for c in cases):
            ifs.append(gen_conditional(cases, stmts, ELSE))
        elif i == 0:
            ifs.append(gen_conditional(cases, stmts, IF))
        else:
            ifs.append(gen_conditional(cases, stmts, ELIF))
    if_stmt = E.if_stmt(*ifs, narrow_ws)
    return if_stmt

def switch_exchange(root, picker=lambda i: i[0], info=None):
    root = copy.deepcopy(root)
    all_switches = xp(root, f'//src:switch')
    target = picker(all_switches)
    if_stmt = gen_if_stmt(target)
    target.getparent().replace(target, if_stmt)
    return root

exchanged_xmldata = switch_exchange(xmldata)
original_count = len(xp(xmldata, '//src:if') + xp(xmldata, '//src:else'))
difference = len(xp(exchanged_xmldata, '//src:if') + xp(exchanged_xmldata, '//src:else')) - original_count
assert difference == 3
# difference

In [44]:
# Refactoring: exchange for loop with while

import re
def loop_exchange(root, picker=lambda i: i[0], info=None):
    root = copy.deepcopy(root)
    all_loops = xp(root, f'//src:for')
    loop = picker(all_loops)
    loop_parent = loop.getparent()
    loop_idx = loop_parent.index(loop)
    block = xp(loop, './src:block')[0]
    block_content = xp(block, './src:block_content')[0]

    # Deconstruct loop control node
    loop_control = xp(loop, './src:control')[0]
    init, cond, incr = loop_control
    init = init[0]  # "int i = 0"
    cond = cond[0]  # "i < n"
    incr = incr[0]  # "i ++"
    # init.tail = ';' + loop.tail
    init.tail = ';'
    # cond.tail = ''

    # Insert loop initializer
    loop_parent.insert(loop_idx, init)
    
    # Insert increment statement
    incr_stmt = E.expr_stmt(incr)
    whitespace_before_content = get_space(block_content, 'front')
    # Adjust whitespace of the increment statement and the last line in the block
    # block_content[-1].tail, incr_stmt.tail = whitespace_before_content, ';' + block_content[-1].tail
    incr_stmt.tail = ';'
    block_content.insert(len(block_content)+1, incr_stmt)

    # Replace for loop with while inplace (preserves most whitespace automatically)
    loop.tag = f'{{{namespaces["src"]}}}while'
    loop.text = 'while '
    loop.replace(loop_control, E.condition('(', cond, ')'))
    return root

exchanged_xmldata = loop_exchange(xmldata)
assert len(xp(xmldata, '//src:while')) == 0
assert len(xp(exchanged_xmldata, '//src:while')) > 0

In [56]:

import random
import traceback

def c2c(c_filename, transforms, picker, num_iterations):
    """Do C source-to-source translation"""
    xml = srcml(c_filename)
    mod_filename = c_filename.parent / (c_filename.stem + '.new.c')
    dst_filename = c_filename.parent / (c_filename.stem + '.new.xml')
    info = {"project": str(Path(c_filename).parent)}
    
    # Apply all transforms
    # for t in transforms:
    #     try:
    #         xml = t(xml, picker=picker, info={"project": str(Path(c_filename).parent)})
    #     except:
    #         traceback.print_exc()
    #         pass
    #     et.ElementTree(xml).write(str(dst_filename))
    #     modified_c = srcml(dst_filename)
    #     with open(mod_filename, 'w') as f:
    #         f.write(modified_c)
    #     xml = srcml(mod_filename)

    # Apply num_iterations transforms, chosen randomly
    i = 0
    while i < num_iterations:
        t = random.choice(transforms)
        try:
            xml = t(xml, picker=picker, info=info)
            i += 1
        except:
            traceback.print_exc()
            pass
        et.ElementTree(xml).write(str(dst_filename))
        modified_c = srcml(dst_filename)
        with open(mod_filename, 'w') as f:
            f.write(modified_c)
        xml = srcml(mod_filename)
    
    with open(mod_filename) as f:
        return f.read(), mod_filename

c_file = Path('tests/testbed2/testbed2.c')
transforms = [
    insert_noop,
    switch_exchange,
    loop_exchange,
    rename_variable,
    permute_stmt,
]
picker = random.choice
num_iterations = 3
new_c_code, new_c_file = c2c(c_file, transforms, picker, num_iterations)
import difflib
differ = difflib.Differ()
diffs = differ.compare(c_file.open().readlines(), new_c_code.splitlines(keepends=True))
r = []
line_nums = []
lineno = 0
for line in diffs:
    if line[0] in (' ', '+'):
        lineno += 1
    print(str(lineno).ljust(3), ' ', line, end='')
    if line[0] == '+':
        if len(r) == 0 or r[-1] == lineno-1:
            r.append(lineno)
        else:
            line_nums.append(r)
            r = [lineno]
line_nums.append(r)
lines = []
for r in line_nums:
    lines.append(f'--lines={r[0]}:{r[-1]}')
subprocess.run(f'clang-format {" ".join(lines)} {new_c_file} > {str(new_c_file).replace(".new.", ".formatted.")}', shell=True, check=True)
# diff = list(difflib.unified_diff(c_file.open().readlines(), new_c_code.splitlines(keepends=True)))
# print(''.join(diff))

Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.c --position
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.xml
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.c --position
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.xml
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.c --position
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.xml
Running SrcML: srcml/bin/srcml tests/testbed2/testbed2.new.c --position
1       /*
2       Description: A chroot() is performed without a chdir().
3       Keywords: Unix C Size0 Complex0 Api Chroot
4       
5       Copyright 2005 Fortify Software.
6       
7       Permission is hereby granted, without written agreement or royalty fee, to
8       use, copy, modify, and distribute this software and its documentation for
9       any purpose, provided that the above copyright notice and the following
10      three paragraphs appear in all copies of this software.
11      
12      IN NO EVENT SHAL

CompletedProcess(args='clang-format --lines=36:36 --lines=38:39 --lines=41:42 --lines=46:50 --lines=53:55 --lines=57:57 --lines=59:59 --lines=61:61 --lines=95:95 --lines=98:98 tests/testbed2/testbed2.new.c > tests/testbed2/testbed2.formatted.c', returncode=0)

# Modifying AST with `pycparser`
This is not suitable because pycparser destroys whitespace.n

In [46]:
#-----------------------------------------------------------------
# pycparser: rewrite_ast.py
#
# Tiny example of rewriting a AST node
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#-----------------------------------------------------------------
from __future__ import print_function

import pycparser
import os
import pycparser_fake_libc

fake_libc_arg = "-I" + pycparser_fake_libc.directory

generator = pycparser.c_generator.CGenerator()

ast = pycparser.parse_file('x42/c/X42.c', use_cpp=True, cpp_args=fake_libc_arg)
print("Before:")
# ast.show(offset=2)

def print_ast(ast):
  text = generator.visit(ast)
  print(text[text.find('typedef uint32_t xcb_visualid_t;'):])
print_ast(ast)

def recurse(node, fn):
  fn(node)
  for child_node in node.children():
    recurse(node)

for e in ast.ext:
  if isinstance(e, pycparser.c_ast.FuncDef):
    # e.show(offset=2)
    def change_x_to_y_2(node):
      if isinstance(node, pycparser.c_ast.Assignment) and node.lvalue.name == 'x':
        assign.lvalue.name = "y"
        assign.rvalue.value = 2

print("After:")
# ast.show(offset=2)
print_ast(ast)

AttributeError: module 'pycparser' has no attribute 'c_generator'

# Modifying AST with `libclang`
`libclang doesn't support modifying AST directly.`

In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/home/benjis/work/transform/llvm-project-11.1.0.src/build/lib'

In [None]:
#!/usr/bin/env python
""" Usage: call with <filename> <typename>
"""

import clang.cindex
from clang.cindex import CursorKind

clang.cindex.Config.set_library_file('llvm-project-11.1.0.src/build/lib/libclang.so')

def get(node, l):
    if l(node):
        return node
    for c in node.get_children():
        result = get(c, l)
        if result is not None:
            return result
    return None

filename = 'chroot1.c'
index = clang.cindex.Index.create()
tu = index.parse(filename)
print('Translation unit:', tu.spelling)
test = get(tu.cursor, lambda n: n.kind == CursorKind.FUNCTION_DECL and n.spelling == 'test')

Translation unit: chroot1.c


In [None]:
def get_variable_references(node, old_name):
    collect = []
    if node.kind == CursorKind.VAR_DECL or node.kind == CursorKind.DECL_REF_EXPR:
        if node.spelling == old_name:
            collect.append(node)
    for c in node.get_children():
        result = get_variable_references(c, old_name)
        if result:
            collect += result
    return collect

var_ref = get_variable_references(test, 'fd')
for v in var_ref:
    print(v.spelling)

fd
fd
fd
fd


In [None]:
def print_internal(node, replacements, output, section):
    for c in node.get_children():
        print_node(c, replacements)

def print_node(node, replacements, sections=''):
    sections = (' ' * (node.extent.end.line - node.extent.start.line))
    print(len(sections))
print_node(test, [])

10
