In [4]:
import torch
from sklearn.linear_model import LinearRegression
import numpy as np
import os
import string

def load_tensors(task):
    # Construct the file paths
    a_path = f"./tasks/{task}/A_best.pt"
    b_path = f"./tasks/{task}/b_best.pt"
    z_path = f"./tasks/{task}/Z_best.pt"
    z2_path = f"./tasks/{task}/Z2_best.pt"

    # Load the tensors
    A = torch.load(a_path)
    b = torch.load(b_path)
    Z = torch.load(z_path)
    Z2 = torch.load(z2_path)

    return A, b, Z, Z2

def linear(input, output, show_eqn=True, rounding=True):
    
    reg = LinearRegression().fit(input, output)
    score = reg.score(input, output)
    coeff = reg.coef_
    intercept = reg.intercept_

    if show_eqn:
        if rounding == False:
            equation = "y = " + " + ".join([f"{coef}*x{idx}" for idx, coef in enumerate(coeff)]) + f" + {intercept}"
        else:
            equation = "y = " + " + ".join([f"{np.round(coef).astype(int)}*x{idx}" for idx, coef in enumerate(coeff)]) + f" + {np.round(intercept).astype(int)}"
        #print(f"Linear Equation:{equation}")
        
    if rounding == False:
        return score, coeff, intercept, reg
    else:
        return score, np.round(coeff).astype(int), np.round(intercept).astype(int), reg

def polynomial_fit(input, output, degree=2, show_eqn=True, rounding=True):
    poly = PolynomialFeatures(degree=degree)
    input_poly = poly.fit_transform(input)

    reg = LinearRegression().fit(input_poly, output)
    score = reg.score(input_poly, output)

    if show_eqn:
        # Formatting coefficients with corresponding powers
        if rounding == False:
            equation_terms = [f"{coef}*x^{i}" for i, coef in enumerate(reg.coef_)]
        else:
            equation_terms = [f"{np.round(coef)}*x^{i}" for i, coef in enumerate(reg.coef_)]
        equation = "y = " + " + ".join(equation_terms) + f" + {reg.intercept_}"
        print(f"Polynomial Equation (degree {degree}): {equation}")
        
    if rounding == False:
        return score, reg.coef_, reg.intercept_, reg
    else:
        print(np.round(reg.coef_))
        return score, np.round(reg.coef_), np.round(reg.intercept_), reg

# Example usage
#task_name = "rnn_identity_numerical"  # replace with your task name


def get_data(task_name):
    A, b, Z, Z2 = load_tensors(task_name)
    Z = np.round(Z); Z2 = np.round(Z2)
    data = torch.load(f"../tasks/{task_name}/data.pt")
    inputs = data[0].detach().numpy()
    outputs = data[1].detach().numpy()

    inputs_last = data[0][:,[-1]].detach().numpy()
    outputs_last = data[1][:,[-1]].detach().numpy()
    return A, b, Z, Z2, inputs, inputs_last, outputs, outputs_last

In [43]:
# produce program

def produce_program(task_name, print_code=False):
    data = get_data(task_name)
    A, b, Z, Z2, inputs, inputs_last, outputs, outputs_last = data
    
    # linear regression (input, hidden) => hidden
    combine_input_Z2 = np.concatenate([inputs_last, Z2], axis=1)
    _, coeff_h, intercept_h, _ = linear(combine_input_Z2, Z)
    coeff_i2h, coeff_h2h = coeff_h[:,:inputs_last.shape[1]], coeff_h[:,inputs_last.shape[1]:]
    
    # linear regression hidden => output
    _, coeff_o, intercept_o, _ = linear(Z, outputs_last)
    
    num_seq = inputs.shape[1]
    num_data = inputs_last.shape[0]
    dim_hidden = Z.shape[1]

    # note that here hidden is not zero! it should be translated to lattice coordinate
    # Ah + b = 0
    h = - np.round(np.matmul(b, np.linalg.inv(A))).astype('int')
    
    variables = string.ascii_lowercase[:14] # hidden variables denoted as a,b,c,d,...; input variable denoted as x; output variable denoted as y.
    # initialize_code (initialize hidden states)
    initialize_code = ""
    if len(h.shape) == 2:
        h = h[0]
    for i in range(dim_hidden):
        initialize_code += f"{variables[i]} = {h[i]};"

    # hidden code
    hidden_code = ""
    for i in range(dim_hidden):
        hidden_i_code = f"{variables[i]} = "
        for j in range(dim_hidden):
            hidden_i_code += f"+{coeff_h2h[i,j]}*{variables[j]}"
        hidden_i_code += f"+{coeff_i2h[i,0]}*x+{intercept_h[i]}"
        if i == 0:
            hidden_code = hidden_i_code
        else:
            hidden_code += "\n        " + hidden_i_code
            
    
    # output code
    output_code = f"y = "
    for j in range(dim_hidden):
        output_code += f"+{coeff_o[0,j]}*{variables[j]}"
    output_code += f"+{intercept_o[0]}"
    
    
    function_code = f"""
def f(s):
    {initialize_code}
    ys = []
    for i in range({num_seq}):
        x = s[i]
        {hidden_code}
        {output_code}
        ys.append(y)
    return ys
    """
    
    preprocess_code = f"""
import numpy as np
data = get_data(\"{task_name}\")
A, b, Z, Z2, inputs, inputs_last, outputs, outputs_last = data
num_example = inputs.shape[0]
    """
    
    verify_code = f"""
wrong = 0
for i in range(num_example):
    if i % 100000 == 0:
        print(i)
    out_pred = f(inputs[i])
    wrong += np.sum(1 - (np.array(outputs[i]) == np.array(out_pred)))
    
if wrong == 0:
    print('verification success')
else:
    print('verification failure')
"""
    
    code = f"""
{preprocess_code}
{function_code}
{verify_code}
"""
    if print_code == True:
        print(code)
    return code 
    

In [45]:
#code = produce_program('rnn_sum_numerical', print_code=True);
code = produce_program('rnn_prev3_numerical', print_code=True)
#code = produce_program('rnn_parity_last2_numerical', print_code=True)



import numpy as np
data = get_data("rnn_prev3_numerical")
A, b, Z, Z2, inputs, inputs_last, outputs, outputs_last = data
num_example = inputs.shape[0]
    

def f(s):
    a = 0;b = 0;c = 99;d = 99;
    ys = []
    for i in range(10):
        x = s[i]
        a = +0*a+1*b+0*c+0*d+0*x+0
        b = +0*a+0*b+-1*c+0*d+0*x+99
        c = +0*a+0*b+0*c+1*d+0*x+0
        d = +0*a+0*b+0*c+0*d+-1*x+99
        y = +1*a+0*b+0*c+0*d+0
        ys.append(y)
    return ys
    

wrong = 0
for i in range(num_example):
    if i % 100000 == 0:
        print(i)
    out_pred = f(inputs[i])
    wrong += np.sum(1 - (np.array(outputs[i]) == np.array(out_pred)))
    
if wrong == 0:
    print('verification success')
else:
    print('verification failure')




In [46]:
tasks = [f for f in os.listdir('./tasks') if not f.startswith('.')]
for task_name in tasks:
    code = produce_program(task_name, print_code=False)
    print(task_name)
    exec(code)

rnn_sum_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification success
rnn_parity_last3_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification failure
rnn_prev3_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification success
rnn_identity_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification success
rnn_parity_last2_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification failure
rnn_prev2_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification failure
rnn_prev4_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification failure
rnn_parity_last4_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification failure
rnn_prev1_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verification success
rnn_sum_last4_numerical
0
100000
200000
300000
400000
500000
600000
700000
800000
verif