In [1]:
import copy
import time
import pickle
from collections import Counter 
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as Variable

import torchvision
import torchvision.transforms as transforms

In [2]:
last_time = None

def start_time():
    global last_time
    last_time = time.time()
def take_time():
    global last_time
    new_time = time.time()
    print("time:", new_time - last_time)
    last_time = new_time

In [3]:
def freq_dict(list):
    
    freq = {}
    for item in list:
        if (item in freq):
            freq[item] += 1
        else:
            freq[item] = 1
    return(freq)

def clip(x, a=0, b=1):
    return(torch.max(torch.min(x, 0*x + b), 0*x + a))

In [4]:
def circuit_output(b, c, print_stuff=False):
#     print(b, c)
    
    n_strings = len(b[0])
    
    gate_i = 0
    cur_width = b.size()[0]
    
    while cur_width > 1:
        if print_stuff:
            print(cur_width)
            print(b)
        
        cur_width = round(cur_width/2)
        outputs = torch.zeros(cur_width, n_strings)

        for out_i in range(cur_width):
            in_i1 = out_i*2
            in_i2 = out_i*2 + 1

            outputs[out_i] = (1 - c[gate_i])*(1 - (1 - b[in_i1])*(1 - b[in_i2])) + c[gate_i]*b[in_i1]*b[in_i2]
            
            gate_i += 1

        b = outputs
    
    return(b[0])

In [5]:
def choose_target(n_bits):
    global c
    c = 1.*(torch.rand(n_bits - 1) > .5)

    while True:
        c = 1.*(torch.randn(n_bits - 1) > 0)
        outputs = [circuit_output(1.*(torch.randn(n_bits, 1) > 0), c) for i in range(1000)]
#         print(sum(outputs))
        if sum(outputs) > 400 and sum(outputs) < 600:
            break
            
    outputs = [circuit_output(1.*(torch.randn(n_bits, 1) > 0), c) for i in range(1000)]
    print("average output of", sum(outputs).item()/1000)
    print(c)
    return(c)

In [6]:
def generate_all_binary_vectors(n_bits, array_so_far=None):
    if array_so_far == None:
        array_so_far = []
#     print(n_bits, array_so_far)
    if n_bits == 0:
        return([array_so_far.copy()])
    
    outputs = []
    
    array_so_far += [0]
    outputs += generate_all_binary_vectors(n_bits - 1, array_so_far)
    
    array_so_far[len(array_so_far) - 1] = 1
    outputs += generate_all_binary_vectors(n_bits - 1, array_so_far)
    
    del array_so_far[-1]
    
    return(outputs)

def generate_random_binary_vectors(n_bits, n):
    return(1.*(torch.rand((n, n_bits)) > .5))

## Choose target vectors

In [32]:
lengths = [4, 8, 16, 32]
cs = {}
for n_bits in lengths:
    cs[n_bits] = choose_target(n_bits)

average output of 0.579
tensor([0., 0., 1.])
average output of 0.435
tensor([0., 0., 0., 1., 1., 0., 1.])
average output of 0.579
tensor([1., 1., 0., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0.])
average output of 0.546
tensor([1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0.,
        0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1.])


## Check for local minima

In [9]:
def bitstring_neighbors(b):
    neighbors = torch.cat((b,)*len(b[0]))
    neighbors = neighbors + torch.eye(len(b[0]))
    neighbors %= 2
    return(neighbors)

def binary_to_decimal(b):
    nb = len(b[0])
    x = torch.Tensor([[2**k for k in range(nb - 1, -1, -1)]])
    output = torch.mm(1.*b, 1.*torch.t(x))
    return(output.long())

def is_local_min(b, errors, strict=True):
    error = errors[binary_to_decimal(b).view(-1)[0]]
    neighbors = bitstring_neighbors(b)
    neighbor_is = binary_to_decimal(neighbors).view(-1).tolist()
    neighbor_errors = errors[neighbor_is]
    min_neighbor_error = min(neighbor_errors)
    
    if error < min_neighbor_error:
        return(True)
    elif error == min_neighbor_error:
        if strict:
            return(False)
        return(True)
    return(False)

def get_equivalent_neighbors_if_min(b, errors):
    error = errors[binary_to_decimal(b).view(-1)[0]]
    neighbors = bitstring_neighbors(b)
    neighbor_is = binary_to_decimal(neighbors).view(-1).tolist() + binary_to_decimal(b)[0].tolist()
    neighbor_errors = errors[neighbor_is]
    min_neighbor_error = min(neighbor_errors)
    
    if min_neighbor_error < error:
        return(set())
    
    output = set()
    for i in range(len(neighbor_is)):
        if neighbor_errors[i] == error:
            output.add(neighbor_is[i])
    return(output)

In [33]:
n_bits = 8
redundancy = 1
n_gates = redundancy*n_bits - 1

c = cs[n_bits]
# c = torch.Tensor([0, 0, 1])
# c = torch.Tensor([0, 0, 0, 0, 1, 1, 0])
# c = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])

data = torch.tensor(generate_all_binary_vectors(n_bits))
if redundancy > 1:
    data = torch.cat((data,)*redundancy, dim=1)

circuits = torch.tensor(generate_all_binary_vectors(n_gates))

errors = torch.zeros(len(circuits))
for i in range(len(circuits)):
    output = circuit_output(torch.t(data), circuits[i])
    target = circuit_output(torch.t(data), c)
    errors[i] = sum(abs(output - target))

# local_min = torch.Tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
# neighbor_is = binary_to_decimal(bitstring_neighbors(local_min))
# print(errors[binary_to_decimal(local_min)])
# [errors[n_i] for n_i in neighbor_is]

### Find single local mins

In [34]:
for i in range(len(circuits)):
    if is_local_min(circuits[i:(i + 1)], errors, strict=True):
        print(i, circuits[i], errors[i])

13 tensor([0, 0, 0, 1, 1, 0, 1]) tensor(0.)
30 tensor([0, 0, 1, 1, 1, 1, 0]) tensor(34.)
35 tensor([0, 1, 0, 0, 0, 1, 1]) tensor(72.)
41 tensor([0, 1, 0, 1, 0, 0, 1]) tensor(52.)
67 tensor([1, 0, 0, 0, 0, 1, 1]) tensor(72.)
73 tensor([1, 0, 0, 1, 0, 0, 1]) tensor(52.)
108 tensor([1, 1, 0, 1, 1, 0, 0]) tensor(94.)


### Find local min clusters

In [35]:
cohort_is = [get_equivalent_neighbors_if_min(circuits[i:(i + 1)], errors) for i in range(len(circuits))]

original_cohort_is = copy.deepcopy(cohort_is)

for run in range(20):
    for i in range(len(circuits)):
        set_i = cohort_is[i]
        for j in set_i:
            if cohort_is[j] == set():
                cohort_is[i] = set()
                break

cluster_ids = [str(i) for i in range(len(circuits))]
for run in range(20):
    for i in range(len(circuits)):
        if cohort_is[i] == set():
            cluster_ids[i] = ''
        else:
            for j in cohort_is[i]:
                cluster_ids[j] = cluster_ids[i]

zero_cluster_ids = set()
min_cluster_ids = set()
for ci in set(cluster_ids):
    if ci != '':
        if errors[int(ci)] == 0:
            zero_cluster_ids.add(ci)
        else:
            min_cluster_ids.add(ci)

cluster_size_dict = freq_dict(cluster_ids)
min_cluster_errors = [errors[int(ci)].item() for ci in min_cluster_ids]
min_cluster_sizes = [cluster_size_dict[mci] for mci in min_cluster_ids]
zero_cluster_sizes = [cluster_size_dict[zci] for zci in zero_cluster_ids]
            
print(str(len(set(zero_cluster_ids))) + " zero clusters with sizes " + str(zero_cluster_sizes))
print(str(len(set(min_cluster_ids))) + " min clusters with sizes " + str(min_cluster_sizes) + " and errors " + str(min_cluster_errors))

print(cluster_ids)
print(cohort_is)


1 zero clusters with sizes [1]
6 min clusters with sizes [1, 1, 1, 1, 1, 1] and errors [52.0, 72.0, 52.0, 34.0, 94.0, 72.0]
['', '', '', '', '', '', '', '', '', '', '', '', '', '13', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '30', '', '', '', '', '35', '', '', '', '', '', '41', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '67', '', '', '', '', '', '73', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '108', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '']
[set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), {13}, set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), set(), {30}, set(), set(), set(), set(), {35}, set(), set(), set(), set(), set(), {41}, set(), set(), set(), set(), set(), set(), set(), set(), s

## Explore the neighborhood of a vertex

In [57]:
n_bits = 16
redundancy = 2
n_gates = redundancy*n_bits - 1

c = cs[n_bits]

data = torch.tensor(generate_all_binary_vectors(n_bits))
data_r = torch.cat((data,)*redundancy, dim=1)
    
x = torch.tensor([[0, 0, 1, 1, 1, 1, 0]])
x = torch.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0]])

output = circuit_output(torch.t(data_r), x[0])
target = circuit_output(torch.t(data), c)
print(sum(abs(output - target)))
    
neighbors = bitstring_neighbors(x)
neighbor_errors = []
for i in range(len(neighbors)):
    neighbor = neighbors[i]
    output = circuit_output(torch.t(data_r), neighbor)
    neighbor_errors.append(sum(abs(output - target)).item())

print(neighbor_errors)

tensor(34.)
[88.0, 88.0, 48.0, 48.0, 34.0, 34.0, 34.0, 34.0, 124.0, 76.0, 34.0, 34.0, 108.0, 34.0, 116.0]


tensor([0., 0., 0., 1., 1., 0., 1.])