In [2]:
# !pip install torch
# !pip install torch_geometric
# !pip install -U sentence-transformers

# Imports

In [3]:
import os
import pickle
import re
from pprint import pprint
import typing as tp
from collections import Counter, defaultdict
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch_geometric as tg
from torch_geometric.data import Data

import networkx as nx
from networkx.classes.digraph import DiGraph

import xml.etree.ElementTree as ET

from sentence_transformers import SentenceTransformer

In [4]:
# Building graph
def dfs_build_graph(root_node:ET.Element) -> tp.Tuple[DiGraph, tp.List, str]:
    graph = nx.DiGraph() 

    leafs = []
    root_name = 0
    suf = iter(range(10_000_000))

    def dfs(vertice:ET.Element) -> str:
        # Extract node name in AST

        tag = vertice.tag # XML tag like <SimpleName>

        node = next(suf)
        node_name = tag   

        graph.add_node(node, text=[node_name])
#         graph.add_node(node, text=node_name)
        
        # Recursively traverse the child nodes
        for child in vertice:
            child_node = dfs(child)
            graph.add_edge(child_node, node)
          
        if len(vertice) == 0:
            leaf_node = next(suf)
            leaf_name = vertice.text.strip()

            graph.add_node(leaf_node, text=[leaf_name])
#             graph.add_node(leaf_node, text=leaf_name)
            graph.add_edge(leaf_node, node)

            leafs.append(leaf_node)
            
        return node

    # TODO: add sink and leafs edges here, not outside
    
    dfs(root_node)

    return (graph, leafs, root_name)

# Merging Single-Entry Node Sequences
def dfs_merge_sequences(graph, node: str):
  # node: str, key of node in graph.nodes

  nodes_one_succ = []

  while(graph.in_degree(node) == 1):
      nodes_one_succ.append(node)
      node = list(graph.predecessors(node))[0]

  # leaf nodes sholdn't be merged, so in_degree != 0

  if len(nodes_one_succ) > 1:
      parent = nodes_one_succ[0]

      for child in nodes_one_succ[1:]:
          nx.contracted_nodes(graph, parent, child, self_loops=False, copy=False)

      # concatenate tokens of merged vertices
      p = graph.nodes(data=True)[parent]
        
      p['text'] = p['text'] + [args['text'][0] for args in p['contraction'].values()]
#       p['text'] += ',' + ','.join(args['text'] for args in p['contraction'].values()) # TODO

      del p['contraction']

  # now node is either a leaf or has degree >= 2

  for child in graph.predecessors(node):
      dfs_merge_sequences(graph, child)

# Merging Aggregation Structures
# TODO: merge only vertices with specific AST node names
def dfs_merge_aggregations(graph, node: str):

    # Recursively traverse the child nodes
    children_degs = [graph.in_degree(child) for child in graph.predecessors(node)]
    
    if not children_degs:
        # node is a leaf
        return

    if len(children_degs) >= 2 and max(children_degs) <= 1: # TODO
        # all children are either have deg=1 or a leaf
        # merging node and children
        children = list(graph.predecessors(node))
        
        for child in (children):
            nx.contracted_nodes(graph, node, child, self_loops=False, copy=False)
        
        parent = graph.nodes(data=True)[node]
        parent['text'] = [args['text'] for args in parent['contraction'].values()] + [parent['text']]
#         parent['text'] += '|' + \
#             '|'.join([args['text'] for args in parent['contraction'].values()]) # TODO
        del parent['contraction']

    for child in graph.predecessors(node):
        dfs_merge_aggregations(graph, child)

In [44]:
def build_graph(root_node:ET.Element) -> DiGraph:

    graph, leafs, root_name = dfs_build_graph(root_node)

    dfs_merge_sequences(graph, root_name)
    dfs_merge_aggregations(graph, root_name)

    # Adding edges between leafs (consequent initial code tokens)
    for u, v in zip(leafs[:-1], leafs[1:]):
        graph.add_edge(u, v)

    # Adding ROOT node - global graph's sink
    for v in graph.nodes:
        if v != root_name:
            graph.add_edge(v, root_name)

    return graph


In [45]:
def graph_to_data(graph:DiGraph, target:int, cwe:int=0, cwe_full:str='') -> Data:
    graph = nx.convert_node_labels_to_integers(graph)
    return Data(edge_index=torch.tensor(list(graph.edges), dtype=torch.long).t().contiguous().view(2, -1),
                edge_attr=None,
                x=list(nx.get_node_attributes(graph, 'text').values()),
                y=target,
                num_nodes=graph.number_of_nodes(),
                cwe=cwe,
                # cwe_full=cwe_full
                )

In [68]:
def method_words(method, threshold=1e6):
    words = 0
    if not method.findall('.Block'):
        return 0
    for elem in method.findall('.Block')[0].iter():
#     for elem in method.iter():
        if elem.text.strip():
            words += 1
            if words > threshold:
                return words
    return words
    
def get_tree_depth(element:ET.Element, level=0, threshold=4):
    """Return the depth of an ElementTree Element object."""
    if len(element) == 0:
        return level
    if level > threshold:
        return level
    return max(get_tree_depth(child, level + 1) for child in element)

def compare_nodes(node1, node2): # NOT USED
    """Recursively compare the attributes and children of two XML nodes."""
    # Check if the tag and attributes of the nodes are equal
    if node1.tag != node2.tag :#or node1.text != node2.attrib:
        return False

    # Check if the number of children of the nodes is equal
    if len(node1) != len(node2):
        return False

    # Recursively compare the children of the nodes
    for child1, child2 in zip(node1, node2):
        if not compare_nodes(child1, child2):
            return False

    return True
        
def get_roots(file_path:str, DEPTH_THRESHOLD=4, WORDS_THRESHOLD=5) -> tp.List[tp.Dict]:
    """
    :return: list of dicts {'element': ET.element, 'words', 'depth', 'method_name', 'path'}
    """
    
    tree = ET.parse(file_path)
    class_name = tree.findtext('.TypeDeclaration/SimpleName')
    if not class_name.startswith('CWE'):
#         raise ValueError(f'Class name "{class_name}" does not starts with CWE')
        print(f'Class name "{class_name}" does not starts with CWE')
        
        
    cwe, cwe_full = 0, ''
    tmp = tree.findall('./PackageDeclaration//QualifiedName/SimpleName') 
    for t in tmp:
        if t.text.startswith('CWE'):
            cwe = int(re.search(r'CWE([0-9]*)_', t.text).group(1))
            cwe_full = t.text
            break
            
    
    valid_methods = []
    for method in tree.findall('.//TypeDeclaration/MethodDeclaration'):
        
        method_name = method.findtext('.SimpleName')        
        if not re.match('^(good|bad).*$', method_name):
            continue
        
        depth = get_tree_depth(method, threshold=DEPTH_THRESHOLD)
        if depth <= DEPTH_THRESHOLD:
            continue
        
        words = method_words(method, WORDS_THRESHOLD)
        if words <= WORDS_THRESHOLD:
            continue
            
#         for elem in method.findall('.SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
#             if re.search(r'^(good|bad).*', elem.text):                
#                 elem.text = str(hash((class_name, elem.text)))
                
#         for elem in method.findall('.//MethodInvocation/SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
#             if re.search(r'^(good|bad).*', elem.text):                
#                 elem.text = str(hash((class_name, elem.text)))
        
        for elem in method.findall('.//SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
            if re.search(r'^(good|bad).*', elem.text):                
                elem.text = str(hash((class_name, elem.text)))
                
        if re.search(r'^(good|bad).*', method_name).group(1) == 'good':
            cwe, cwe_full = 0, ''
        tmp_cnt[cwe] += 1
                
        valid_methods.append({
            'element':method,
            'words':words,
            'depth':depth,
            'method_name':method_name,
            'cwe': cwe,
            # 'cwe_full': cwe_full,
            # 'path': file_path,
        })
    return valid_methods


In [69]:
def xml_to_Data(path_to_xmls,) -> tp.Tuple[tp.List[Data], tp.List[DiGraph]]:
    datas = [] 
    graphs = []
    methods = []
    files = 0 

    for filename in tqdm(os.listdir(path_to_xmls)):
        file_path = os.path.join(path_to_xmls, filename)
        if not os.path.isfile(file_path):
            continue
        method_roots = get_roots(file_path)
        files += 1
        # methods += method_roots

        for method in method_roots: 
            graph = build_graph(method['element'])
            target = 1 if method['method_name'][:3] == 'bad' else 0
            data = graph_to_data(graph, target, cwe=method['cwe'])

            # graphs.append(graph)
            datas.append(data)        

    print(f'Total files: {files}')
    return [datas, graphs, methods]       

# Dataset collection


In [73]:
datas, graphs, methods = xml_to_Data('Juliet_AST')

100%|██████████| 46286/46286 [23:19<00:00, 33.07it/s]  

Total files: 46286





In [None]:
with open('data_v2_4.pickle', 'wb') as f:
    pickle.dump(datas, f)