Skip to content

Commit

Permalink
allow marking idiomatic code
Browse files Browse the repository at this point in the history
  • Loading branch information
manishshettym committed Feb 19, 2024
1 parent 42136b2 commit 5ba196f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
21 changes: 16 additions & 5 deletions codescholar/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import os.path as osp
import openai
import json
import networkx as nx
import torch
from networkx.readwrite import json_graph

from codescholar.utils.graph_utils import nx_to_sast
from codescholar.sast.sast_utils import sast_to_prog
from codescholar.utils.search_utils import read_prog
from codescholar.utils.search_utils import read_prog, read_graph

from codescholar.constants import DATA_DIR

Expand Down Expand Up @@ -98,17 +100,26 @@ def get_result_from_dir(api, api_cache, select_size):
data = json.load(f)
idx = data["index"]
prog_path = f"{DATA_DIR}/pnosmt/source/example_{idx}.py"
graph_path = f"{DATA_DIR}/pnosmt/graphs/data_{idx}.pt"

with open(prog_path, "r") as f:
prog = f.read()
p_graph = torch.load(graph_path, map_location=torch.device("cpu"))

# highlight idiom in prog
i_graph = json_graph.node_link_graph(data["graph"])
i_subg = p_graph.subgraph(i_graph).copy()
i_subg.remove_edges_from(nx.selfloop_edges(i_subg))
for v in p_graph.nodes:
p_graph.nodes[v]["is_idiom"] = 1 if v in i_subg.nodes else 0

graph = json_graph.node_link_graph(data["graph"])
sast = nx_to_sast(graph)
idiom = sast_to_prog(sast).replace("#", "_")
# convert to sast and extract highlighted code
idiom_prog = sast_to_prog(nx_to_sast(p_graph), mark_idiom=True)

results.update(
{
count: {
"idiom": idiom,
"idiom": idiom_prog,
"size": size,
"cluster": cluster,
"freq": nhood_count,
Expand Down
12 changes: 11 additions & 1 deletion codescholar/sast/sast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def replace_nonterminals(node, child_spans):
return node


def sast_to_prog(sast: ProgramGraph):
def sast_to_prog(sast: ProgramGraph, mark_idiom=False):
"""perform an dfs traversal and regenerate prog"""

def dfs_util(sast: ProgramGraph, node, visited):
Expand All @@ -270,6 +270,16 @@ def dfs_util(sast: ProgramGraph, node, visited):
for node in sast.all_nodes():
visited[node.id] = False

if mark_idiom and node.is_idiom:
parts = node.span.split("#")
marked_parts = []
for i, part in enumerate(parts):
if part.strip():
marked_parts.append(f"<mark>{part}</mark>")
if i < len(parts) - 1:
marked_parts.append("#")
node.span = "".join(marked_parts)

for node in sast.all_nodes():
if not visited[node.id]:
dfs_util(sast, node, visited)
Expand Down

0 comments on commit 5ba196f

Please sign in to comment.