## Challenges to solve
1. Generalization of hypercomplex multiplication
2. Speed up the kernel matrix collation

## The breakpoint
https://www.johndcook.com/blog/2018/07/10/cayley-dickson/

1. Recursive function makes hypercomplex multiplication easy
    1. Complex = (Real, Real)
    2. Quaternion = (Complex, Complex)
    3. Octonion = (Quaternion, Quaternion)
    etc.
    

2. Standard Conjugate multplication of complex numbers are recursively applicable

    z1 = a1 + j b1 = (a1, b1)
    z2 = a2 + j b2 = (a2, b2)
    
    z1' = a1 - j b1 = (a1, -b1)
    z2' = a2 - j b2 = (a2, -b2)
    
    z1 * z2 = (a1 a2 - b1 b2, a1 b2 + a2 b1)

In [None]:
def conj(x):
    xstar = -x
    xstar[0] *= -1
    return xstar 

def CayleyDickson(x, y):
    n = len(x)

    if n == 1:
        return x*y

    m = n // 2  # number of elements expected for the hypothetical splitting

    a, b = x[:m], x[m:]    # hypothetical spliting to real and imaginary
    c, d = y[:m], y[m:]    # hypothetical spliting to real and imaginary
    z = np.zeros(n)
    z[:m] = CayleyDickson(a, c) - CayleyDickson(conj(d), b)  # hypothetical real part
    z[m:] = CayleyDickson(d, a) + CayleyDickson(b, conj(c))  # hypothetical imaginary part
    return z

## And the story continues

- Why not use symbolic maths for the subdivisions to lower level hypercomplex
- With this, get a matrix of components needed for an hypercomplex
    - S = [O, O]
        
        = [[Q, Q], [Q, Q]]
        
        = [[[C, C], [C, C]],[[C, C], [C, C]] ]
        
        = [[[[R, R], [R, R]], [[R, R], [R, R]]],[[[R, R], [R, R]], [[R, R], [R, R]]] ]
        
- No panic, symbolic maths get all this covered....
    --- comes our "util" function for doing all sort of operations following the breakpoint techniques but in symbolic way

In [4]:
from fast_hypercomplex.utils import get_hmat, get_comp_mat
from fast_hypercomplex import HyperConv2d

from torch import nn
import torch

In [13]:
get_hmat(4)

[['w0', '-w1', '-w2', '-w3'],
 ['w1', 'w0', '-w3', 'w2'],
 ['w2', 'w3', 'w0', '-w1'],
 ['w3', '-w2', 'w1', 'w0']]

In [15]:
get_comp_mat(2**8)

array([[   0,   -1,   -2, ..., -253, -254, -255],
       [   1,    0,   -3, ..., -252, -255,  254],
       [   2,    3,    0, ...,  255, -252, -253],
       ...,
       [ 253,  252, -255, ...,    0,   -3,    2],
       [ 254,  255,  252, ...,    3,    0,   -1],
       [ 255, -254,  253, ...,   -2,    1,    0]])

In [7]:
conv2  = nn.Conv2d(32,64, kernel_size=3)
hconv2 = HyperConv2d(32, 64, kernel_size=3, n_divs=4)

In [8]:
# features : Nout Nin k k
conv2.weight.shape

torch.Size([64, 32, 3, 3])

In [9]:
# features : N_divs Nout/N_divs Nin/Ndivs k k
hconv2.weight.shape

torch.Size([4, 16, 8, 3, 3])

## The real deal for speedup

In [None]:
def fast_hypercomplex(weights, n_divs=4, comp_mat=None):
    """
    The constructed 'hamilton' W is a modified version of the hypercomplex representation,
    """
    if comp_mat is None:
        comp_mat = get_comp_mat(n_divs) 

    weights_new = torch.cat([weights, -torch.flipud(weights[1:])], dim=0)
    kernel = rearrange(weights_new[rearrange(comp_mat, 'a b -> (a b)')], '(a b) o i ... -> (a o) (b i) ...',
                       a=comp_mat.shape[0])
    return kernel