In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import mpld3

mpld3.enable_notebook()

In [2]:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter

In [3]:
from attention import load_model
from vocab import tokenize_nl, load_vocabs, tok_type2id, START, END, UNK, SKIP_TOKENS
from py2_tokenize import tokenize_code
import dynet as dy
import math
from itertools import chain, combinations
from asttokens import ASTTokens
from copy import deepcopy
from itertools import product
import astor
import ast
from _ast import stmt, Pass, If, TryExcept, TryFinally

In [4]:
code2nl_model, code2nl_translator = load_model('code2nl_0419113212_model_dmp')
nl2code_model, nl2code_translator = load_model('nl2code_0419113223_model_dmp')

In [5]:
nl_voc2wid, nl_wid2voc, code_voc2wid, code_wid2voc = load_vocabs('./vocab.dmp')

In [6]:
def lookup_nl(seqs):
    return [[(START,)] + map(lambda w:(nl_voc2wid[w],), seq) + [(END,)] for seq in seqs]

In [7]:
def lookup_code(seqs):
    return [[(0, START)] + map(lambda w:(tok_type2id[w[0]], nl_voc2wid[w[1]]), seq) + [(0, END)] for seq in seqs]

In [8]:
def lookup_code(seqs):
    return [[(0, START)] + map(lambda w:(w[0], nl_voc2wid[w[1]]), seq) + [(0, END)] for seq in seqs]

In [9]:
def bi_likelihood(nl, code):
    dy.renew_cg()
    nl = tokenize_nl(nl)
    code = [(token_type, token_literal) for token_type, token_literal in tokenize_code(code) if token_type not in SKIP_TOKENS]
    nl = lookup_nl([nl])
    code = lookup_code([code])
    nl2code_prob = nl2code_translator.calc_loss(nl, code, training=False)
    code2nl_prob = code2nl_translator.calc_loss(code, nl, training=False)
    return (nl2code_prob.value() / sum(map(len, code)), code2nl_prob.value() / sum(map(len, nl)))

In [10]:
def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [11]:
def subsets(s):
    return map(list, powerset(s))

In [12]:
def sub_snippets(code_snippet):
    tree = ast.parse(code_snippet)
    
    def _subs(node, name=None):
        if isinstance(node, ast.AST):
            field_values = []
            field_names = []
            for subname, subnode in ast.iter_fields(node):
                subvalues = list(_subs(subnode, subname))
                if len(subvalues) > 0:
                    field_names.append(subname)
                    field_values.append(subvalues)
            count = 0
            for values in product(*field_values):
                newnode = deepcopy(node)
                for name, value in zip(field_names, values):
                    setattr(newnode, name, value)
                count += 1
                if isinstance(newnode, If) and len(newnode.orelse) == 0 and isinstance(newnode.body[0], Pass):
                    yield newnode.test
                    continue
                if isinstance(newnode, TryExcept) and len(newnode.orelse) == 0 and isinstance(newnode.body[0], Pass):
                    continue
                if isinstance(newnode, TryFinally) and isinstance(newnode.body[0], Pass) and isinstance(newnode.finalbody[0], Pass):
                    continue
                yield newnode
            if count == 0:
                yield node
        elif isinstance(node, list):
            if len(node) == 0:
                yield node
                return
            values = [list(_subs(x)) for x in node]
            stat_count = len(node)
            for subset in subsets(range(stat_count)) if isinstance(node[0], stmt) else [range(stat_count)]:
                if stat_count >= len(subset) > 0:
                    for stat_list in product(*map(lambda x:values[x], subset)):
                        yield list(stat_list)
                else:
                    if name in ('body', 'finalbody') and len(node) > 0:
                        yield [ast.parse("pass").body[0]]
                    if name == 'orelse':
                        yield []
        else:
            yield node
    
    sub_snippet_set = set()
    
    for node in _subs(tree):
        sub_snippet_set.add(astor.to_source(node))
    
    return sub_snippet_set

In [13]:
def getHtml(code):
    formatter = HtmlFormatter(linenos=True)
    return highlight(code, PythonLexer(), formatter)

In [14]:
def visiualize(nl, snippets, likelihoods):
    fig, ax = plt.subplots(subplot_kw=dict(axisbg='#EEEEEE'), figsize=(10,5))
    ax.set_xlabel('P ( Snippet | Intent )')
    ax.set_ylabel('P ( Intent | Snippet )')
    ax.set_title(nl)
    n2c, c2n = map(list, zip(*likelihoods))
    scatter = ax.scatter(n2c,
                         c2n,
                         alpha=0.3,
                         cmap=plt.cm.jet)
    ax.grid(color='white', linestyle='solid')
    css = HtmlFormatter().get_style_defs('.highlight')
    labels = map(lambda x:getHtml('#nl2code: %f\n#code2nl: %f\n%s' % (x[0][0], x[0][1], x[1])), zip(likelihoods, snippets))
    tooltip = mpld3.plugins.PointHTMLTooltip(scatter, labels, css=css)
    mpld3.plugins.connect(fig, tooltip)
    return mpld3.display()

In [15]:
#http://162.243.20.48:7788/post/306400
nl = '''How do I randomly select an item from a list using Python?'''
code = '''import random

foo = ['a', 'b', 'c', 'd', 'e']
print(random.choice(foo))'''

In [16]:
#http://162.243.20.48:7788/post/120656
nl = '''Directory listing in Python'''
code = '''import os
from os import listdir

for filename in os.listdir("C:\\temp"):
    print  filename'''

In [17]:
#https://stackoverflow.com/questions/610883/how-to-know-if-an-object-has-an-attribute-in-python
nl = '''How to know if an object has an attribute in Python'''
code = '''if hasattr(a, 'property'):
    a.property'''

In [18]:
#https://stackoverflow.com/questions/610883/how-to-know-if-an-object-has-an-attribute-in-python
nl = '''How to know if an object has an attribute in Python'''
code = '''if hasattr(a, 'property'):
    doStuff(a.property)
else:
    otherStuff()'''

In [19]:
#http://162.243.20.48:7788/post/306400
nl = '''Getting the length of an array in Python'''
code = '''my_list = [1,2,3,4,5]
print 'test'
len(my_list)'''

In [20]:
snippets = list(sub_snippets(code))
likelihoods = [bi_likelihood(nl, code) for code in snippets]

In [22]:
visiualize(nl, snippets, likelihoods)