# Modifying AST with `srcml`
Parses very OK! Can we modify??

In [296]:
from lxml import etree as et
from lxml.builder import ElementMaker
import subprocess
from pathlib import Path
import copy

srcml_exe = 'srcml/bin/srcml'
namespaces={'src': 'http://www.srcML.org/srcML/src'}
E = ElementMaker(namespace="http://www.srcML.org/srcML/src")

def srcml(filepath):
    assert filepath.exists()
    args = [srcml_exe, filepath]
    args = [str(a) for a in args]
    print('Running SrcML:', ' '.join(args))
    proc = subprocess.run(args, capture_output=True)
    if proc.returncode != 0:
        print('Error', proc.returncode)
        print(proc.stderr)
        return None
    if filepath.suffix == '.xml':
        return proc.stdout.decode('utf-8')
    elif filepath.suffix == '.c':
        xml = et.fromstring(proc.stdout)
        return xml

fname = Path('testbed.c')
root = srcml(fname)
xmldata = root.xpath('//src:unit', namespaces=namespaces)[0]
assert 4 == len(xmldata.xpath('//src:if_stmt', namespaces=namespaces))
assert 1 == len(xmldata.xpath('//src:switch', namespaces=namespaces))

Running SrcML: srcml/bin/srcml testbed.c


In [297]:
# Print XML from root
def prettyprint(node):
    print(et.tostring(node, encoding="unicode", pretty_print=True))
prettyprint(xmldata)

<unit xmlns="http://www.srcML.org/srcML/src" xmlns:cpp="http://www.srcML.org/srcML/cpp" revision="1.0.0" language="C" filename="testbed.c"><comment type="block">/*
Description: A chroot() is performed without a chdir().
Keywords: Unix C Size0 Complex0 Api Chroot

Copyright 2005 Fortify Software.

Permission is hereby granted, without written agreement or royalty fee, to
use, copy, modify, and distribute this software and its documentation for
any purpose, provided that the above copyright notice and the following
three paragraphs appear in all copies of this software.

IN NO EVENT SHALL FORTIFY SOFTWARE BE LIABLE TO ANY PARTY FOR DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF FORTIFY SOFTWARE HAS
BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMANGE.

FORTIFY SOFTWARE SPECIFICALLY DISCLAIMS ANY WARRANTIES INCLUDING, BUT NOT
LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PUR

In [356]:
from random_word import RandomWords
words = RandomWords()

def rename_variable(root, picker=lambda i: i[0]):
    root = copy.deepcopy(root)
    all_names = root.xpath(f'//src:function//src:decl_stmt/src:decl/src:name', namespaces=namespaces)
    target_name_node = picker(all_names)
    original_target_name = target_name = target_name_node.text
    all_names_str = [c.text for c in all_names]
    while target_name in all_names_str:
        target_name = words.get_random_word()
    function_name = target_name_node.xpath('./ancestor::src:function', namespaces=namespaces)[0].xpath('./src:name', namespaces=namespaces)[0].text
    targets = root.xpath(f'//src:name[text() = "{original_target_name}"][ancestor::src:function[./src:name[text() = "{function_name}"]]]', namespaces=namespaces)
    assert len(targets) > 0
    for target in targets:
        target.text = target_name
    return root

renamed_xmldata = rename_variable(xmldata)

In [299]:
def insert_noop(root, picker=lambda i: i[0]):
    root = copy.deepcopy(root)
    all_targets = root.xpath(f'//src:*[contains(local-name(), "_stmt")]', namespaces=namespaces)
    target = picker(all_targets)
    target_parent = target.getparent()
    decl = E.decl_stmt(E.decl(
        E.type(E.name('int'), ' '),
        E.name('fubar', ' '),
        E.init('= ', E.expr(E.literal('123', type='number'))), ';'))
    decl.tail = target.tail
    target_parent.insert(target_parent.index(target)+1, decl)
    return root

nooped_xmldata = insert_noop(xmldata)
assert len(xmldata.xpath('//src:name[text() = "123"]', namespaces=namespaces)) == 0
assert len(nooped_xmldata.xpath('//src:literal[text() = "123"]', namespaces=namespaces)) > 0

In [310]:
from collections import defaultdict, OrderedDict
def switch_exchange(root, picker=lambda i: i[0]):
    root = copy.deepcopy(root)
    all_switches = root.xpath(f'//src:switch', namespaces=namespaces)
    target = picker(all_switches)
    variable = copy.deepcopy(target.xpath('./src:condition/src:expr', namespaces=namespaces)[0])
    variable.tail = ''
    switch_ws = next(iter(re.findall(r'\s+$', et.tostring(target.xpath('./src:condition', namespaces=namespaces)[0], encoding='unicode'))), '')
    block_content = target.xpath('./src:block/src:block_content', namespaces=namespaces)[0]
    block_content_ws = next(iter(re.findall(r'^<block_content[^>]+>(\s+)', et.tostring(block_content, encoding='unicode'))), '')
    stmts_by_case = OrderedDict()
    cases_key = None
    cases = []
    for stmt in block_content:
        if et.QName(stmt).localname == 'case' or et.QName(stmt).localname == 'default':
            cases.append(stmt)
            cases_key = tuple(copy.deepcopy(cases))
        elif et.QName(stmt).localname == 'break' or len(stmt.xpath('.//src:break', namespaces=namespaces)) > 0:
            cases = []
        else:
            if cases_key not in stmts_by_case:
                stmts_by_case[cases_key] = []
            stmts_by_case[cases_key].append(stmt)
    
    # default must be last
    for i, (cases, _) in enumerate(stmts_by_case.items()):
        if any(et.QName(c).localname == 'default' for c in cases):
            assert i == len(stmts_by_case)-1, 'Non-standard defaults are not supported!'
    
    # NOTE: debug
    # for k, v in stmts_by_case.items():
    #     for i in k:
    #         prettyprint(i)
    #     for i in v:
    #         prettyprint(i)

    # Reorder fallthrough blocks for soundness
    def get_cases(cases):
        result = set()
        for c in cases:
            expr = c.xpath('.//src:expr', namespaces=namespaces)
            if len(expr) > 0:
                result.add(''.join(expr[0].itertext()))
            else:
                result.add('default')
        return result
    items = list(stmts_by_case.items())
    for i in range(len(items)-1):
        cases, stmts = items[i]
        assert not get_cases(cases).issubset(get_cases(items[i+1][0])), 'Fallthroughs are not supported!'

    # NOTE: Trying to handle fallthroughs. This functionality is archived.
    # new_stmts_by_case = OrderedDict()
    # while i < len(items):
    #     cases, stmts = items[i]
    #     collect = []
    #     while True:
    #         if i+1 >= len(items) or any(et.QName(c).localname == 'default' for c in items[i+1][0]):
    #             break
    #         if get_cases(cases).issubset(get_cases(items[i+1][0])):
    #             collect.insert(0, (cases, stmts))  # Insert fallthrough cases in reverse order
    #             i += 1
    #             cases, stmts = items[i]
    #         else:
    #             collect.insert(0, (cases, stmts))  # Insert fallthrough cases in reverse order
    #             i += 1
    #             break
    #     if any(collect):
    #         for k, v in collect:
    #             new_stmts_by_case[k] = v
    #     else:
    #         new_stmts_by_case[cases] = stmts
    #         i += 1
    # stmts_by_case = new_stmts_by_case
    
    def get_if(cases, stmts, if_type):
        exprs = []
        case_cases = [c for c in cases if et.QName(c).localname == 'case']
        for i, case in enumerate(case_cases):
            case_value = copy.deepcopy(case.xpath('./src:expr', namespaces=namespaces)[0])
            case_value.tail = ''
            expr = E.expr(copy.deepcopy(variable), ' ', E.operator('=='), case_value)
            exprs.append(expr)
            if i < len(case_cases) - 1:
                exprs.append(E.operator('||'))
        stmts[-1].tail = switch_ws
        condition = E.expr(*exprs)
        if_xml = None
        if if_type == 'if':
            if_xml = E.__call__('if', 'if ', E.condition('(', condition, ')'), switch_ws, E.block('{', E.block_content(block_content_ws, *stmts), '}'))
        if if_type == 'elseif':
            if_xml = E.__call__('if', 'else if ', E.condition('(', condition, ')'), switch_ws, E.block('{', E.block_content(block_content_ws, *stmts), '}'), type='elseif')
        if if_type == 'else':
            if_xml = E.__call__('else', 'else ', switch_ws, E.block('{', E.block_content(block_content_ws, *stmts)), '}', switch_ws)
        return if_xml

    items = list(stmts_by_case.items())
    ifs = []
    for i, (cases, stmts) in enumerate(items):
        ifs.append(switch_ws)
        if all(et.QName(c).localname == 'default' for c in cases):
            ifs.append(get_if(cases, stmts, 'else'))
        else:
            if i == 0:
                ifs.append(get_if(cases, stmts, 'if'))
            else:
                ifs.append(get_if(cases, stmts, 'elseif'))
    if_stmt = E.if_stmt(*ifs)
    target.getparent().replace(target, if_stmt)
    return root

exchanged_xmldata = switch_exchange(xmldata)
original_count = len(xmldata.xpath('//src:if', namespaces=namespaces) + xmldata.xpath('//src:else', namespaces=namespaces))
difference = len(exchanged_xmldata.xpath('//src:if', namespaces=namespaces) + exchanged_xmldata.xpath('//src:else', namespaces=namespaces)) - original_count
assert difference == 3
# difference

In [None]:

import re
def loop_exchange(root, picker=lambda i: i[0]):
    root = copy.deepcopy(root)
    all_loops = root.xpath(f'//src:for', namespaces=namespaces)
    loop = picker(all_loops)
    loop_parent = loop.getparent()
    loop_idx = loop_parent.index(loop)
    block = loop.xpath('./src:block', namespaces=namespaces)[0]
    block_content = block.xpath('./src:block_content', namespaces=namespaces)[0]

    loop_control = loop.xpath('./src:control', namespaces=namespaces)[0]
    init, cond, incr = loop_control
    init = init[0]
    cond = cond[0]
    incr = incr[0]

    init.tail = ';' + loop.tail
    cond.tail = ''

    incr_stmt = E.expr_stmt(incr)
    whitespace_before_content = next(iter(re.findall(r'^<block_content[^>]+>(\s+)', et.tostring(block_content, encoding='unicode'))), '')
    orig_tail = block_content[-1].tail
    block_content[-1].tail = whitespace_before_content
    incr_stmt.tail = ';' + orig_tail
    block_content.insert(len(block_content)+1, incr_stmt)

    loop_parent.insert(loop_idx, init)
    loop.tag = f'{{{namespaces["src"]}}}while'
    loop.text = 'while '
    loop.replace(loop_control, E.condition('(', cond, ')'))
    return root

exchanged_xmldata = loop_exchange(xmldata)
assert len(xmldata.xpath('//src:while', namespaces=namespaces)) == 0
assert len(exchanged_xmldata.xpath('//src:while', namespaces=namespaces)) > 0

In [357]:
import random

def c2c(c_filename, transforms):
    xml = srcml(c_filename)
    t = random.choice(transforms)
    xml = t(xml)
    # for t in transforms:
        # xml = t(xml)
    dst_filename = c_filename.parent / (c_filename.name + '.xml')
    tree = et.ElementTree(xml)
    tree.write(str(dst_filename))
    return srcml(dst_filename)

c_file = Path('testbed.c')
transforms = [
    # insert_noop,
    # switch_exchange,
    # loop_exchange,
    rename_variable,
]
new_c_code = c2c(c_file, transforms)
import difflib
diff = list(difflib.unified_diff(c_file.open().readlines(), new_c_code.splitlines(keepends=True)))
print(''.join(diff))

Running SrcML: srcml/bin/srcml testbed.c
Running SrcML: srcml/bin/srcml testbed.c.xml
--- 
+++ 
@@ -32,7 +32,7 @@
 
 int switchtest(char a)
 {
-    char *x;
+    char *photoswitch;
     int y = 1;
     int z = 0;
     switch(a)
@@ -41,18 +41,18 @@
         case 'a':
         case 'b':
         y = 10;
-        if (y == 10 && y > 4 && x == 5) {
-            x = "5";
+        if (y == 10 && y > 4 && photoswitch == 5) {
+            photoswitch = "5";
         }
         break;
         case 'c':
-        x = "10";
+        photoswitch = "10";
         break;
         default:
-        x = "1";
+        photoswitch = "1";
         break;
     }
-    return strlen(x) * y + z;
+    return strlen(photoswitch) * y + z;
 }
 
 int looptest()



# Modifying AST with `pycparser`
This is not suitable because pycparser destroys whitespace.n

In [205]:
#-----------------------------------------------------------------
# pycparser: rewrite_ast.py
#
# Tiny example of rewriting a AST node
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#-----------------------------------------------------------------
from __future__ import print_function

import pycparser
import os
import pycparser_fake_libc

fake_libc_arg = "-I" + pycparser_fake_libc.directory

generator = pycparser.c_generator.CGenerator()

ast = pycparser.parse_file('x42/c/X42.c', use_cpp=True, cpp_args=fake_libc_arg)
print("Before:")
# ast.show(offset=2)

def print_ast(ast):
  text = generator.visit(ast)
  print(text[text.find('typedef uint32_t xcb_visualid_t;'):])
print_ast(ast)

def recurse(node, fn):
  fn(node)
  for child_node in node.children():
    recurse(node)

for e in ast.ext:
  if isinstance(e, pycparser.c_ast.FuncDef):
    # e.show(offset=2)
    def change_x_to_y_2(node):
      if isinstance(node, pycparser.c_ast.Assignment) and node.lvalue.name == 'x':
        assign.lvalue.name = "y"
        assign.rvalue.value = 2

print("After:")
# ast.show(offset=2)
print_ast(ast)

AttributeError: module 'pycparser' has no attribute 'c_generator'

# Modifying AST with `libclang`
`libclang doesn't support modifying AST directly.`

In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/home/benjis/work/transform/llvm-project-11.1.0.src/build/lib'

In [None]:
#!/usr/bin/env python
""" Usage: call with <filename> <typename>
"""

import clang.cindex
from clang.cindex import CursorKind

clang.cindex.Config.set_library_file('llvm-project-11.1.0.src/build/lib/libclang.so')

def get(node, l):
    if l(node):
        return node
    for c in node.get_children():
        result = get(c, l)
        if result is not None:
            return result
    return None

filename = 'chroot1.c'
index = clang.cindex.Index.create()
tu = index.parse(filename)
print('Translation unit:', tu.spelling)
test = get(tu.cursor, lambda n: n.kind == CursorKind.FUNCTION_DECL and n.spelling == 'test')

Translation unit: chroot1.c


In [None]:
def get_variable_references(node, old_name):
    collect = []
    if node.kind == CursorKind.VAR_DECL or node.kind == CursorKind.DECL_REF_EXPR:
        if node.spelling == old_name:
            collect.append(node)
    for c in node.get_children():
        result = get_variable_references(c, old_name)
        if result:
            collect += result
    return collect

var_ref = get_variable_references(test, 'fd')
for v in var_ref:
    print(v.spelling)

fd
fd
fd
fd


In [None]:
def print_internal(node, replacements, output, section):
    for c in node.get_children():
        print_node(c, replacements)

def print_node(node, replacements, sections=''):
    sections = (' ' * (node.extent.end.line - node.extent.start.line))
    print(len(sections))
print_node(test, [])

10
