In [1]:
import ast
import json
import numpy as np
import pandas as pd
import os
import seaborn as sns

In [2]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Normalise AST

## identify self defs

self defined functions

In [3]:
def is_funcdef(node: ast.AST) -> bool:
    '''
    return True if node is a function definition statement,
    return False otherwise.
    '''
    return isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef)

In [4]:
def get_funcdef_name(node: ast.AST) -> str:
    '''
    get the name of the function defined
    '''
    if not is_funcdef(node): return None
    
    return node.name

In [5]:
def is_instance_method_def(node: ast.AST) -> bool:
    '''
    return True if node is a instance method definition statement,
    return False otherwise.
    '''
    
    if not is_funcdef(node): return False
    
    arguments = node.args.args
    
    # self is first argument of an instance method
    return len(arguments) > 0 and arguments[0].arg == 'self'

In [6]:
def get_method_def(node: ast.AST, classname: str) -> tuple:
    '''
    Return (<classname>.<class_method_name>, <classname>) for an class method,
    (<instance_method_name>, <classname>) for an instance method,
    and None otherwise.
    '''
    
    if not is_funcdef(node): return None
    
    if is_instance_method_def(node):
        node_name = get_funcdef_name(node)
        if node_name == '__init__':
            return (classname, classname)
        else:
            return (node_name, classname)
    
    return (f'{classname}.{get_funcdef_name(node)}', classname)

In [7]:
def get_methoddefs_in_classdef(node: ast.ClassDef) -> set:
    '''
    Given node of type <ast.ClassDef>,
    get the methods defined within the class definition subtree
    '''
    method_defs = set()
    
    if not isinstance(node, ast.ClassDef):
        return method_defs

    classname = node.name
    
    for n_x in node.body:
        m_def = get_method_def(n_x, classname)
        if m_def is not None:
            method_defs.add(m_def)
    
    return method_defs

variables

In [8]:
def unpack_ast_tup_or_list(node_list: []) -> []:
    def unpack(node):
        if not isinstance(node, ast.Tuple) and not isinstance(node, ast.List):
            return [node]
        
        unpacked = []
        for x in node.elts:
            unpacked.extend(unpack(x))
        
        return unpacked
    
    res = []
    for i in node_list:
        res.extend(unpack(i))
    return res

In [9]:
def get_assigned_vars(node: ast.Assign) -> set:
    '''
    Given an assignment, extract the LHS variables
    '''
    lhs_vars = set()
    
    if not isinstance(node, ast.Assign):
        return lhs_vars

    lhs = unpack_ast_tup_or_list(node.targets)

    for var in lhs:
        lhs_vars.add(ast.unparse(var))
    
    return lhs_vars

identifying in the AST

In [10]:
def get_self_defs(root: ast.AST):
    '''
    extracts all the current scope self defs
    '''
    if root is None: return None
    
    self_defs = set()
    queue = [x for x in ast.iter_child_nodes(root)]
    
    while queue:
        node = queue.pop(0)
        
        if is_funcdef(node):
            self_defs.add(get_funcdef_name(node))
            continue
        
        if isinstance(node, ast.ClassDef):
            self_defs.add(node.name)
            self_defs = self_defs.union(get_methoddefs_in_classdef(node))
            continue
        
        if isinstance(node, ast.Assign): # handles only mainscope assignments
            self_defs = self_defs.union(get_assigned_vars(node))
        
        for child in ast.iter_child_nodes(node):
            queue.append(child)
            
    return self_defs

example

In [11]:
testcase_code = """
class Person:
    def __init__(self, name):
        self.name = name
    
    def hi(name):
        print(f'hi {name}')
    
    def introduce(self):
        print(f'hi my name is {self.name}')

def bye(name):
    print(f'bye {name}')
    
a = print

c, d = [1,2]
"""

testcase_root = ast.parse(testcase_code)

assert(get_self_defs(testcase_root) ==
       {('Person', 'Person'),
        ('Person.hi', 'Person'),
        ('introduce', 'Person'),
        'Person',
        'a',
        'bye',
        'c',
        'd'})

## normalise self defs

In [12]:
def get_func_attrs(node):
    """
    return the func node within an 'ast.Call' node and the relevant named attribute to access 
    the object of the calling function and function called
    
    i.e., getattr(funcnode, "id"/"attr") to access function name
    """
    if not isinstance(node, ast.Call): return None
    
    funcnode = node.func
    
    if isinstance(funcnode, ast.Name):
        return (funcnode, None, "id")
    
    if isinstance(funcnode, ast.Attribute):
        return (funcnode, "value", "attr")

In [13]:
def get_func_called(node: ast.Call):
    '''
    Return (object of calling function, function called)
    '''
    
    call_obj, func = None, None
    
    if not isinstance(node, ast.Call): return call_obj, func
    
    func_attrs = get_func_attrs(node)
    if func_attrs is None: return call_obj, func
    
    funcnode, obj_attr, func_attr = func_attrs[0], func_attrs[1], func_attrs[2]
    

    if obj_attr is not None:
        call_obj = ast.unparse(getattr(funcnode, obj_attr))
    
    if func_attr is not None:
        func = getattr(funcnode, func_attr)
    return call_obj, func

In [14]:
def mask_call(node: ast.AST, self_defs: set()):
    '''
    mask any calls to self defined functions
    '''
    if not isinstance(node, ast.Call): return
    
    if get_func_attrs(node) is None:
        return
    
    funcnode, obj_attr, func_attr = get_func_attrs(node)
    calling_obj, func = get_func_called(node)
    
    if calling_obj is None and func is not None:
        if func in self_defs:
            setattr(funcnode, func_attr, "self_def_func")
            return
    
    elif func is not None:
        # case 1: method defined within a classdef
        if (f'{calling_obj}.{func}', calling_obj) in self_defs:
            setattr(funcnode, obj_attr, ast.Name(id='self_def_class', ctx=ast.Load()))
            setattr(funcnode, func_attr, 'self_def_func')
            return
        
        # case 2: instance method
        for x in self_defs:
            if isinstance(x, tuple) and x[0] == func:
                setattr(funcnode, func_attr, 'self_def_func')
                return
    return

In [15]:
def mask_calls_in_AST(root: ast.AST, outer_scope_self_defs):
    
    if root is None: return
    
    self_defs = outer_scope_self_defs.union(get_self_defs(root)) 
    queue = [x for x in ast.iter_child_nodes(root)]
    
    while queue:
        node = queue.pop(0)
        
        if is_funcdef(node):
            mask_calls_in_AST(node, self_defs.union(node.name))
            continue
        
        if isinstance(node, ast.ClassDef):
            for n_x in node.body:
                if is_funcdef(n_x):
                    mask_calls_in_AST(n_x, self_defs.union({get_method_def(n_x, node.name)}))
            continue
        
        if isinstance(node, ast.Call):
            mask_call(node, self_defs)
        
        for child in ast.iter_child_nodes(node):
            queue.append(child)
    return

In [16]:
def mask_selfdefs(root: ast.AST):
    '''
    normalise the self defined headers
    '''
    
    if root is None: return
    
    for n_x in ast.walk(root):
        if is_funcdef(n_x):
            n_x.name = 'self_def_func'
        
        if isinstance(n_x, ast.ClassDef):
            n_x.name = 'self_def_class'
    return

In [17]:
def mask_args(node: ast.FunctionDef):
    if not isinstance(node, ast.FunctionDef): return
    argmap = dict()
    arg_count = 0
    
    for arg in ast.iter_child_nodes(node.args):
        if not isinstance(arg, ast.arg):
            continue
        
        if arg.arg == 'self':
            argmap[arg.arg] = arg.arg # keep 'self' as 'self'
            continue
        
        argmap[arg.arg] = f'arg{arg_count}'# record new masked name
        arg.arg = argmap[arg.arg] # update argument name to the masked name
        arg_count += 1

    return argmap

In [18]:
def mask_variables(root: ast.AST, varmap):
    stack = [x for x in ast.iter_child_nodes(root)][::-1]
    count = 0
    
    call_attrs = set()
    
    while stack:
        skip = None
        node = stack.pop(-1)
        
        if isinstance(node, ast.Call):
            if isinstance(node.func, ast.Name):
                skip = node.func
            
            if isinstance(node.func, ast.Attribute):
                call_attrs.add(node.func)
                
            if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
                if node.func.value.id == 'self_def_class':
                    skip = node.func
                
                # an variable that calls a function (e.g., a.lower())
                if isinstance(node.func.value, ast.Name) and varmap.get(node.func.value.id) is not None:
                    node.func.value.id = varmap[node.func.value.id]
                    skip = node.func
            
        
        # special masking for args
        if isinstance(node, ast.FunctionDef):
            argmap = mask_args(node)
            mask_variables(node, argmap)
            continue
            
        if isinstance(node, ast.Name):
            if varmap.get(node.id) is None:
                varmap[node.id] = f'var{count}'
                count += 1
            node.id = varmap[node.id]
        
        if node not in call_attrs and isinstance(node, ast.Attribute):
            node.attr = 'attr'
        
        temp = []
        for child in ast.iter_child_nodes(node):
            if child == skip:
                continue
            temp.append(child)
        stack.extend(temp[::-1])
    
    return root

In [19]:
class MaskSubscript(ast.NodeTransformer):
        def visit_Subscript(self, node):
            return ast.Name(id='dummySubscript',ctx=ast.Load())

def cpy_with_subcript_masked(i):
    '''
    given an AST subtree, provide a copy with subcripts masked to 'dummyScript'
    '''
    _cpy = ast.parse(ast.unparse(i))
    MaskSubscript().visit(_cpy)
    return _cpy

Example

In [20]:
testcase_code = """
ls = ['A', 'B']
for a in ls:
    b = a.lower()
    print(b)
c = ls
"""
testcase_root = ast.parse(testcase_code)

mask_calls_in_AST(testcase_root, set())
mask_selfdefs(testcase_root)
mask_variables(testcase_root, dict())
print(ast.unparse(testcase_root))

var0 = ['A', 'B']
for var1 in var0:
    var2 = var1.lower()
    print(var2)
var3 = var0


In [21]:
testcase_root = ast.parse("""
class Person:
    def __init__(self, name):
        self.name = name
    
    def hi(name):
        print(f'hi {name}')
    
    def introduce(self):
        print(f'hi my name is {self.name}')

def bye(name):
    print(f'bye {name}')

Person.hi('Robin')
obj1 = Person('Cassandra')
obj2 = Person('Robin')
obj1.introduce()
""")
mask_calls_in_AST(testcase_root, set())
mask_selfdefs(testcase_root)
mask_variables(testcase_root, dict())
print(ast.unparse(testcase_root))

class self_def_class:

    def self_def_func(self, arg0):
        self.attr = arg0

    def self_def_func(arg0):
        print(f'hi {arg0}')

    def self_def_func(self):
        print(f'hi my name is {self.attr}')

def self_def_func(arg0):
    print(f'bye {arg0}')
self_def_class.self_def_func('Robin')
var0 = self_def_func('Cassandra')
var1 = self_def_func('Robin')
var0.self_def_func()


# Extract Logic Blocks From AST

In [22]:
def is_logic_block(node):
    return (isinstance(node, ast.For) or 
            isinstance(node, ast.While) or 
            isinstance(node, ast.If) or 
            isinstance(node, ast.ClassDef) or
            is_funcdef(node)
           )

In [23]:
def has_test_statement(node):
    return (isinstance(node, ast.For) or 
            isinstance(node, ast.While) or 
            isinstance(node, ast.If)
           )

In [24]:
def get_test_logic(node):
    '''
    Given an AST node with test logic, extract the test logic header
    '''
    if not has_test_statement(node): return
      
    cpy = ast.parse(ast.unparse(node))
    for i in ast.walk(cpy):
        if isinstance(i, type(node)):
            i.body=[ast.Expr(value=ast.Constant(value=Ellipsis))]
            i.orelse=[]
            break
    return cpy

if statements

In [25]:
def get_if_block(node):
    if not isinstance(node, ast.If): return None
    if len(node.orelse) == 0: return None
    
    cpy = ast.parse(ast.unparse(node))
    for i in ast.walk(cpy):
        if isinstance(i, ast.If):
            i.orelse = []
            break
    return cpy

In [26]:
def get_else_block(node):
    
    if not isinstance(node, ast.If): return
    
    # is elif block
    if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
        return
    
    # no else block
    if len(node.orelse) == 0:
        return
        
    cpy = ast.parse(ast.unparse(node))
    for i in ast.walk(cpy):
        if isinstance(i, ast.If):
            i.test=ast.Constant(value=Ellipsis)
            i.body=[ast.Expr(value=ast.Constant(value=Ellipsis))]
            break
    
    return cpy

In [27]:
def subblocks_in_If(node):
    if not isinstance(node, ast.If): return None
    
    subblocks = []
    
    # get its own block
    own_block = get_if_block(node)
    if own_block is not None:
        subblocks.append(own_block)
    
    # else block
    else_block = get_else_block(node)

    if else_block is not None:
        subblocks.append(else_block)
        
    return subblocks

Example

In [28]:
testcase_code = """
if x == 7:
    print(x)
elif x < 7:
    print(7-x)
else:
    print(7)
    a = 3+2
"""
testcase_root = ast.parse(testcase_code)
for i in ast.walk(testcase_root):
    if subblocks_in_If(i) is not None:
        for k in subblocks_in_If(i):
            print(ast.unparse(k))
            print('='*100)

if x == 7:
    print(x)
if x < 7:
    print(7 - x)
if ...:
    ...
else:
    print(7)
    a = 3 + 2


extract all logic blocks

In [29]:
def get_nodes_within(main: ast.AST, attr: str):
    nodes_within = []
    try: 
        nodes_within = getattr(main, attr) 
    except:
        return []
    
    return nodes_within

In [30]:
def get_unit_blocks_within(node):
    sub_units = []
    within = []
    
    sub_units.extend(get_nodes_within(node, 'body'))
    sub_units.extend(get_nodes_within(node, 'orelse'))
    
    if has_test_statement(node): # header which is a test statement
        within.append(get_test_logic(node))
        
    if isinstance(node, ast.If):
        within.extend(subblocks_in_If(node)) # its own if block and else block
            
    for n_x in sub_units:
        if not is_logic_block(n_x):
            within.append(ast.parse(ast.unparse(n_x)))
    return within

In [31]:
def get_contained_logic_blocks(root: ast.AST):
    logic_blocks = []
    
    direct_child = set([node for node in ast.iter_child_nodes(root)])
    queue = [node for node in ast.iter_child_nodes(root)]

    while queue:
        node = queue.pop(0)
        if is_logic_block(node) or node in direct_child:
            # record the whole block as a logic struct
            _cpy = ast.parse(ast.unparse(node))
            logic_blocks.append(_cpy)
            
            if is_logic_block(node):
                # record the blocks within as a logic structs
                blocks_within = get_unit_blocks_within(node)
                logic_blocks.extend(blocks_within)
                    
        for child in ast.iter_child_nodes(node):
            queue.append(child)
    return logic_blocks

In [32]:
def get_logic_blocks_in_AST(root: ast.AST):
    '''
    input: python submission
    output: logic subtrees
    '''
    
    if root is None: return None
    
    mask_calls_in_AST(root, set()) # mask all calls to self defined functions
    mask_selfdefs(root) # mask all self defined functions and self defined class names
    
    logic_blocks = get_contained_logic_blocks(root)
    
    res = [mask_variables(x, dict()) for x in logic_blocks]
    unique_res = set()
    
    further_masked = [cpy_with_subcript_masked(x) for x in res]
    filtered_further_masked = []
    
    for i in res:
        unique_res.add(ast.dump(i))
        
    for i in further_masked:
        if ast.dump(i) in unique_res:
            continue
        else:
            filtered_further_masked.append(i)
    
    res.extend(filtered_further_masked)
    return res

In [33]:
def extract_logic_blocks_from_submission(code):
    ast_root = None
    
    try:
        ast_root = ast.parse(code)
    except:
        return None    
    
    return get_logic_blocks_in_AST(ast_root)

## logic block extraction examples

In [34]:
testcase_code = """
a = int(input('Enter a number <= 4: '))

for i in range(0, min(a, 4)):
    b = ['c', 'o', 'd', 'e']
    print(b[i].upper())
"""

testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

var0 = int(input('Enter a number <= 4: '))
--------------------------------------------------------------------------------------------------------------
for var0 in range(0, min(var1, 4)):
    var2 = ['c', 'o', 'd', 'e']
    print(var2[var0].upper())
--------------------------------------------------------------------------------------------------------------
for var0 in range(0, min(var1, 4)):
    ...
--------------------------------------------------------------------------------------------------------------
var0 = ['c', 'o', 'd', 'e']
--------------------------------------------------------------------------------------------------------------
print(var0[var1].upper())
--------------------------------------------------------------------------------------------------------------
for var0 in range(0, min(var1, 4)):
    var2 = ['c', 'o', 'd', 'e']
    print(dummySubscript.upper())
--------------------------------------------------------------------------------------------------------

In [35]:
testcase_code = """
if x == 7:
    print(x)
elif x < 7:
    print(7-x)
else:
    print(7)
"""
testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

if var0 == 7:
    print(var0)
elif var0 < 7:
    print(7 - var0)
else:
    print(7)
--------------------------------------------------------------------------------------------------------------
if var0 == 7:
    ...
--------------------------------------------------------------------------------------------------------------
if var0 == 7:
    print(var0)
--------------------------------------------------------------------------------------------------------------
print(var0)
--------------------------------------------------------------------------------------------------------------
if var0 < 7:
    print(7 - var0)
else:
    print(7)
--------------------------------------------------------------------------------------------------------------
if var0 < 7:
    ...
--------------------------------------------------------------------------------------------------------------
if var0 < 7:
    print(7 - var0)
--------------------------------------------------------------------------------

In [36]:
testcase_code = """
def check_number_is_three(num):
    if num == 3:
        print('it is 3')
    else:
        print(num)

i = 0
while i < 9:
    check_number_is_three(i)
    i += 1
"""
testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

def self_def_func(arg0):
    if arg0 == 3:
        print('it is 3')
    else:
        print(arg0)
--------------------------------------------------------------------------------------------------------------
var0 = 0
--------------------------------------------------------------------------------------------------------------
while var0 < 9:
    self_def_func(var0)
    var0 += 1
--------------------------------------------------------------------------------------------------------------
while var0 < 9:
    ...
--------------------------------------------------------------------------------------------------------------
self_def_func(var0)
--------------------------------------------------------------------------------------------------------------
var0 += 1
--------------------------------------------------------------------------------------------------------------
if var0 == 3:
    print('it is 3')
else:
    print(var0)
--------------------------------------------------------------

In [37]:
testcase_code = """
ls = [1,2,3]
for i in ls:
    while i < 9:
        print(i)
        i += 1
"""
testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

var0 = [1, 2, 3]
--------------------------------------------------------------------------------------------------------------
for var0 in var1:
    while var0 < 9:
        print(var0)
        var0 += 1
--------------------------------------------------------------------------------------------------------------
for var0 in var1:
    ...
--------------------------------------------------------------------------------------------------------------
while var0 < 9:
    print(var0)
    var0 += 1
--------------------------------------------------------------------------------------------------------------
while var0 < 9:
    ...
--------------------------------------------------------------------------------------------------------------
print(var0)
--------------------------------------------------------------------------------------------------------------
var0 += 1
--------------------------------------------------------------------------------------------------------------


In [38]:
testcase_code = """
i = 0
while i < 5:
    if i % 2 == 0:
        print(str(i) + ' is even')
    else:
        print(str(i) + ' is odd')
    i += 1
"""
testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

var0 = 0
--------------------------------------------------------------------------------------------------------------
while var0 < 5:
    if var0 % 2 == 0:
        print(str(var0) + ' is even')
    else:
        print(str(var0) + ' is odd')
    var0 += 1
--------------------------------------------------------------------------------------------------------------
while var0 < 5:
    ...
--------------------------------------------------------------------------------------------------------------
var0 += 1
--------------------------------------------------------------------------------------------------------------
if var0 % 2 == 0:
    print(str(var0) + ' is even')
else:
    print(str(var0) + ' is odd')
--------------------------------------------------------------------------------------------------------------
if var0 % 2 == 0:
    ...
--------------------------------------------------------------------------------------------------------------
if var0 % 2 == 0:
    print(str(var0)

In [39]:
testcase_code = """
class Person:
    def __init__(self, name):
        self.name = name
    
    def hi(name):
        print(f'hi {name}')
    
    def introduce(self):
        print(f'hi my name is {self.name}')

def bye(name):
    print(f'bye {name}')

Person.hi('Robin')
obj1 = Person('Cassandra')
obj2 = Person('Peter')
obj1.introduce()

print(obj1.name.split())
"""

In [40]:
testcase_logic_blocks = extract_logic_blocks_from_submission(testcase_code)
for i in testcase_logic_blocks:
    print(ast.unparse(i))
    print('-'*110)

class self_def_class:

    def self_def_func(self, arg0):
        self.attr = arg0

    def self_def_func(arg0):
        print(f'hi {arg0}')

    def self_def_func(self):
        print(f'hi my name is {self.attr}')
--------------------------------------------------------------------------------------------------------------
def self_def_func(arg0):
    print(f'bye {arg0}')
--------------------------------------------------------------------------------------------------------------
print(f'bye {var0}')
--------------------------------------------------------------------------------------------------------------
self_def_class.self_def_func('Robin')
--------------------------------------------------------------------------------------------------------------
var0 = self_def_func('Cassandra')
--------------------------------------------------------------------------------------------------------------
var0 = self_def_func('Peter')
---------------------------------------------------------

# ML

## BoW

In [41]:
def get_bow(submissions):
    bow_unique = set()
    bow_asts = []
    bow = []
    
    for s_x in submissions:
        logic_blocks = extract_logic_blocks_from_submission(s_x)
        
        if logic_blocks is None:
            continue
        
        for b_x in logic_blocks:
            if ast.dump(b_x) not in bow_unique:
                bow_unique.add(ast.dump(b_x))
                bow.append(ast.dump(b_x))
                bow_asts.append(b_x)
    return bow, bow_asts

In [42]:
def get_freq_dict(extracted_logic_blocks: []):
    freq_dict = dict()
    
    for b_x in extracted_logic_blocks:
        if freq_dict.get(ast.dump(b_x)) is None:
            freq_dict[ast.dump(b_x)] = 0
        freq_dict[ast.dump(b_x)] += 1
    
    return freq_dict

In [43]:
def vectorise(row, bow: []):
    v_length = len(bow) + 1
    v = [0 for i in range(v_length)]
    
    extracted = extract_logic_blocks_from_submission(row)
    
    if extracted is None:
        v[0] = 1 # syntax error (unparsable to ast)
        return v
    
    freq_dict = get_freq_dict(extracted)
    
    for i in range(len(bow)):
        if freq_dict.get(bow[i]) is None:
            v[i+1] = 0
        else:
            v[i+1] = freq_dict[bow[i]]
    return v

example

In [44]:
testcase_code1 = """
for i in range(1,4):
    print(i)
"""
testcase_code2 = """
i = 0
while i < 3:
    print(i)
    i += 1
"""

testcase_code3 = """
m = 0
while m < 3:
    m += 1
    print(m)
"""

testcase_code_list = [testcase_code1, testcase_code2, testcase_code3]
testcase_bow, testcase_bow_asts = get_bow(testcase_code_list)
testcase_result = [vectorise(x, testcase_bow) for x in testcase_code_list]

print(testcase_result)

[[0, 1, 1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0, 1, 1, 1]]


In [45]:
testcase_code_idx = 1
for i in range(len(testcase_result[testcase_code_idx])):
    if testcase_result[testcase_code_idx][i] > 0:
        print(i)
        print(ast.unparse(testcase_bow_asts[i-1]))
        print('='*100)

3
print(var0)
4
var0 = 0
5
while var0 < 3:
    print(var0)
    var0 += 1
6
while var0 < 3:
    ...
7
var0 += 1


## train DT

In [46]:
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
from sklearn import metrics

In [47]:
def train(X, y):
    
    # split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y)    
    
    # get bag of words from X_train
    bow, bow_asts = get_bow(X_train)
    features = ['ast not generated'] + [ast.unparse(x) for x in bow_asts]
    
    X_train_vec = X_train.apply(vectorise, args=(bow,))
    X_train_vec = pd.DataFrame(X_train_vec.tolist(), index= X_train_vec.index, columns=features)
    
    X_test_vec = X_test.apply(vectorise, args=(bow,))
    X_test_vec = pd.DataFrame(X_test_vec.tolist(), index= X_test_vec.index, columns=features)
    
    param_grid = {'max_depth' : [2, 3, 4],
                  'min_samples_leaf': [20, 40, 60]
                 }
    
    clf = DecisionTreeClassifier()
    gcv = GridSearchCV(estimator=clf, param_grid=param_grid, cv=5)
    gcv.fit(X_train_vec, y_train)
    
    best_clf = DecisionTreeClassifier(max_depth=gcv.best_params_['max_depth'], 
                                           min_samples_leaf=gcv.best_params_['min_samples_leaf'])
    
    best_clf.fit(X_train_vec, y_train)
    
    y_pred = best_clf.predict(X_test_vec)
    
    
    scores = {
    'accuracy': metrics.accuracy_score(y_test, y_pred) * 100,
    'f1 score': metrics.f1_score(y_test, y_pred) * 100,
    'precision': metrics.precision_score(y_test, y_pred) * 100,
    'recall': metrics.recall_score(y_test, y_pred) * 100,
    'BoW size': len(bow)
           }
    
    
    return best_clf, scores, features

## display DT

In [48]:
from sklearn.tree import export_text

def display_DT_txt(clf, feature_names):
    dt_txt = export_text(clf, feature_names=feature_names)
    print(dt_txt)

In [49]:
from sklearn.tree import export_graphviz
import graphviz

def display_DT_graph(clf, feature_names, out_file="dt"):
    export_graphviz(clf, out_file=out_file, class_names=["fail", "pass"],
                feature_names=feature_names, impurity=False, filled=True)

    with open(out_file) as f:
        dt_graph = f.read()
    display(graphviz.Source(dt_graph))