In [31]:
import time
import numpy as np
import scipy, itertools
import torch
from sklearn.linear_model import LinearRegression
import numpy as np
import os
import string
from neural_verification import *
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering, KMeans
from itertools import permutations, product

'''
MIPS contains three regression methods: lienar regression, boolean regression and symbolic regression.
This file implments boolean regression (BR).
Taking in integer data table (output from integer autoencoder), BR aims to find boolean relations among them.
'''


def bitstring2int(s): return int(s,2)

def int2bitstring(n,i): 
    s = bin(i)[2:]
    return "".join(["0" for j in range(n-len(s))])+s

def str2int(lst): return list(map(int,lst))

# Computes all 2^n bitstrings f length n:
def allstrings(n): [int2bitstring(n,i) for i in range(2**n)]

# Computes B-matrix whose rows are all 2^n bit strings of length n:
def Bmatrix(n): 
    return np.array([str2int(list(int2bitstring(n,i))) for i in range(2**n)])

def list2nN(lst):
    N = len(lst)
    n = int(np.log(N)/np.log(2))
    if 2**n != N: 
        print("List length not power of 2: ",f,N,n)
        exit()
    return n,N                

# Input:   a list of 2^n vectors
# Output:  fitting error to model where these vectors form an n-dimensional parallelogram in canonical order
# Demo:  vectorlist =[[10,20],[10,22],[11,21],[11,23]]
def parallelogramFit(vectorlist):
    n,N = list2nN(vectorlist)
    A = np.array(vectorlist)
    A = A - A[0] # WLOG 1st point is at the origin
    B = Bmatrix(n)
    BtB = B.T @ B
    BBinv = np.linalg.inv(BtB)
    X = BBinv @ B.T @ A
    E = A - B @ X # Fitting error
    error = np.trace(E.T @ E)/np.trace(A.T @ A) # Between & 1, where 0 = perfect
    return error


# Input:   a list of 2^n vectors
# Output:  a list of 2^b bitstring of length n, labeling these vectors
# Demo:  vectorlist =[[10,20],[10,22],[11,21],[11,23]]
# Calls parallelogramFit for all permutations of the vectors and returns best fit.

# Computes inverse permutation:
def invperm(perm):
    n = len(perm)
    p = [0 for i in range(n)]
    for i in range(n): p[perm[i]]=i
    return p
# Demo: invperm([1,2,3,0])

def vecs2bits(vectorlist):
    n,N = list2nN(vectorlist)
    perms = list(itertools.permutations(range(N)))
    besterror = 666.
    for perm in itertools.permutations(range(N)):
        A = [vectorlist[i] for i in perm]
        error = parallelogramFit(A)
        #print(perm,error)
        if error < besterror:
            error = besterror
            bestperm = perm
    B = Bmatrix(n)
    p = invperm(bestperm)
    bestB = np.array([B[p[i]] for i in range(N)]).astype(int)
    return bestB,besterror


# three symbolic functions to be learned
# out from h[0], h[1]
# h[0] from last h[0], h[1], current x[0], x[1]
# h[1] from last h[0], h[1], current x[0], x[1]


# Tools for symbolic regression of boolean functions
# Max Tegmark Aug 24-25 2023

import time
from math import log

def bitstring2int(s): return int(s,2)

def int2bitstring(n,i): 
    s = bin(i)[2:]
    return "".join(["0" for j in range(n-len(s))])+s

# A boolean function f is defined as a string f of length 2^n, say "00010001"
# The argument list is defined as a string x length n ,say "011"
# Returns char "0" or "1"
def booleval(f,x):
    n = len(x)
    if len(f) != 2**n: 
        print("String length mismatch error: ",n,2**n,len(s))
        exit()
    i = bitstring2int(x)
    return f[i]
# DEMO: booleval("11111111","111")

# Flip the ith bit in the bitstring n
def flip_bit(x,i):
    s = list(x)
    s[i] = str(1-int(s[i]))
    return "".join(s)
# DEMO: flip_bit("11111111",2)

def f2nN(f):
    N = len(f)
    n = int(log(N)/log(2))
    if 2**n != N: 
        print("String length not power of 2: ",f,N,n)
        exit()
    return n,N

def bitsum(s): return sum([int(c) for c in s])

def find_variables_used(f):
    # Returns e.g. "101" if function f depends on x_0 & x_2 but not x_1
    n,N = f2nN(f)
    s = ["0" for i in range(n)]
    for i in range(N):
        x = int2bitstring(n,i)
        for k in range(n):
            if booleval(f,x) != booleval(f,flip_bit(x,k)): s[k] = "1"
    return "".join(s)
# DEMO: find_variables_used("01100110") # x[1] XOR x[2]

# Return the function f restricted to the only variables it depends on:
def subfunc(f):
    used = find_variables_used(f)
    n = len(used)
    vars = [i for i in range(n) if used[i]=="1"]
    n1 = len(vars)
    N1 = 2**n1
    f1 = ["0" for i in range(N1)]
    for i in range(N1):
        x1 = int2bitstring(n1,i)
        x  = ["0" for i in range(n)]
        for k in range(n1): x[vars[k]] = x1[k]
        f1[i] = booleval(f,"".join(x))
    return f1,vars
# DEMO: subfunc("01100110") # x[1] XOR x[2]

def parallelogram():
    return

def symmetricQ(f):
    # Check if f is fully symmetric under all permutations of its variables, thus depending only on variable sum
    # If so, returns string giving value taken for each bit sum, otherwise returns "".
    # 2**(n+1) out of the 2**N functions are symmetric.
    n,N = f2nN(f)
    s = ["-" for i in range(n+1)]
    for i in range(N):
        x = int2bitstring(n,i)
        k = bitsum(x)
        c = booleval(f,x) 
        if s[k] != c:
            if s[k] == "-": s[k] = c
            else: return ""
    return "".join(s)

# Write x as boolean condition"
def x2dnf(x,varnames):
    n = len(x)
    return " and ".join([["not ",""][int(x[i])]+varnames[i] for i in range(n)])
# Demo: x2dnf("110","abc") 
# returns "a and b and not c"

# Write f in disjunctive normal form:
def f2dnf(f,varnames):
    n,N = f2nN(f)
    return " or ".join(["("+x2dnf(int2bitstring(n,i),varnames)+")" for i in range(N) if f[i]=="1"])
# Demo: f2dnf("01100110","abc") 
# returns "a and b and not c"

# Checks if string of type "010101010":
def parityQ(s):
    x = [int(c) for c in s]
    if x[0] != 0: return False
    for i in range(len(x)-1):
        if x[i]+x[i+1] != 1: return False
    return True
# DEMO: parityQ("01010")
    
# Checks if string is sorted, like e.g. "0000111":
def sortedQ(s):
    for i in range(len(s)-1):
        if s[i]>s[i+1]: return False
    return True

# Returns 4 if s="0000111", returns -1 if not step function
def stepupQ(s):
    if not sortedQ(s): return -1
    for i in range(len(s)): 
        if s[i]=="1": return i-1
    return -1

# Returns 4 if s="1111000", returns -1 if not step function
def stepdownQ(s):
    if not sortedQ("".join(reversed(s))): return -1
    for i in range(len(s)): 
        if s[i]=="0": return i-1
    return -1

def symmfunc(s,varsum):
    return " or ".join([varsum+"=="+str(i) for i in range(len(s)) if s[i]=="1"])
# Demo: symmfunc("1001","abc") 
# returns "a+b+c==0 or a+b+c==3"

# Given a string s specifying how function depends on bit sum, return the formula:
def symmetric_formula(s,varnames):
    if len(varnames)==1 and s=="10": return "not "+varnames[0]
    if parityQ(s): return " xor ".join(varnames) # f is xor of all variables
    varsum = "+".join(varnames)
    i = stepupQ(s) 
    if i >= 0: return varsum+">"+str(i)
    i = stepdownQ(s) 
    if i >= 0: return varsum+"<"+str(i+1)
    return symmfunc(s,varsum)

def formula(f):
    f1,vars = subfunc(f)
    varnames = [chr(97+i) for i in vars]
    if varnames == []: return str(bool(int(f[0]))) # Function is a constant
    formula1 = f2dnf(f1,varnames)
    s = symmetricQ(f1)
    if s == "": return formula1
    formula2 = symmetric_formula(s,varnames)
    if len(formula2)<len(formula1): formula1 = formula2 # Pick shortest formula
    return formula1


'''def load_tensors(task):
    # Construct the file paths
    a_path = f"./tasks/{task}/A_best.pt"
    b_path = f"./tasks/{task}/b_best.pt"
    z_path = f"./tasks/{task}/Z_best.pt"
    z2_path = f"./tasks/{task}/Z2_best.pt"
    hidden = f"./tasks/{task}/hidden.pt"
    hidden2 = f"./tasks/{task}/hidden2.pt"

    # Load the tensors
    A = torch.load(a_path)
    b = torch.load(b_path)
    Z = torch.load(z_path)
    Z2 = torch.load(z2_path)

    hidden = torch.load(hidden)
    hidden2 = torch.load(hidden2)
    return A, b, Z, Z2, hidden, hidden2'''

'''def get_data(task_name, batch=1000):
    A, b, Z, Z2, hidden, hidden2 = load_tensors(task_name)
    Z = np.round(Z); Z2 = np.round(Z2)
    data = torch.load(f"../tasks/{task_name}/data.pt")
    inputs = data[0].detach().numpy()
    outputs = data[1].detach().numpy()

    if len(data[0].shape) == 2:
        inputs_last = data[0][:,[-1]].detach().numpy()
        input_dim_larger_than_one_flag = False
    else:
        input_dim = data[0].shape[2]
        if input_dim > 1:
            input_dim_larger_than_one_flag = True
        inputs_last = data[0][:,-1].detach().numpy()
    outputs_last = data[1][:,[-1]].detach().numpy()
    if batch == None:
        return A, b, Z, Z2, hidden, hidden2, inputs, inputs_last, outputs, outputs_last, input_dim_larger_than_one_flag
    else:
        return A, b, Z[:batch], Z2[:batch], hidden[:batch], hidden2[:batch], inputs[:batch], inputs_last[:batch], outputs[:batch], outputs_last[:batch], input_dim_larger_than_one_flag'''


def get_boolean_formula(task, eff_hidden_dim_increase=0, ckpt='tasks'):

    A, b, Z, Z2, hidden, hidden2, inputs, inputs_last, outputs, outputs_last, _ = get_data(task, batch=1000)
    
    dic4config = {'tasks':'model_config', 'models_regularized': 'regularized_model_config', 'models_hammered': 'post_hammer_retrain_model_config'}
    dic4model = {'tasks':'model_perfect', 'models_regularized': 'regularized_model_best', 'models_hammered': 'post_hammer_retrain_model_best'}
    
    #config = torch.load(f"../tasks/{task}/model_config.pt", map_location=torch.device('cpu'))
    config = torch.load(f"../{ckpt}/{task}/{dic4config[ckpt]}.pt", map_location=torch.device('cpu'))
    model = GeneralRNN(config, device=torch.device('cpu'))
    #model.load_state_dict(torch.load(f"../tasks/{task}/model_perfect.pt", map_location=torch.device('cpu')))
    model.load_state_dict(torch.load(f"../{ckpt}/{task}/{dic4model[ckpt]}.pt", map_location=torch.device('cpu')))
    n_cluster = 2**(model.hidden_dim+eff_hidden_dim_increase)

    # concat hidden and hidden2
    n_sample = hidden.shape[0]
    hidden_concat = np.concatenate([hidden, hidden2], axis=0)
    clustering = SpectralClustering(n_clusters=n_cluster,
                                    assign_labels='discretize',
                                    random_state=0).fit(hidden_concat)

    # get cluster centers
    cluster_centers = np.zeros((n_cluster, model.hidden_dim))
    cluster_size = np.zeros(n_cluster,)
    for i in range(n_sample):
        cluster_centers[clustering.labels_[i]] += hidden_concat[i]
        cluster_size[clustering.labels_[i]] += 1
    cluster_centers = cluster_centers/cluster_size[:,np.newaxis]

    hidden_cl = clustering.labels_[:n_sample]
    hidden2_cl = clustering.labels_[n_sample:]

    ih_combined_dim = model.config.input_dim + model.config.hidden_dim

    hidden_eff_dim = int(np.log2(n_cluster))
    ih_combined_eff_dim = model.config.input_dim + hidden_eff_dim

    lengths = []
    next_hidden_formulas_candidates = []
    next_out_formulas_candidates = []
    consistency = []

    perm = np.array(list(permutations(np.arange(n_cluster))))

    for i in range(perm.shape[0]):
        c2bits = np.array(list(product(*[[0,1]]*int(np.log2(n_cluster)))))[list(perm[i])]
        hidden_bits = c2bits[hidden_cl]
        hidden2_bits = c2bits[hidden2_cl]

        # (hidden2, input) -> hidden
        input_hidden2_combined = np.concatenate([hidden2_bits, inputs_last], axis=1)

        next_hidden_formulas = []
        length = 0
        hidden_std = 0
        for i in range(hidden_eff_dim):
            str_arr = np.zeros(2**ih_combined_eff_dim,dtype=int)
            indices = np.sum(input_hidden2_combined * 2**np.arange(ih_combined_eff_dim)[::-1], axis=1)
            max_index = 2**hidden_eff_dim
            for j in range(max_index):
                hidden_std += np.std(hidden_bits[np.where(indices==j)[0],i])
            str_arr[indices] = hidden_bits[:,i]
            string_ = ""
            for bit in list(str_arr):
                string_ += str(bit)
            next_formula = formula(string_)
            length += len(next_formula)
            next_hidden_formulas.append(next_formula.replace('xor','^'))
        if hidden_std > 1e-2:
            hidden_consistency = False
        else:
            hidden_consistency = True


        # hidden -> output

        next_out_formulas = []
        out_std = 0
        for i in range(model.config.output_dim):
            str_arr = np.zeros(2**hidden_eff_dim,dtype=int)
            indices = np.sum(hidden_bits * 2**np.arange(hidden_eff_dim)[::-1], axis=1)
            max_index = 2**hidden_eff_dim
            for j in range(max_index):
                out_std += np.std(outputs_last[np.where(indices==j)[0],i])
            str_arr[indices] = outputs_last[:,i]
            string_ = ""
            for bit in list(str_arr):
                string_ += str(bit)
            next_formula = formula(string_)
            length += len(next_formula)
            next_out_formulas.append(next_formula.replace('xor','^'))
        if out_std > 1e-2:
            out_consistency = False
        else:
            out_consistency = True

        next_hidden_formulas_candidates.append(next_hidden_formulas)
        next_out_formulas_candidates.append(next_out_formulas)
        lengths.append(length)
        consistency.append(hidden_consistency * out_consistency)
        
    lengths = np.array(lengths)
    consistency = np.array(consistency)
    #print(consistency)
    consistent_indices = np.where(consistency==1)[0]
    #print(consistent_indices)
    if consistent_indices.shape[0] == 0:
        best_id = 0
    else:
        best_id = consistent_indices[np.argmin(lengths[consistent_indices])]

    next_hidden_formulas = next_hidden_formulas_candidates[best_id]
    next_out_formulas = next_out_formulas_candidates[best_id]
    c2bits = np.array(list(product(*[[0,1]]*int(np.log2(n_cluster)))))[list(perm[best_id])]
    return next_hidden_formulas, next_out_formulas, cluster_centers, c2bits


# produce program

def produce_program_boolean_regression(task_name, effective_hidden_dim_increase = 0, print_code=False, ckpt="tasks"):
    data = get_data(task_name)
    A, b, Z, Z2, hidden, hidden2, inputs, inputs_last, outputs, outputs_last, _ = data
    
    num_seq = inputs.shape[1]
    num_data = inputs_last.shape[0]
    dim_hidden = Z.shape[1]
    if len(inputs_last.shape) == 1:
        dim_input = 1
    else:
        dim_input = inputs_last.shape[1]
        
    # inputs_last should be 0 or 1
    not_zero_or_one_num = np.sum(1-((inputs_last == 0).astype('int')+(inputs_last == 1).astype('int')))
    
    if dim_hidden + effective_hidden_dim_increase > 3 or not_zero_or_one_num > 0:
        print("dim too high (perm > 16!), skip...")
        code = "success_flag = 0 # dim too high for boolean regression"
        with open(f'./programs/{task}.txt'.replace('_','-'),"w") as f:
            f.writelines(code)
        return code
        
    next_hidden_formulas, next_out_formulas, cluster_centers, c2bits = get_boolean_formula(task, eff_hidden_dim_increase=effective_hidden_dim_increase, ckpt=ckpt)
    # hidden should be initialized to be the 
    h = c2bits[np.argmin(np.linalg.norm(cluster_centers, axis=1))]
    
    variables = string.ascii_lowercase[:14] # hidden variables denoted as a,b,c,d,...; input variable denoted as x (or x1 & x2); output variable denoted as y.
    # initialize_code (initialize hidden states)
    initialize_code = ""
    if len(h.shape) == 2:
        h = h[0]
    for i in range(dim_hidden+effective_hidden_dim_increase):
        initialize_code += f"{variables[i]} = {h[i]};"
        
    # hidden code
    hidden_code = ""
    
    for i in range(dim_hidden+effective_hidden_dim_increase):
        hidden_i_code = f"next_{variables[i]} = {next_hidden_formulas[i]}"
        if i == 0:
            hidden_code = hidden_i_code
        else:
            hidden_code += "\n        " + hidden_i_code
            
    hidden_code += "\n        "
            
    for i in range(dim_hidden+effective_hidden_dim_increase):
        hidden_code += f"{variables[i]} = next_{variables[i]};"
            
    
    # output code
    output_code = f"y = {next_out_formulas[0]}"
    
    if dim_input == 2:
        function_code = f"""
def f(s,t):
    {initialize_code}
    ys = []
    for i in range({num_seq}):
        {variables[dim_hidden+effective_hidden_dim_increase]} = s[i]; {variables[dim_hidden+effective_hidden_dim_increase+1]} = t[i];
        {hidden_code}
        {output_code}
        ys.append(y)
    return ys
    """
        
    if dim_input == 1:
        function_code = f"""
def f(s):
    {initialize_code}
    ys = []
    for i in range({num_seq}):
        {variables[dim_hidden+effective_hidden_dim_increase]} = s[i]
        {hidden_code}
        {output_code}
        ys.append(y)
    return ys
    """

    
    preprocess_code = f"""
import numpy as np
data = get_data(\"{task_name}\")
A, b, Z, Z2, hidden, hidden2, inputs, inputs_last, outputs, outputs_last, _ = data
num_example = inputs.shape[0]
    """
    
    if dim_input == 1:
        verify_code = f"""
wrong = 0
for i in range(int(num_example*0.01)):
#for i in range(int(num_example)):
    out_pred = f(inputs[i])
    wrong += np.sum(1 - (np.array(outputs[i]) == np.array(out_pred)))
    
if wrong == 0:
    success_flag = 1
    print('verification success')
else:
    success_flag = 0
    print('verification failure')
"""
    
    if dim_input == 2:
        verify_code = f"""
wrong = 0
for i in range(int(num_example*0.01)):
#for i in range(int(num_example)):
    out_pred = f(inputs[i,:,0], inputs[i,:,1])
    wrong += np.sum(1 - (np.array(outputs[i]) == np.array(out_pred)))
    
if wrong == 0:
    success_flag = 1
    print('verification success')
else:
    success_flag = 0
    print('verification failure')
"""
    
    code = f"""
{preprocess_code}
{function_code}
{verify_code}
"""
    if print_code == True:
        print(code)
        
    with open(f'./programs/{task}.txt'.replace('_','-'),"w") as f:
        f.writelines(function_code)
    return code 
    