# Modifying AST with `srcml`
Parses very OK! Can we modify??

In [198]:
from lxml import etree as et
from lxml.builder import ElementMaker
import subprocess
from pathlib import Path
import copy

srcml_exe = 'bin/srcml'
namespaces={'src': 'http://www.srcML.org/srcML/src'}
E = ElementMaker(namespace="http://www.srcML.org/srcML/src")

def srcml(filepath):
    assert filepath.exists()
    args = [srcml_exe, filepath]
    args = [str(a) for a in args]
    print('Running SrcML:', ' '.join(args))
    proc = subprocess.run(args, capture_output=True)
    if proc.returncode != 0:
        print('Error', proc.returncode)
        print(proc.stderr)
        return None
    if filepath.suffix == '.xml':
        return proc.stdout.decode('utf-8')
    elif filepath.suffix == '.c':
        xml = et.fromstring(proc.stdout)
        return xml

fname = Path('testbed.c')
root = srcml(fname)
xmldata = root.xpath('//src:unit', namespaces=namespaces)[0]
assert 4 == len(xmldata.xpath('//src:if_stmt', namespaces=namespaces))
assert 1 == len(xmldata.xpath('//src:switch', namespaces=namespaces))

Running SrcML: bin/srcml testbed.c


In [199]:
# Print XML from root
def prettyprint(node):
    print(et.tostring(node, encoding="unicode", pretty_print=True))
prettyprint(xmldata)

<unit xmlns="http://www.srcML.org/srcML/src" xmlns:cpp="http://www.srcML.org/srcML/cpp" revision="1.0.0" language="C" filename="testbed.c"><comment type="block">/*
Description: A chroot() is performed without a chdir().
Keywords: Unix C Size0 Complex0 Api Chroot

Copyright 2005 Fortify Software.

Permission is hereby granted, without written agreement or royalty fee, to
use, copy, modify, and distribute this software and its documentation for
any purpose, provided that the above copyright notice and the following
three paragraphs appear in all copies of this software.

IN NO EVENT SHALL FORTIFY SOFTWARE BE LIABLE TO ANY PARTY FOR DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF FORTIFY SOFTWARE HAS
BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMANGE.

FORTIFY SOFTWARE SPECIFICALLY DISCLAIMS ANY WARRANTIES INCLUDING, BUT NOT
LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PUR

In [200]:
def rename_variable(root, var_name, new_name):
    root = copy.deepcopy(root)
    for r in root.xpath(f'//src:name[text() = "{var_name}"]', namespaces=namespaces):
        r.text = new_name
    return root

xmldata_test = xmldata.xpath('//src:function[./src:name[text() = "test"]]', namespaces=namespaces)[0]
renamed_xmldata = rename_variable(xmldata_test, 'fd', 'file_descriptor')
assert len(xmldata.xpath('//src:name[text() = "file_descriptor"]', namespaces=namespaces)) == 0
assert len(renamed_xmldata.xpath('//src:name[text() = "file_descriptor"]', namespaces=namespaces)) > 0

In [201]:
def insert_noop(root):
    root = copy.deepcopy(root)
    all_targets = root.xpath(f'//src:decl_stmt', namespaces=namespaces)
    target = next(iter(all_targets))
    target_parent = target.getparent()
    decl = E.decl_stmt(E.decl(
        E.type(E.name('int'), ' '),
        E.name('fubar', ' '),
        E.init('= ', E.expr(E.literal('123', type='number'))), ';'))
    decl.tail = target.tail
    target_parent.insert(target_parent.index(target)+1, decl)
    return root

nooped_xmldata = insert_noop(xmldata)
assert len(xmldata.xpath('//src:name[text() = "123"]', namespaces=namespaces)) == 0
assert len(nooped_xmldata.xpath('//src:literal[text() = "123"]', namespaces=namespaces)) > 0

In [202]:
from collections import defaultdict
def switch_exchange(root):
    root = copy.deepcopy(root)
    all_switches = root.xpath(f'//src:switch', namespaces=namespaces)
    target = all_switches[0]
    variable = copy.deepcopy(target.xpath('./src:condition/src:expr', namespaces=namespaces)[0])
    # prettyprint(variable)
    block_content = target.xpath('./src:block/src:block_content', namespaces=namespaces)[0]
    # prettyprint(block_content)
    stmts_by_case = defaultdict(list)
    cases = []
    for stmt in block_content:
        # print(et.QName(stmt).localname)
        if et.QName(stmt).localname == 'case' or et.QName(stmt).localname == 'default':
            cases.append(stmt)
        elif et.QName(stmt).localname == 'break' or len(stmt.xpath('//src:break', namespaces=namespaces)) == 0:
            cases = []
        else:
            stmts_by_case[tuple(cases)].append(stmt)
    def get_if(cases, stmts, if_type):
        exprs = []
        # print(cases)
        case_cases = [c for c in cases if et.QName(c).localname == 'case']
        for i, case in enumerate(case_cases):
            case_value = copy.deepcopy(case.xpath('./src:expr', namespaces=namespaces)[0])
            case_value.tail = ''
            expr = E.expr(variable, ' ', E.operator('=='), case_value)
            # prettyprint(expr)
            exprs.append(expr)
            if i < len(case_cases) - 1:
                exprs.append(E.operator('||'))
        condition = E.expr(*exprs)
        # prettyprint(condition)
        if_xml = None
        if if_type == 'if':
            if_xml = E.__call__('if', 'if ', E.condition('(', condition, ')'), E.block('{', E.block_content('\n', *stmts), '}}'))
        if if_type == 'elseif':
            if_xml = E.__call__('if', 'if ', E.condition('(', condition, ')'), E.block('{', E.block_content('\n', *stmts), '}}'), type='elseif')
        if if_type == 'else':
            if_xml = E.__call__('else', 'else ', E.block('{', E.block_content('\n', *stmts)), '}')
        return if_xml

    ifs = []
    for i, (cases, stmts) in enumerate(stmts_by_case.items()):
        if all(c.tag == 'default' for c in cases):
            ifs.append(get_if(cases, stmts, 'else'))
        else:
            if i == 0:
                ifs.append(get_if(cases, stmts, 'if'))
            else:
                ifs.append(get_if(cases, stmts, 'elseif'))
    if_stmt = E.if_stmt(*ifs)
    target.getparent().replace(target, if_stmt)
    return root

exchanged_xmldata = switch_exchange(xmldata)
original_count = len(xmldata.xpath('//src:if', namespaces=namespaces) + xmldata.xpath('//src:else', namespaces=namespaces))
difference = len(exchanged_xmldata.xpath('//src:if', namespaces=namespaces) + exchanged_xmldata.xpath('//src:else', namespaces=namespaces)) - original_count
assert difference == 4
# difference

In [203]:

import re
def loop_exchange(root):
    root = copy.deepcopy(root)
    all_loops = root.xpath(f'//src:for', namespaces=namespaces)
    loop = next(iter(all_loops))
    loop_parent = loop.getparent()
    loop_idx = loop_parent.index(loop)
    block = loop.xpath('./src:block', namespaces=namespaces)[0]
    block_content = block.xpath('./src:block_content', namespaces=namespaces)[0]

    init, cond, incr = loop.xpath('./src:control', namespaces=namespaces)[0]
    init = init[0]
    cond = cond[0]
    incr = incr[0]

    init.tail = ';' + loop.tail
    cond.tail = ''

    incr_stmt = E.expr_stmt(incr)
    whitespace_before_content = next(iter(re.findall(r'^<block_content[^>]+>(\s+)', et.tostring(block_content, encoding='unicode'))), '')
    orig_tail = block_content[-1].tail
    block_content[-1].tail = whitespace_before_content
    incr_stmt.tail = ';' + orig_tail
    block_content.insert(len(block_content)+1, incr_stmt)

    loop_parent.insert(loop_idx, init)
    whitespace = next(iter(re.findall(r'\s+$', et.tostring(loop, encoding='unicode'))), '')
    loop_parent.insert(loop_idx+1, E.__call__('while', 'while ', E.condition('(', cond, ')'), ' ', block, whitespace))
    loop_parent.remove(loop)
    return root

exchanged_xmldata = loop_exchange(xmldata)
assert len(xmldata.xpath('//src:while', namespaces=namespaces)) == 0
assert len(exchanged_xmldata.xpath('//src:while', namespaces=namespaces)) > 0

In [207]:
import functools

def c2c(c_filename, transforms):
    xml = srcml(c_filename)
    for t in transforms:
        xml = t(xml)
    dst_filename = c_filename.parent / (c_filename.name + '.xml')
    tree = et.ElementTree(xml)
    tree.write(str(dst_filename))
    return srcml(dst_filename)

c_file = Path('testbed.c')
transforms = [
    insert_noop,
    loop_exchange,
    functools.partial(rename_variable, var_name='fd', new_name='ficus'),
]
new_c_code = c2c(c_file, transforms)
import difflib
diff = list(difflib.unified_diff(c_file.open().readlines(), new_c_code.splitlines(keepends=True)))
print(''.join(diff))

Running SrcML: bin/srcml testbed.c
Running SrcML: bin/srcml testbed.c.xml
--- 
+++ 
@@ -32,6 +32,7 @@
 int switchtest(char a)
 {
     int x;
+    int fubar = 123;
     int y = 1;
     int z = 0;
     switch(a)
@@ -58,8 +59,10 @@
 int looptest()
 {
     int x = 0;
-    for (int i = 0; i < 10; i ++) {
+    int i = 0;
+    while (i < 10) {
         x += 1;
+        i ++;
     }
     return x;
 }
@@ -67,13 +70,13 @@
 void
 test(char *str)
 {
-	int fd;
+	int ficus;
 
 	if(chroot(DIR) < 0)			/* BAD */
 		return;
-	fd = open(FILE, O_RDONLY);		/* BAD */
-	if(fd >= 0)
-		close(fd);
+	ficus = open(FILE, O_RDONLY);		/* BAD */
+	if(ficus >= 0)
+		close(ficus);
 }
 
 int



# Modifying AST with `pycparser`
This is not suitable because pycparser destroys whitespace.n

In [205]:
#-----------------------------------------------------------------
# pycparser: rewrite_ast.py
#
# Tiny example of rewriting a AST node
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#-----------------------------------------------------------------
from __future__ import print_function

import pycparser
import os
import pycparser_fake_libc

fake_libc_arg = "-I" + pycparser_fake_libc.directory

generator = pycparser.c_generator.CGenerator()

ast = pycparser.parse_file('x42/c/X42.c', use_cpp=True, cpp_args=fake_libc_arg)
print("Before:")
# ast.show(offset=2)

def print_ast(ast):
  text = generator.visit(ast)
  print(text[text.find('typedef uint32_t xcb_visualid_t;'):])
print_ast(ast)

def recurse(node, fn):
  fn(node)
  for child_node in node.children():
    recurse(node)

for e in ast.ext:
  if isinstance(e, pycparser.c_ast.FuncDef):
    # e.show(offset=2)
    def change_x_to_y_2(node):
      if isinstance(node, pycparser.c_ast.Assignment) and node.lvalue.name == 'x':
        assign.lvalue.name = "y"
        assign.rvalue.value = 2

print("After:")
# ast.show(offset=2)
print_ast(ast)

AttributeError: module 'pycparser' has no attribute 'c_generator'

# Modifying AST with `libclang`
`libclang doesn't support modifying AST directly.`

In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/home/benjis/work/transform/llvm-project-11.1.0.src/build/lib'

In [None]:
#!/usr/bin/env python
""" Usage: call with <filename> <typename>
"""

import clang.cindex
from clang.cindex import CursorKind

clang.cindex.Config.set_library_file('llvm-project-11.1.0.src/build/lib/libclang.so')

def get(node, l):
    if l(node):
        return node
    for c in node.get_children():
        result = get(c, l)
        if result is not None:
            return result
    return None

filename = 'chroot1.c'
index = clang.cindex.Index.create()
tu = index.parse(filename)
print('Translation unit:', tu.spelling)
test = get(tu.cursor, lambda n: n.kind == CursorKind.FUNCTION_DECL and n.spelling == 'test')

Translation unit: chroot1.c


In [None]:
def get_variable_references(node, old_name):
    collect = []
    if node.kind == CursorKind.VAR_DECL or node.kind == CursorKind.DECL_REF_EXPR:
        if node.spelling == old_name:
            collect.append(node)
    for c in node.get_children():
        result = get_variable_references(c, old_name)
        if result:
            collect += result
    return collect

var_ref = get_variable_references(test, 'fd')
for v in var_ref:
    print(v.spelling)

fd
fd
fd
fd


In [None]:
def print_internal(node, replacements, output, section):
    for c in node.get_children():
        print_node(c, replacements)

def print_node(node, replacements, sections=''):
    sections = (' ' * (node.extent.end.line - node.extent.start.line))
    print(len(sections))
print_node(test, [])

10
