In [373]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [239]:
n = 2    # number of inputs
h = 2    # width of hidden layer
l = n*h + h + h + 1    # number of network hyperparameters
k = 3    # base (e.g. 3 if hyperparameters take values -1,0,1)

In [381]:
# Define functions

def binary(n, length=1):
    if n == 0:
        return length*'0'
    nums = []
    while n:
        n, r = divmod(n, 2)
        nums.append(str(r))
    if length > len(nums):
        for i in range(length-len(nums)):
            nums.append('0')
    return ''.join(reversed(nums))

def ternary(n, length=1):
    if n == 0:
        return length*'0'
    nums = []
    while n:
        n, r = divmod(n, 3)
        nums.append(str(r))
    if length > len(nums):
        for i in range(length-len(nums)):
            nums.append('0')
    return ''.join(reversed(nums))

def hamming_pair(string1, string2):
    if len(string1) != len(string2):
        raise Exception('Input string sizes do not match')
        
    if string1 == string2:
        return False
    
    count_diffs = 0
    for a, b in zip(string1, string2):
        if a!=b:
            if count_diffs:
                return False
            else:
                count_diffs += 1
    return True

def assign_weights(func_num_tern, n, h, l, k):
    
    if len(func_num_tern) != l:
        raise Exception('Function length and number of hyperparameters do not match')
    
    W1 = torch.zeros((h,n), dtype=int)
    b1 = torch.zeros((h,1), dtype=int)
    W2 = torch.zeros((1,h), dtype=int)
    b2 = torch.zeros((1,1), dtype=int)
    
    digit_ind = 0
    for i in range(h):
        for j in range(n):
            W1[i,j] = int(func_num_tern[digit_ind]) - (k-2)
            digit_ind += 1
    for i in range(h):
        b1[i,0] = int(func_num_tern[digit_ind]) - (k-2)
        digit_ind += 1
    for i in range(h):
        W2[0,i] = int(func_num_tern[digit_ind]) - (k-2)
        digit_ind += 1
    b2[0,0] = int(func_num_tern[digit_ind]) - (k-2)
    
    return [W1, b1, W2, b2]

In [389]:
# Define neural network

class Net(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.W1 = params[0]
        self.b1 = params[1]
        self.W2 = params[2]
        self.b2 = params[3]
        self.relu = nn.ReLU()
    
    def forward(self, x):
        out = self.W1@x + self.b1
        out = self.relu(out)
        out = self.W2@out + self.b2
        out = self.relu(torch.sign(out))
        return out

In [374]:
func_dict = {}

for func_num in tqdm(range(k**l)):
    func_num_tern = ternary(func_num, l)
    params = assign_weights(func_num_tern, n, h, l, k)
    bool_func = Net(params)
    func_true = []
    
    for func_in in range(2**n):
        
        x = torch.zeros((n,1), dtype=int)
        func_in_bin = binary(func_in, n)
        
        digit_ind = 0
        for i in range(n):
            x[i,0] = int(func_in_bin[digit_ind])
            digit_ind += 1
       
        func_true.append(str(bool_func.forward(x).item()))
    
    func_dict[func_num] = ''.join(func_true)

100%|███████████████████████████████████████████████████████████████████████████| 19683/19683 [00:24<00:00, 799.97it/s]


In [399]:
freq = []
rho = []
LZ = []

NC = []
for i in range(2**(2**n)):
    NC.append({})

count = 0
for i in tqdm(func_dict):
    for func_true in range(2**(2**n)):
        if func_dict[i] == binary(func_true, 2**n):
            com_ind = []
            for key in NC[func_true]:
                if hamming_pair(ternary(i, l), ternary(key, l)):
                    com_ind.append(NC[func_true][key])
                else:
                    pass
            if len(com_ind)>0:
                NC[func_true][i] = min(com_ind)
                for key in NC[func_true]:
                    if NC[func_true][key] in com_ind:
                        NC[func_true][key] = min(com_ind)
            else:
                NC[func_true][i] = count
                count += 1
        else:
            pass

100%|████████████████████████████████████████████████████████████████████████████| 19683/19683 [14:10<00:00, 23.14it/s]


KeyboardInterrupt: 

In [401]:
G_N = []
for i in range(2**(2**n)):
    G_N.append({})
        
ind = []
for func_true in range(2**(2**n)):
    ind.append(sorted(set(NC[func_true].values())))

for true_func in range(2**(2**n)):
    for val in ind[true_func]:
        G_N[true_func][val] = [0,0]
        for key_i in tqdm(NC[true_func]):
            if NC[true_func][key_i] == val:
                G_N[true_func][val][0] += 1
                for key_j in NC[true_func]:
                    if NC[true_func][key_j] == val and key_j > key_i:
                        G_N[true_func][val][1] = G_N[true_func][val][1] + hamming_pair(ternary(key_i, l), ternary(key_j, l))
                    else:
                        pass
            else:
                pass

for func_true in range(2**(2**n)):
    for key in G_N[true_func]:
        rho_val = (2*float(G_N[func_true][key][1]))/(l*(k-1)*float(G_N[func_true][key][0]))
        rho.append(rho_val)
        freq_val = G_N[func_true][key][0]
        freq.append(freq_val)

  1%|▋                                                                              | 93/10752 [00:21<41:39,  4.26it/s]


KeyboardInterrupt: 