In [1]:
# !pip install torch
# !pip install torch_geometric
# !pip install -U sentence-transformers

In [2]:
import os
import pickle
import re

from pprint import pprint
import typing as tp
from collections import Counter
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch_geometric as tg
from torch_geometric.data import Data

import networkx as nx
from networkx.classes.digraph import DiGraph

import xml.etree.ElementTree as ET

from sentence_transformers import SentenceTransformer

In [267]:
# Building graph
def dfs_build_graph(root_node:ET.Element) -> tp.Tuple[DiGraph, tp.List, str]:
    graph = nx.DiGraph() 

    leafs = []
    root_name = 0
    suf = iter(range(10_000_000))

    def dfs(vertice:ET.Element) -> str:
        # Extract node name in AST

        tag = vertice.tag # XML tag like <SimpleName>

        node = next(suf)
        node_name = tag   

        graph.add_node(node, text=[node_name])
#         graph.add_node(node, text=node_name)
        
        # Recursively traverse the child nodes
        for child in vertice:
            child_node = dfs(child)
            graph.add_edge(child_node, node)
          
        if len(vertice) == 0:
            leaf_node = next(suf)
            leaf_name = vertice.text.strip()

            graph.add_node(leaf_node, text=[leaf_name])
#             graph.add_node(leaf_node, text=leaf_name)
            graph.add_edge(leaf_node, node)

            leafs.append(leaf_node)
            
        return node

    # TODO: add sink and leafs edges here, not outside
    
    dfs(root_node)

    return (graph, leafs, root_name)

# Merging Single-Entry Node Sequences
def dfs_merge_sequences(graph, node: str):
  # node: str, key of node in graph.nodes

  nodes_one_succ = []

  while(graph.in_degree(node) == 1):
      nodes_one_succ.append(node)
      node = list(graph.predecessors(node))[0]

  # leaf nodes sholdn't be merged, so in_degree != 0

  if len(nodes_one_succ) > 1:
      parent = nodes_one_succ[0]

      for child in nodes_one_succ[1:]:
          nx.contracted_nodes(graph, parent, child, self_loops=False, copy=False)

      # concatenate tokens of merged vertices
      p = graph.nodes(data=True)[parent]
        
      p['text'] = p['text'] + [args['text'][0] for args in p['contraction'].values()]
#       p['text'] += ',' + ','.join(args['text'] for args in p['contraction'].values()) # TODO

      del p['contraction']

  # now node is either a leaf or has degree >= 2

  for child in graph.predecessors(node):
      dfs_merge_sequences(graph, child)

# Merging Aggregation Structures
# TODO: merge only vertices with specific AST node names
def dfs_merge_aggregations(graph, node: str):

    # Recursively traverse the child nodes
    children_degs = [graph.in_degree(child) for child in graph.predecessors(node)]
    
    if not children_degs:
        # node is a leaf
        return

    if len(children_degs) >= 2 and max(children_degs) <= 1: # TODO
        # all children are either have deg=1 or a leaf
        # merging node and children
        children = list(graph.predecessors(node))
        
        for child in (children):
            nx.contracted_nodes(graph, node, child, self_loops=False, copy=False)
        
        parent = graph.nodes(data=True)[node]
        parent['text'] = [args['text'] for args in parent['contraction'].values()] + [parent['text']]
#         parent['text'] += '|' + \
#             '|'.join([args['text'] for args in parent['contraction'].values()]) # TODO
        del parent['contraction']

    for child in graph.predecessors(node):
        dfs_merge_aggregations(graph, child)

In [145]:
def build_graph(root_node:ET.Element) -> DiGraph:
    
    graph, leafs, root_name = dfs_build_graph(root_node)

    dfs_merge_sequences(graph, root_name)
    dfs_merge_aggregations(graph, root_name)

    # Adding edges between leafs (consequent initial code tokens)
    for u, v in zip(leafs[:-1], leafs[1:]):
        graph.add_edge(u, v)

    # Adding ROOT node - global graph's sink
    for v in graph.nodes:
        if v != root_name:
            graph.add_edge(v, root_name)

    return graph


In [154]:
def graph_to_data(graph:DiGraph, target:int) -> Data:
    graph = nx.convert_node_labels_to_integers(graph)
    return Data(edge_index=torch.tensor(list(graph.edges), dtype=torch.long).t().contiguous().view(2, -1),
                edge_attr=None,
                x=list(nx.get_node_attributes(graph, 'text').values()),
                y=target
                )

In [155]:
def method_words(method, threshold=1e6):
    words = 0
    if not method.findall('.Block'):
        return 0
    for elem in method.findall('.Block')[0].iter():
#     for elem in method.iter():
        if elem.text.strip():
            words += 1
            if words > threshold:
                return words
    return words
    
def get_tree_depth(element:ET.Element, level=0, threshold=4):
    """Return the depth of an ElementTree Element object."""
    if len(element) == 0:
        return level
    if level > threshold:
        return level
    return max(get_tree_depth(child, level + 1) for child in element)

def compare_nodes(node1, node2): # NOT USED
    """Recursively compare the attributes and children of two XML nodes."""
    # Check if the tag and attributes of the nodes are equal
    if node1.tag != node2.tag :#or node1.text != node2.attrib:
        return False

    # Check if the number of children of the nodes is equal
    if len(node1) != len(node2):
        return False

    # Recursively compare the children of the nodes
    for child1, child2 in zip(node1, node2):
        if not compare_nodes(child1, child2):
            return False

    return True

def get_roots(file_path:str, DEPTH_THRESHOLD=4, WORDS_THRESHOLD=5) -> tp.List[tp.Dict]:
    '''
    :return: list of dicts {'element': ET.element, 'words', 'depth', 'method_name', 'path'}
    '''
    
    tree = ET.parse(file_path)
    class_name = tree.findtext('.TypeDeclaration/SimpleName')
    if not class_name.startswith('CWE'):
#         raise ValueError(f'Class name "{class_name}" does not starts with CWE')
        print(f'Class name "{class_name}" does not starts with CWE')
    
    valid_methods = []
        
    for method in tree.findall('.//TypeDeclaration/MethodDeclaration'):
        
        method_name = method.findtext('.SimpleName')        
        if not re.match('^(good|bad).*$', method_name):
            continue
        
        depth = get_tree_depth(method, threshold=DEPTH_THRESHOLD)
        if depth <= DEPTH_THRESHOLD:
            continue
        
        words = method_words(method, WORDS_THRESHOLD)
        if words <= WORDS_THRESHOLD:
            continue
            
#         for elem in method.findall('.SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
#             if re.search(r'^(good|bad).*', elem.text):                
#                 elem.text = str(hash((class_name, elem.text)))
                
#         for elem in method.findall('.//MethodInvocation/SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
#             if re.search(r'^(good|bad).*', elem.text):                
#                 elem.text = str(hash((class_name, elem.text)))
        
        for elem in method.findall('.//SimpleName'): # changing all <SimpleName> NAME </SimpleName> to hash(class, NAME)
            if re.search(r'^(good|bad).*', elem.text):                
                elem.text = str(hash((class_name, elem.text)))
                
        valid_methods.append({
            'element':method,
#             'words':words,
#             'depth':depth,
            'method_name':method_name,
#             'path': file_path,
        })
        
    return valid_methods


In [271]:
def xml_to_Data(path_to_xmls,) -> tp.Tuple[tp.List[Data], tp.List[DiGraph]]:
    datas = [] 
    graphs = []
    files = 0

    for filename in tqdm(os.listdir(path_to_xmls)):
        file_path = os.path.join(path_to_xmls, filename)
        if not os.path.isfile(file_path):
            continue
        method_roots = get_roots(file_path)
        files += 1

        for method in method_roots: 
            graph = build_graph(method['element'])
            target = 1 if method['method_name'][:3] == 'bad' else 0
            data = graph_to_data(graph, target)

#             graphs.append(graph)
            datas.append(data)        

    print(f'Total files: {files}')
    return [datas, graphs]       

# Начало тут


In [272]:
datas, graphs = xml_to_Data('C:/Users/Arkady/Downloads/Juliet_AST/Juliet_AST')

100%|████████████████████████████████████████████████████████████████████████████| 46286/46286 [24:39<00:00, 31.28it/s]

Total files: 46286





In [279]:
len(datas)

128039

In [273]:
with open('data_v2_0.pickle', 'wb') as f:
    pickle.dump(datas, f)

In [274]:
datas[0].x[:20]

[['MethodDeclaration'],
 ['Modifier'],
 ['public'],
 ['PrimitiveType'],
 ['void'],
 ['SimpleName'],
 ['-6391724344061429216'],
 ['SimpleType', 'SimpleName'],
 ['IOException'],
 ['Block'],
 ['VariableDeclarationStatement'],
 ['SimpleType', 'SimpleName'],
 ['InputStreamReader'],
 [['SimpleName'], ['NullLiteral'], ['VariableDeclarationFragment']],
 ['readerInputStream'],
 ['null'],
 ['VariableDeclarationStatement'],
 ['SimpleType', 'SimpleName'],
 ['BufferedReader'],
 [['SimpleName'], ['NullLiteral'], ['VariableDeclarationFragment']]]

In [275]:
with open('data_v2_0_short.pickle', 'wb') as f:
    pickle.dump(datas[:1000], f)

# Разработка

In [276]:
def create_embedding(MODEL_NAME='all-MiniLM-L6-v2'):
    '''
    :return: returns callable function(data.x) -> torch.Tensor[num_nodes, EMBEDDING_SIZE]
    '''
    model = SentenceTransformer('sentence-transformers/' + MODEL_NAME, 
                                device = 'cuda' if torch.cuda.is_available() else 'cpu')
    EMBEDDING_SIZE = model.get_sentence_embedding_dimension()
    embeddings = dict()

    def encode_token(token:str):
        if token not in embeddings.keys():
            embeddings[token] = model.encode(token)
#             print(f'encoded new token: {token}')
        return embeddings[token]
    
    def f(args:tp.List[tp.List[tp.List]] = []):
        # args =  [[['SimpleType', 'SimpleName'], ['SimpleName'], ['SingleVariableDeclaration']], ...]
        # :return: torch.Tensor[len(args), EMBEDDING_SIZE]
        embs_0 = torch.zeros([len(args), EMBEDDING_SIZE], dtype=torch.float)

        for i, arg_1 in enumerate(args): 
            # node = [['SimpleType', 'SimpleName'], ['SimpleName'], ['SingleVariableDeclaration']]
            embs_1 = torch.zeros(EMBEDDING_SIZE)
            
            if isinstance(arg_1, str):
                embs_0[i] = encode_token(arg_1)
                continue
            
            for arg_2 in arg_1: # arg = ['SimpleType', 'SimpleName']
                embs_2 = torch.zeros(EMBEDDING_SIZE)
                if isinstance(arg_2, str):
                    embs_1 += encode_token(arg_2)
                    continue
                    
                for token in arg_2: # token = 'SimpleType'
                    embs_2 += encode_token(token)
                embs_2 /= len(arg_2)
                embs_1 += embs_2

            embs_1 /= len(arg_1)
            embs_0[i] = embs_1

        return embs_0

    return f


emb = create_embedding()
s = datas[0].x
e = emb(datas[3].x)
e.shape

torch.Size([53, 384])

In [248]:
calc_emb = get_embedding()

with open('data_v2_0_short.pickle', 'rb') as f:
    d = pickle.load(f)

In [277]:
e

tensor([[ 0.0016,  0.1044,  0.0081,  ...,  0.0169,  0.0683,  0.0132],
        [-0.0481,  0.0820, -0.0293,  ...,  0.0727, -0.0073, -0.0837],
        [-0.0344,  0.0216, -0.0430,  ...,  0.0538, -0.0877,  0.0188],
        ...,
        [-0.0376,  0.0746, -0.0152,  ...,  0.0045,  0.0392, -0.0632],
        [-0.0895, -0.0681, -0.0692,  ..., -0.0215,  0.0204, -0.0557],
        [-0.0973, -0.0501, -0.0441,  ...,  0.0514, -0.0049, -0.0276]])

In [None]:
MODEL_NAME = 'all-mpnet-base-v2'
# MODEL_NAME = 'all-MiniLM-L6-v2' # ~ 2 times faster, slightly worse quality
model = SentenceTransformer('sentence-transformers/' + MODEL_NAME, 
                                device = 'cuda' if torch.cuda.is_available() else 'cpu')
x = torch.zeros(model.get_sentence_embedding_dimension())
for i in tqdm(range(1000)):
  s = generate_random_string(10)
  x += model.encode(s)
x

100%|██████████| 1000/1000 [00:15<00:00, 64.89it/s]


tensor([ 2.3247e+01, -1.7360e+01,  1.6009e+01,  1.4016e+01, -3.5395e+01,
         4.6722e+01, -3.1937e+01,  2.9005e+01,  1.5794e+00,  6.3628e+00,
         4.5635e+01, -1.1611e+01,  3.3276e+01,  4.9039e+01, -1.5933e+01,
         6.1741e+00, -1.4937e+01, -9.2308e+00,  1.4600e+01, -2.0914e+01,
        -4.6614e+01,  2.2514e+01, -1.9158e+01, -5.9237e+00, -2.9356e+01,
        -6.1191e+00, -2.4338e+00, -1.5150e+01, -1.5961e+01,  2.2583e+01,
         1.8625e+01,  2.8759e+01, -3.0815e+01,  1.5404e+01,  2.0227e-03,
        -2.1144e+01, -1.8557e+01, -5.6416e+00, -2.4728e+01,  3.2059e+01,
        -3.4298e+01,  1.9746e+01, -2.0892e+01, -1.5802e+01, -9.2973e+00,
        -1.4842e+01,  1.8361e+01,  1.0652e+01, -1.7804e+01,  8.0665e+00,
         1.8526e+00, -3.1671e+01,  1.6687e+01, -3.6816e-01,  6.3781e+01,
        -4.7884e+01,  2.7504e+01,  2.8632e+01, -1.8908e+01,  2.8738e+01,
         1.2851e+01,  2.6709e+01,  1.6451e+01, -8.5416e+00,  1.3588e+01,
         1.4132e+01,  5.8880e+00,  1.3587e+00,  2.8

In [28]:
# model = SentenceTransformer('sentence-transformers/' + MODEL_NAME, 
#                                 device = 'cuda' if torch.cuda.is_available() else 'cpu')
# x = torch.zeros(384)
# for i in tqdm(range(10000)):
#   s = generate_random_string(20)
#   x += model.encode(s)
# x

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import tp

from embedding import create_embedding   # !

class GCN(torch.nn.Module):
    def __init__(self, embedding=create_embedding):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

        self.embedding = create_embedding()  # !

    def forward(self, data):
        x:tp.List[str] = data.x            # !
        x:torch.Tensor = self.embedding(x) # !

        edge_index = data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

'123'