# One-Hot Kernel Encoding
Based on [Structured Variationally Auto-encoded Optimization (Lu et al., 2018)](http://proceedings.mlr.press/v80/lu18c/lu18c.pdf)

Suppose we have a set of base kernels $\mathcal{B} = \{A, B, C\}$ and a set of operations $\mathcal{O} = \{+, \times, Stop\}$

$\hat{B} = \{A_1, A_2, ..., A_D, 
B_1, B_2, ..., B_D, 
C_1, C_2, ..., C_D\}$


\begin{bmatrix} 
A_1 & B_1 & C_1 \\
A_2 & B_2 & C_2 \\
\vdots & \vdots & \vdots \\
A_D & B_D & C_D
\end{bmatrix}

We will 1-hot encode vectors for both, kernels and operations. We need $|\mathcal{B}|D$ bits to reprensent a kernel applied to a single dimension.

Any expression $S$ is transformed into a binary vector by recurrently attaching the 1-hot vectors of each kernel and operation. When the operation is Stop the vector is completed with zeros. For example, let $D=8$ and let $N_{max}$ be the number of operations. 

$A_1 + B_2 * C_8$ Stop ...

100000000000000000000000 100 000000000100000000000000 010 000000000000000000000001 001 000

Kernel Encoding: $ABC$

$A_1: 1000 0000 0000 0000 0000 0000$

$A_2: 0100 0000 0000 0000 0000 0000$

$B_1: 0000 0000 1000 0000 0000 0000$

$B_2: 0000 0000 0100 0000 0000 0000$

$C_1: 0000 0000 0000 0000 1000 0000$

$C_1: 0000 0000 0000 0000 0100 0000$

$C_8: 0000 0000 0000 0000 0000 0001$

In [10]:
# one-hot encode operations
add  = bin(0b100) # 0b100
mult = bin(0b010) # 0b010
stop = bin(0b001) # 0b001

kernel_families = ['A', 'B', 'C']
D = 4

def encode_kernel(family, dim):
    i = kernel_families.index(family) + 1
    shift = i * D - dim
    return 0b1 << shift

In [11]:
n_bits = int(len(kernel_families) * D)
for family in kernel_families:
    for d in range(1, D + 1):
        kern_encoding = encode_kernel(family, d)
        print(family + str(d) + ':', format(kern_encoding, '0' + str(n_bits) +'b'))
    print('')

A1: 000000001000
A2: 000000000100
A3: 000000000010
A4: 000000000001

B1: 000010000000
B2: 000001000000
B3: 000000100000
B4: 000000010000

C1: 100000000000
C2: 010000000000
C3: 001000000000
C4: 000100000000



In [12]:
A1 = encode_kernel('A', 1)
B2 = encode_kernel('B', 2)
C8 = encode_kernel('C', 8)
print('A1 + B2 * C8 = ', bin(A1) + add + bin(B2) + mult  + bin(C8) + stop + bin(0b000))

A1 + B2 * C8 =  0b10000b1000b10000000b100b100000b10b0


In [13]:
add

'0b100'