In [1]:
import sys
sys.path.insert(0, '..')

import equation_vae

%matplotlib inline
import numpy as np
from numpy import sin, exp, cos
from matplotlib import pyplot as plt

from collections import defaultdict

Using TensorFlow backend.


In [2]:
STEPS = 9

grammar_weights = "../eq_vae_grammar_h100_c234_L25_E50_batchB.hdf5"
grammar_model = equation_vae.EquationGrammarModel(grammar_weights, latent_rep_size=25)

char_weights = "../eq_vae_str_h100_c234_L25_E50_batchB.hdf5"
char_model = equation_vae.EquationCharacterModel(char_weights, latent_rep_size=25)



# We want to interpolate between two equations, in each model.

In [3]:
pairs = []
char_seq = []

for seed in xrange(10):
    np.random.seed(seed)
    z = np.random.randn(2,25)*np.random.uniform(0.05,0.2)
    print "SEED", seed, "DISTANCE %0.4f" % np.sqrt(np.sum((z[0]-z[-1])**2))
    # print z
    steps = STEPS
    z = (np.linspace(0,1,steps)[None,:]*z[0][:,None] + np.linspace(1,0,steps)[None,:]*z[1][:,None]).T
    # print z.shape

    rep = 200
    z = z[:,None,:] * np.ones((steps,rep,25))
    res = np.array(char_model.decode(z.reshape((steps*rep,25)))).reshape((steps,rep))

    freq = [defaultdict(int) for i in xrange(steps)]

    for i in xrange(steps):
        for eq in list(res[i]):
            if eq != '':
                freq[i][eq] += 1
        freq[i] = freq[i].items()
        freq[i] = freq[i][np.argmax([c[1] for c in freq[i]])][0]

    pairs.append((freq[0], freq[-1]))
    char_seq.append(list(freq))
    #print "\n".join(freq)
    #print "\n\n"

SEED 0 DISTANCE 0.5238
SEED 1 DISTANCE 1.0171
SEED 2 DISTANCE 0.7144
SEED 3 DISTANCE 1.0359
SEED 4 DISTANCE 0.6471
SEED 5 DISTANCE 0.9055
SEED 6 DISTANCE 1.3074
SEED 7 DISTANCE 0.9990
SEED 8 DISTANCE 0.3888
SEED 9 DISTANCE 1.2363


In [4]:
def is_valid(eq):
    try:
        parsetree = grammar_model._parser.parse(equation_vae.tokenize(eq)).next()
        return True
    except:
        return False

valid_pair_ix = np.array([(is_valid(a) and is_valid(b)) for (a,b) in pairs])
valid_pairs = np.array(pairs)[valid_pair_ix]

print np.array(pairs)
print
print valid_pairs

[['3*x+(x)+exp(3*x)' 'x*3+exp(3+exp(3)']
 ['1*3+(x)+ex)(x*3)' '2*2+(x)+exp(2/3)']
 ['2*x+(1)+exp(x*x)' '2*2+exp(1)+exp(x)']
 ['1/1+eep(33+sin11/1)' 'x*2+(3)+exp(2*2)']
 ['2*2+3+exp(x*3)' '3*2+exp(x)+(x/2)']
 ['3*3+sin(3)+(3*3)' 'x+2+exp(1)+sin(1*2)']
 ['3*3+sxp(2)+(x*1)' '1*3+exp(2)+(i*1)']
 ['3*x+sin(3)+sin(1)' '3*1+3+)(1*3)']
 ['3*x+exp(2)+(2*2)' 'x*1+sin(2)+(x*x)']
 ['2*3+exp33+1' '2*x+(x)++exp(x)3)']]

[['2*x+(1)+exp(x*x)' '2*2+exp(1)+exp(x)']
 ['2*2+3+exp(x*3)' '3*2+exp(x)+(x/2)']
 ['3*3+sin(3)+(3*3)' 'x+2+exp(1)+sin(1*2)']
 ['3*x+exp(2)+(2*2)' 'x*1+sin(2)+(x*x)']]


In [5]:
z = grammar_model.encode(valid_pairs.ravel().tolist())

def get_grammar_mode(z):
    steps = z.shape[0]
    rep = 100
    z = z[:,None,:] * np.ones((steps,rep,25))
    res = np.array(grammar_model.decode(z.reshape((steps*rep,25)))).reshape((steps,rep))

    freq = [defaultdict(int) for i in xrange(steps)]

    for i in xrange(steps):
        for eq in list(res[i]):
            if eq != '':
                freq[i][eq] += 1
        freq[i] = freq[i].items()
        try:
            freq[i] = freq[i][np.argmax([c[1] for c in freq[i]])][0]
        except:
            freq[i] = 'None'
    return np.array(freq).reshape((steps/2,2)).tolist()


## The grammar model will successfully encode some of these equations.

In [6]:
dec = get_grammar_mode(z)
matched_pairs = []
matched_ix = []
matched_grammar_z = []
for i in xrange(valid_pairs.shape[0]):
    if np.all(np.array(map(str,dec[i])) == valid_pairs[i]):
        matched_pairs.append(valid_pairs[i])
        matched_ix.append(np.arange(len(pairs))[valid_pair_ix][i])
        matched_grammar_z.append(z.reshape((-1,2,25))[i])

print np.array(matched_pairs)

[['2*x+(1)+exp(x*x)' '2*2+exp(1)+exp(x)']
 ['2*2+3+exp(x*3)' '3*2+exp(x)+(x/2)']
 ['3*3+sin(3)+(3*3)' 'x+2+exp(1)+sin(1*2)']
 ['3*x+exp(2)+(2*2)' 'x*1+sin(2)+(x*x)']]


### Results for the character model:

In [7]:
for ix in matched_ix:
    print char_seq[ix][0]
    for j in xrange(1,len(char_seq[ix])):
        if char_seq[ix][j] != char_seq[ix][j-1]:
            print char_seq[ix][j]
    print

2*x+(1)+exp(x*x)
2*x+(2)+exp(x*3)
2*x+(x)+exp(x*x)
2*1+(x)+exp(x*3)
2*2+exp3)+exp(2)
3*3+exp(3)+exp(2)
2*2+exp(3)+exp(2)
2*x+exp(3)+exp(2)
2*2+exp(1)+exp(x)

2*2+3+exp(x*3)
2*1+exexxexp(3)
2*3+exp(3+ex*3)
2*x+exp(3)+(x*1)
3*1+exp(3)+(2*1)
3*x+exp(3)+(2*1)
3*x+exp(3)+(2*3)
3*3+exp(3)+(x/3)
3*2+exp(x)+(x/2)

3*3+sin(3)+(3*3)
x*3+exp(3)+(3*2)
x+2+exp(x)+(2*3)
x+2+exp(2)+(2*1)
x+1+exp(1)+(1*2)
x+2+exp(1)+(1*1)
x+1+exp(x)+sin(1)
x+x+exp(2)+sin(1*3)
x+2+exp(1)+sin(1*2)

3*x+exp(2)+(2*2)
x*3+exp(1)+(2*3)
x*3+exp(x)+(2*3)
x*2+exp(x)+(x*1)
x*x+exp(x)+(x*3)
x*1+exp(x)+ex*3)
2*1+exp(1)+ex*x)
3*1+eip(x)+ex*2)
x*1+sin(2)+(x*x)



### For the grammar model, we have the corresponding points in the latent space:

In [8]:
def interp_grammar(z_list, steps=STEPS):
    for z in z_list:
        z = (np.linspace(0,1,steps)[None,:]*z[0][:,None] + np.linspace(1,0,steps)[None,:]*z[1][:,None]).T
        # print z.shape

        rep = 100
        z = z[:,None,:] * np.ones((steps,rep,25))
        res = np.array(grammar_model.decode(z.reshape((steps*rep,25)))).reshape((steps,rep))

        freq = [defaultdict(int) for i in xrange(steps)]

        for i in xrange(steps):
            for eq in list(res[i]):
                if eq != '':
                    freq[i][eq] += 1
            freq[i] = freq[i].items()
            freq[i] = freq[i][np.argmax([c[1] for c in freq[i]])][0]

        print freq[0]
        for j in xrange(1,len(freq)):
            if freq[j] != freq[j-1]:
                print freq[j]
        print
    
interp_grammar(matched_grammar_z)

2*2+exp(1)+exp(x)
2*2+exp(1)+exp(x*2)
2*3+exp(1)+exp(x*x)
2*x+(1)+exp(x*x)

3*2+exp(x)+(x/2)
2*2+exp(x)+(x)
2*2+exp(3)+(x)
2*2+3+(x/x)
2*2+3+(x/3)
2*2+3+(x*3)
2*2+3+exp(x*3)

x+2+exp(1)+sin(1*2)
x+2+exp(x)+sin(1*2)
x/2+sin(x)+exp(x*2)
3/2+sin(x)+exp(x*3)
3*2+sin(x)+exp(x*3)
3*3+sin(3)+(x*3)
3*3+sin(3)+(3*3)

x*1+sin(2)+(x*x)
x*1+exp(2)+(3*3)
3*1+exp(2)+(3*3)
3*x+exp(2)+(3*2)
3*x+exp(2)+(2*2)

