In [455]:
import copy
import time
import pickle
from collections import Counter 
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as Variable

import torchvision
import torchvision.transforms as transforms

In [456]:
def freq_dict(list):
    
    freq = {}
    for item in list:
        if (item in freq):
            freq[item] += 1
        else:
            freq[item] = 1
    return(freq)

def clip(x, a=0, b=1):
    return(torch.max(torch.min(x, 0*x + b), 0*x + a))

In [458]:
n_bits = 16

def circuit_output_old(b, c):
    vals = torch.zeros(2*n_bits - 1, len(b[0]))
    vals[0:n_bits] = b

    for out_i in range(n_bits - 1):
        in_i1 = out_i*2
        in_i2 = out_i*2 + 1

        vals[out_i + n_bits] = (1 - c[out_i])*(1 - (1 - vals[in_i1])*(1 - vals[in_i2])) + c[out_i]*vals[in_i1]*vals[in_i2]
        # changed AND to XOR
#         vals[out_i + n_bits] = (1 - (1 - vals[in_i1])*(1 - vals[in_i2])) - c[out_i]*vals[in_i1]*vals[in_i2]
    
    return(vals[2*n_bits - 2])

c = 1.*(torch.rand(n_bits - 1) > .5)
# c = torch.Tensor([0, 0, 1])

while True:
    c = 1.*(torch.randn(n_bits - 1) > 0)
    outputs = [circuit_output_old(1.*(torch.randn(n_bits, 1) > 0), c) for i in range(1000)]
    print(sum(outputs))
    if sum(outputs) > 400 and sum(outputs) < 600:
        break

print(c)

tensor([149.])
tensor([938.])
tensor([468.])
tensor([1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 1.])


In [1]:
# c = 1.*(torch.randn(n_bits - 1) > 0)

def circuit_output(b, c, print_stuff=False):
#     print(b, c)
    
    n_strings = len(b[0])
    
    gate_i = 0
    cur_width = b.size()[0]
    
    while cur_width > 1:
        if print_stuff:
            print(cur_width)
            print(b)
        
        cur_width = round(cur_width/2)
        outputs = torch.zeros(cur_width, n_strings)

        for out_i in range(cur_width):
            in_i1 = out_i*2
            in_i2 = out_i*2 + 1

            outputs[out_i] = (1 - c[gate_i])*(1 - (1 - b[in_i1])*(1 - b[in_i2])) + c[gate_i]*b[in_i1]*b[in_i2]
            
            gate_i += 1

        b = outputs
    
    return(b[0])

In [460]:
def generate_all_binary_vectors(n_bits, array_so_far=None):
    if array_so_far == None:
        array_so_far = []
#     print(n_bits, array_so_far)
    if n_bits == 0:
        return([array_so_far.copy()])
    
    outputs = []
    
    array_so_far += [0]
    outputs += generate_all_binary_vectors(n_bits - 1, array_so_far)
    
    array_so_far[len(array_so_far) - 1] = 1
    outputs += generate_all_binary_vectors(n_bits - 1, array_so_far)
    
    del array_so_far[-1]
    
    return(outputs)

## Initialize and optimize on hypercube - exact

In [449]:
all_data = torch.tensor(generate_all_binary_vectors(n_bits))

x = Variable(torch.rand(n_bits - 1), requires_grad=True)
# x = Variable(torch.Tensor([0, 1, 1, 0, 0, 0, 0]), requires_grad=True)
optimizer = optim.SGD([x], lr=.01, momentum=.999)
# optimizer = optim.Adam([x], lr=.1)

for epoch in range(3000):
#     data = 1.*(torch.rand(10000, n_bits) > .5)
#     data = torch.rand(10000, n_bits)
    data = all_data
    
    optimizer.zero_grad()
    output_x = circuit_output(torch.t(data), x)
    output_c = circuit_output(torch.t(data), c)

    loss = torch.mean(torch.abs((output_x - output_c)))
    overshoot = torch.sum(torch.abs(torch.min(x, 0*x)) + torch.max(x - 1, 0*x))
#     loss += overshoot
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        x.data = clip(x.data)
    
    if epoch % 100 == 0:
        print('*'*100)
        print("loss: ", loss.item())
        
        x_snapped = 1.*(x > .5)
        output_s = circuit_output(torch.t(data), x_snapped)
        loss_snapped = torch.mean(torch.abs(output_s - output_c))
        print("loss snapped: ", loss_snapped.item())
        
        print("abs(x - c): ", torch.sum(torch.abs(x - c)).item())
        extremeness = torch.sum(torch.min(torch.abs(x), torch.abs(x - 1))).item()
        print("extremeness of x: ", extremeness)
        print("overshoot of x: ", overshoot.item())

****************************************************************************************************
loss:  0.49959075450897217
loss snapped:  0.684234619140625
abs(x - c):  7.632452964782715
extremeness of x:  3.66811203956604
overshoot of x:  0.0
****************************************************************************************************
loss:  0.12908747792243958
loss snapped:  0.129119873046875
abs(x - c):  6.9884419441223145
extremeness of x:  0.7579243183135986
overshoot of x:  0.0
****************************************************************************************************
loss:  0.129119873046875
loss snapped:  0.129119873046875
abs(x - c):  7.0
extremeness of x:  0.0
overshoot of x:  0.0
****************************************************************************************************
loss:  0.129119873046875
loss snapped:  0.129119873046875
abs(x - c):  7.0
extremeness of x:  0.0
overshoot of x:  0.0
***********************************************************

KeyboardInterrupt: 

In [450]:
x

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)

## Initialize and optimize on hypercube with redundant parameterization

In [453]:
redundancy = 4

all_data = torch.tensor(generate_all_binary_vectors(n_bits))
all_data_redundant = torch.cat((all_data,)*redundancy, dim=1)

x = Variable(torch.rand(redundancy*n_bits - 1), requires_grad=True)
# x = Variable(torch.Tensor([0, 1, 1, 0, 0, 0, 1]), requires_grad=True)
optimizer = optim.SGD([x], lr=.003, momentum=.999)
# optimizer = optim.Adam([x], lr=.1)

# c = torch.Tensor([0, 0, 1])

for epoch in range(10000):
#     data = 1.*(torch.rand(10000, n_bits) > .5)
#     data = torch.rand(10000, n_bits)
    data = all_data
    data_redundant = all_data_redundant
    
    optimizer.zero_grad()
    output_x = circuit_output(torch.t(data_redundant), x)
    output_c = circuit_output(torch.t(data), c)
    loss = torch.mean(torch.abs(output_x - output_c))
    overshoot = torch.sum(torch.abs(torch.min(x, 0*x)) + torch.max(x - 1, 0*x))
#     loss += overshoot
    
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        x.data = clip(x.data)
    
    if epoch % 100 == 0:
        print('*'*100)
        print("loss: ", loss.item())
        
        x_snapped = 1.*(x > .5)
        output_s = circuit_output(torch.t(data_redundant), x_snapped)
        loss_snapped = torch.mean(torch.abs(output_s - output_c))
        print("loss snapped: ", loss_snapped.item())
        
#         print("abs(x - c): ", torch.sum(torch.abs(x - c)).item())
        extremeness = torch.sum(torch.min(torch.abs(x), torch.abs(x - 1))).item()
        print("extremeness of x: ", extremeness)
        print("overshoot of x: ", overshoot.item())

****************************************************************************************************
loss:  0.3848706781864166
loss snapped:  0.4671630859375
extremeness of x:  15.480570793151855
overshoot of x:  0.0
****************************************************************************************************
loss:  0.12878073751926422
loss snapped:  0.12908935546875
extremeness of x:  12.951181411743164
overshoot of x:  0.0
****************************************************************************************************
loss:  0.12910234928131104
loss snapped:  0.129119873046875
extremeness of x:  10.33227825164795
overshoot of x:  0.0
****************************************************************************************************
loss:  0.12911519408226013
loss snapped:  0.129119873046875
extremeness of x:  8.234146118164062
overshoot of x:  0.0
****************************************************************************************************
loss:  0.1291180700063705

KeyboardInterrupt: 

In [454]:
x

tensor([0.0368, 0.0000, 0.0000, 0.0068, 0.0000, 0.0000, 0.0000, 0.0000, 0.3309,
        0.0000, 0.0000, 0.5618, 0.0000, 0.0103, 0.0000, 0.5727, 0.0000, 0.1635,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4231, 0.1861, 0.3262,
        0.3925, 0.0000, 0.8413, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       requires_grad=True)

In [106]:
print(x)
y = torch.Tensor([1, 0, 0, 0, 0, 0, 0, 0])

output_x = circuit_output(torch.t(data_doubled), x)
output_y = circuit_output(torch.t(data_doubled), y)
output_c = circuit_output(torch.t(data), c)

print(sum(abs(output_x - output_c)))
print(sum(abs(output_y - output_c)))

# print(data_doubled)
# print(output_y)
# print(data_doubled[4:5])
# print('*'*100)
# print(circuit_output(torch.t(data_doubled[4:5]), y, print_stuff=True))

tensor([ 1.0000e+00, -5.4456e-06,  4.9505e-04,  7.5124e-01,  6.1709e-01,
         9.3323e-01,  8.1097e-01], requires_grad=True)
tensor(5.7244, grad_fn=<AddBackward0>)
tensor(6.)


In [135]:
x = torch.Tensor([0, 1, 0])

output_x = circuit_output(torch.t(data), x)
output_c = circuit_output(torch.t(data), c)

print(sum(abs(output_x - output_c)))

tensor(4.)


In [778]:
print(x)
print(c)

tensor([6.8733e-04, 1.0004e+00, 1.3043e-02], requires_grad=True)
tensor([0., 0., 1.])


In [385]:
x_orig = x
x_clipped = 1.*(x > .5)
x

tensor([0.9105, 0.6759, 1.0000, 0.9095, 0.0000, 0.0000, 1.0000, 1.0000, 0.0236,
        0.0000, 0.0000, 0.0000, 0.9949, 0.9949, 0.9748, 0.9748, 0.3941, 1.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.9849, 0.0000, 1.0000, 1.0000, 0.0000,
        0.0000, 0.0000, 0.8900, 1.0000], requires_grad=True)

In [699]:
data = all_data

x_tweaked = x_clipped + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(x_tweaked)

output_x = circuit_output(torch.t(data), x_clipped)
output_x_t = circuit_output(torch.t(data), x_tweaked)
output_c = circuit_output(torch.t(data), c)

loss = torch.mean(torch.abs((output_x - output_c)))
loss_t = torch.mean(torch.abs((output_x_t - output_c)))

print(loss_t - loss)

tensor([0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1.])
tensor(0.)


In [667]:
c - x_clipped

tensor([ 0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
         0.])

In [None]:
# -> there are local mins besides the global min

In [852]:
# check if adding more gates besides AND/OR can help escape a local minimum

all_data = generate_all_binary_vectors(4)

def target_f(x):
    o1 = x[0] or x[1]
    o2 = x[2] or x[3]
    o3 = o1 and o2
    return(o3)

def current_f(x):
    o1 = x[0] or x[1]
    o2 = .8*(x[2] and x[3]) + .2*(x[2] or x[3])
    o3 = .99*o1 + .01*(o1 and o2)
    return(o3)

sum([abs(target_f(x) - current_f(x)) for x in all_data])

3.018

## Check for local minima

In [339]:
def bitstring_neighbors(b):
    neighbors = torch.cat((b,)*len(b[0]))
    neighbors = neighbors + torch.eye(len(b[0]))
    neighbors %= 2
    return(neighbors)

def binary_to_decimal(b):
    nb = len(b[0])
    x = torch.Tensor([[2**k for k in range(nb - 1, -1, -1)]])
    output = torch.mm(1.*b, 1.*torch.t(x))
    return(output.long())

def is_local_min(b, errors, strict=True):
    error = errors[binary_to_decimal(b).view(-1)[0]]
    neighbors = bitstring_neighbors(b)
    neighbor_is = binary_to_decimal(neighbors).view(-1).tolist()
    neighbor_errors = errors[neighbor_is]
    min_neighbor_error = min(neighbor_errors)
    
    if error < min_neighbor_error:
        return(True)
    elif error == min_neighbor_error:
        if strict:
            return(False)
        return(True)
    return(False)

def get_equivalent_neighbors_if_min(b, errors):
    error = errors[binary_to_decimal(b).view(-1)[0]]
    neighbors = bitstring_neighbors(b)
    neighbor_is = binary_to_decimal(neighbors).view(-1).tolist() + binary_to_decimal(b)[0].tolist()
    neighbor_errors = errors[neighbor_is]
    min_neighbor_error = min(neighbor_errors)
    
    if min_neighbor_error < error:
        return(set())
    
    output = set()
    for i in range(len(neighbor_is)):
        if neighbor_errors[i] == error:
            output.add(neighbor_is[i])
    return(output)

In [354]:
n_bits = 4
redundancy = 4
n_gates = redundancy*n_bits - 1

# c = torch.Tensor([0, 0, 0, 0, 1, 1, 0])
c = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])

data = torch.tensor(generate_all_binary_vectors(n_bits))
if redundancy > 1:
    data = torch.cat((data,)*redundancy, dim=1)

circuits = torch.tensor(generate_all_binary_vectors(n_gates))

errors = torch.zeros(len(circuits))
for i in range(len(circuits)):
    output = circuit_output(torch.t(data), circuits[i])
    target = circuit_output(torch.t(data), c)
    errors[i] = sum(abs(output - target))

print(errors)

tensor([6., 6., 6.,  ..., 8., 8., 8.])


### Find single local mins

In [355]:
for i in range(len(circuits)):
    if is_local_min(circuits[i:(i + 1)], errors, strict=True):
        print(i, circuits[i], errors[i])

### Find local min clusters

In [359]:
cohort_is = [get_equivalent_neighbors_if_min(circuits[i:(i + 1)], errors) for i in range(len(circuits))]

original_cohort_is = copy.deepcopy(cohort_is)

for run in range(20):
    for i in range(len(circuits)):
        set_i = cohort_is[i]
        for j in set_i:
            if cohort_is[j] == set():
                cohort_is[i] = set()
                break

cluster_ids = [str(i) for i in range(len(circuits))]
for run in range(20):
    for i in range(len(circuits)):
        if cohort_is[i] == set():
            cluster_ids[i] = ''
        else:
            for j in cohort_is[i]:
                cluster_ids[j] = cluster_ids[i]

zero_cluster_ids = set()
min_cluster_ids = set()
for ci in set(cluster_ids):
    if ci != '':
        if errors[int(ci)] == 0:
            zero_cluster_ids.add(ci)
        else:
            min_cluster_ids.add(ci)

cluster_size_dict = freq_dict(cluster_ids)
min_cluster_errors = [errors[int(ci)].item() for ci in min_cluster_ids]
min_cluster_sizes = [cluster_size_dict[mci] for mci in min_cluster_ids]
zero_cluster_sizes = [cluster_size_dict[zci] for zci in zero_cluster_ids]
            
print(str(len(set(zero_cluster_ids))) + " zero clusters with sizes " + str(zero_cluster_sizes))
print(str(len(set(min_cluster_ids))) + " min clusters with sizes " + str(min_cluster_sizes) + " and errors " + str(min_cluster_errors))

print(cluster_ids)
print(cohort_is)


15 zero clusters with sizes [146, 146, 92, 92, 229, 92, 394, 394, 536, 146, 229, 146, 229, 92, 229]
0 min clusters with sizes [] and errors []
['', '', '', '', '', '', '', '', '', '', '', '11', '', '', '', '11', '', '', '', '19', '', '', '', '19', '', '25', '', '25', '', '25', '', '25', '', '', '', '', '', '37', '', '37', '', '', '', '11', '', '37', '46', '46', '', '', '', '19', '', '37', '54', '54', '', '25', '', '25', '60', '60', '60', '60', '', '', '', '', '', '69', '', '69', '', '', '', '11', '', '69', '78', '78', '', '', '', '19', '', '69', '86', '86', '', '25', '', '25', '92', '92', '92', '92', '', '97', '', '97', '', '97', '', '97', '', '97', '106', '106', '', '97', '106', '106', '', '97', '114', '114', '', '97', '114', '114', '120', '120', '120', '120', '120', '120', '120', '120', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '19', '', '', '', '19', '', '25', '', '', '', '25', '', '', '', '', '', '', '', '37', '', '37', '', '', '', '', '', '37', '4

In [358]:
original_cohort_is

[{0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384},
 {0, 1, 3, 5, 9, 17, 33, 65, 129, 257, 513, 1025, 2049, 4097, 8193, 16385},
 {0, 2, 3, 6, 10, 18, 34, 66, 130, 258, 514, 1026, 2050, 4098, 8194, 16386},
 set(),
 {0, 4, 5, 6, 12, 20, 36, 68, 132, 260, 516, 1028, 2052, 4100, 8196, 16388},
 set(),
 {2, 4, 6, 7, 14, 22, 38, 70, 134, 262, 518, 1030, 2054, 4102, 8198, 16390},
 set(),
 {0, 8, 9, 10, 12, 24, 40, 72, 136, 264, 520, 1032, 2056, 4104, 8200, 16392},
 set(),
 set(),
 {11, 15, 27, 43, 75, 523, 1035, 2059, 4107, 8203, 16395},
 {4, 8, 12, 13, 14, 28, 44, 76, 140, 268, 524, 1036, 2060, 4108, 8204, 16396},
 set(),
 set(),
 {11, 15, 31, 47, 79, 527, 1039, 2063, 4111, 8207, 16399},
 {0, 16, 17, 18, 20, 24, 48, 80, 144, 272, 528, 1040, 2064, 4112, 8208, 16400},
 set(),
 set(),
 {19, 23, 27, 51, 83, 147, 275, 2067, 4115, 8211, 16403},
 {4, 16, 20, 21, 22, 28, 52, 84, 148, 276, 532, 1044, 2068, 4116, 8212, 16404},
 set(),
 set(),
 {19, 23, 31, 55, 87, 151, 279, 2071