In [None]:
# default_exp sigmoid

# Sigmoid

In [None]:
# export
import numpy as np

import torch.nn as nn
import torch

from sklearn.tree import BaseDecisionTree

from typing import List, Tuple
from functools import partial

## Bit matcher

In [None]:
# export
def shift_bit_eps(bit: int, eps: float = 0.5) -> float:
    """Shift the bit from 0 to 1-eps, and from 1 to eps"""
    assert bit in [0,1], "Bit must be 0 or 1"
    return (2 * eps - 1) * bit + 1 - eps

def create_base_vectors(circuit: List[int], eps: float = 0.5) -> List[List[float]]:
    """Create the base vectors from a circuit"""
    vectors = []
    n = len(circuit)
    
    for i in range(n):
        vector = list(circuit).copy()
        vector[i] = shift_bit_eps(vector[i], eps=eps)
        vectors.append(vector)
        
    return vectors

In [None]:
from fastcore.test import test_eq

bits = [0,0]
test_eq(shift_bit_eps(0),0.5)
test_eq(shift_bit_eps(1),0.5)

In [None]:
# export
def create_linear_system(vectors: List[List[float]]) -> Tuple:
    """Create a linear system from the base vectors"""
    X = np.array(vectors)

    y = -X[:,-1]

    X[:,-1] = 1
    
    return X,y

In [None]:
# export
class BitComparison(nn.Module):
    """Module to create a linear model which only outputs one when a specific binary circuit is given."""
    def __init__(self,target: List[int], eps : float = 0.5):
        super(BitComparison, self).__init__()
                
        vectors = create_base_vectors(target, eps=eps)
        X,y = create_linear_system(vectors)
        W = np.linalg.solve(X,y)
        w = W[:-1]
        w = np.concatenate([w,np.ones(1)])
        c = W[-1]
        
        if not target[-1]:
            w = -w
            c = -c
            
        n = len(target)
        self.n = n
        self.linear = nn.Linear(n,1)
        
        self.linear.weight.data = torch.tensor(w.reshape(1,-1)).float()
        self.linear.bias.data = torch.tensor(c).unsqueeze(0).float()
        
    def forward(self,x):
        return self.linear(x)
        
    def __repr__(self):
        output = ""
        for i in range(self.n):
            if i < self.n - 1:
                output += f"{self.linear.weight.data[0][i]}*x_{i} + "
            else:
                output += f"{self.linear.weight.data[0][i]}*y + "
        output += f"{self.linear.bias.data[0]} = 0"
        return output

Given a bit sequence $a \in \{0,1\}^n$, we are interested in finding a linear separator $(W,b)$ such that $Wx + b > 0$ for $x = a$, and $Wx + b < 0$ for all other $x \in \{0,1\}^n \ \{a\}$.

For the following target $a=(1,1)$ one can see that the following linear function inputs a positive number if and only if the bit circuit given as input is $(1,1)$

In [None]:
BitComparison([1,1])

1.0*x_0 + 1.0*y + -1.5 = 0

In [None]:
import itertools

def create_test_cases_x(n):
    products = [[0,1]] * n

    x = list(itertools.product(*products))
    x = np.array(x)
    
    return x

def create_test_cases_y(x,target):
    y = ((x == target).sum(axis=1) == n).astype(int)
    return y

In [None]:
eps = 0.5
n = 7

x = create_test_cases_x(n)

for target in x:
    y = create_test_cases_y(x,target)
    y = torch.tensor(y)
    
    bitcomparison = BitComparison(target, eps=eps)
    accuracy = (y == (bitcomparison.linear(torch.tensor(x).float()) > 0).view(-1)).float().mean().item()
    
    test_eq(accuracy,1)

## Sigmoid utils

In [None]:
# export
def sigmoid_path_to_weight(path, nodes2idx, eps=0.5):
    # This is the target of the Bitcomparison
    bits = [v for k,v in path]
    bit_comparison = BitComparison(bits, eps=eps)
    
    # Those are the indexes to be replaced by the corresponding weights
    idx = [nodes2idx[k] for k,v in path]
    
    K = len(nodes2idx)
    w = np.zeros(K)
    
    w[idx] = bit_comparison.linear.weight.data.numpy().reshape(-1)
    b = bit_comparison.linear.bias.data.numpy()
    
    return w,b