In [None]:
import itertools

import sympy as sm
import numpy as np

sm.init_printing()


class GeometricAlgebra(object):

    def __init__(self, bases, cayley_table):
        # Multiplication table (used for automatic simpliciation)
        self.cayley_table = cayley_table

        # Generate blades within this dimension
        self.blades = self.create_blades(bases)

        self.blades_list = sum(self.blades, [])

    @classmethod
    def generate(cls, p, m, z, start_inx=0):
        """
        See section (5.4) Sylvester signature theorem
            https://bivector.net/PROJECTIVE_GEOMETRIC_ALGEBRA.pdf

        Generate a symmetric bilinear form of dimension n = p + m + z

        e_i.dot(e_j) = 0

        Args:
            p (int): e_i.dot(e_i) = +1
            m (int): e_i.dot(e_i) = -1
            z (int): e_i.dot(e_i) =  0
        """
        assert (p >= 0) and (m >= 0) and (z >= 0)
        n = p + m + z
        e = sm.symbols('e{}:{}'.format(start_inx, start_inx + n),
                       commutative=False)

        subs = {}
        squares = [0] * z + [1] * p + [-1] * m
        for i in range(n):
            for j in range(n):
                if i == j:
                    subs[e[i] * e[i]] = squares[i]
                elif i > j:
                    subs[e[i] * e[j]] = -e[j] * e[i]

        return cls(e, subs)

    @staticmethod
    def create_blades(bases):
        blades = []
        for dim in range(len(bases) + 1):
            # Blades are products of all unique combinations of dim # of bases
            inds_list = itertools.combinations(range(len(bases)), dim)
            blades.append(
                [sm.S(sm.prod([bases[i] for i in inds])) for inds in inds_list])
        return blades

    def simp(self, a, collect=True):
        while True:
            b = sm.expand(a).subs(self.cayley_table, simultaneous=True)
            if b == a:
                break
            a = b
        
        if collect:
            b = sm.collect(b, sum(reversed(self.blades), []))
    
        return b

    def blade(self, dim, coeffs):
        assert (dim > 0) and (dim < len(self.blades))
        assert len(coeffs) == len(self.blades[dim])
        return sum(self.blades[dim][i] * coeff for i, coeff in enumerate(coeffs))
    
    @property
    def pseudoscalar(self):
        assert len(self.blades[-1]) == 1
        return self.blades[-1][0]
    
    def coeffs(self, a):
        blade_bases = [{e for e in blade.free_symbols}
                       for blade in self.blades_list]

        coeffs = [sm.S(0)] * len(blade_bases)
        for k, v in sm.S(a).as_coefficients_dict().items():
            for blade_inx in reversed(range(len(blade_bases))):
                if blade_bases[blade_inx].issubset(k.free_symbols):
                    coeffs[blade_inx] += k.subs(self.blades_list[blade_inx], 1) * v
                    break
    
        return coeffs
    
    def from_coeffs(self, coeffs):
        assert len(coeffs) == len(self.blades_list)
        return sum(k * v for k, v in zip(self.blades_list, coeffs))

    def dual(self, a):
        return sum(k * v for k, v in zip(reversed(self.blades_list), self.coeffs(a)))

    def reverse(self, a):
        result = 0
        for k, v in sm.S(a).as_coefficients_dict().items():
            grade = len(k.free_symbols.intersection(self.blades[1]))
            num_flips = len(list(itertools.combinations(range(grade), 2)))
            if num_flips % 2 == 1:
                result += -k * v
            else:
                result += k * v

        return self.simp(result)

    def regressive(self, a, b):
        return self.dual(self.wedge(self.dual(a), self.dual(b)))

    def join(self, a, b):
        return self.regressive(a, b)
    
    # NOTE(hayk): Actually the commutator product
    def commutator(self, a, b):
        """
        Wedge (upward hat) operator, outer product, (exterior product)
        For two vectors, this is the outer product
        """
        # For two vectors, gives the bivector (plane)
        return self.simp(((a * b) - (b * a)) / sm.S(2))

    # aliases
    join = regressive
#     meet = wedge

    def meet(self, a, b):
        return self.dot(self.dual(a), b)
    
#     def meet(self, a, b):
#         return self.wedge(a, b)
    
    # TODO(hayk): Inner/outer are grade selecting?

    def sandwich(self, a, b):
        return self.simp(a * b * self.reverse(a))

    # -----------------------------------------------------------------------
    # WIP
    # -----------------------------------------------------------------------

    def dot(self, a, b):
        # TODO hayk assert a and b are vectors.
        # For two vectors, this is the inner product
        return self.simp(((a * b) + (b * a)) / sm.S(2))

    # TODO replaced by sandwich
    def reflect(self, a, r):
        return self.simp(-r * a * r)

    # TODO replaced by sandwich
    # rot: ba * v * ab
    def rotate(self, a, b, v):
        return self.simp(a * b * v * b * a)


class ProjectiveGeometry3D(GeometricAlgebra):

    @classmethod
    def create(cls):
        return ProjectiveGeometry3D.generate(3, 0, 1)

    # -----------------------------------------------------------------------
    # Pretty sure are correct
    # -----------------------------------------------------------------------

    # A plane is defined using its homogenous equation ax + by + cz + d = 0 
    def plane(self, a, b, c, d):
        return self.blade(1, (d, a, b, c))
    
    # Vector is the same as a plane through the origin
    def vec(self, x, y, z):
        return self.plane(x, y, z, 0)

    # Homogeneous point
    def point(self, x, y, z):
        return self.blade(3, (x, y, z, 1))

    # -----------------------------------------------------------------------
    # WIP
    # -----------------------------------------------------------------------

    def inverse(self, a):
        return self.simp(self.reverse(a) / self.simp(self.reverse(a) * a))

#     def dual(self, a):
#         # NOTE this is not the pseudoscalar e0123 but e123..
#         return self.simp(a * self.inverse(self.blades[3][-1]))

#     def dual4(self, a):
#         # NOTE needs inverse, but in 3d it's the same
#         return self.simp(a * self.pseudoscalar)

    def cross(self, a, b):
        # See equation 4.5 (page 31):
        #    http://www.jaapsuter.com/geometric-algebra.pdf
        return self.dual(self.wedge(a, b))

    def line_from_vec(self, vec):
        e0, e1, e2, e3 = self.blades[1]
        return self.simp(e1 * e2 * e3 * vec)

    def vec_from_line(self, line):
        e0, e1, e2, e3 = self.blades[1]
        return self.simp(-e1 * e2 * e3 * line)

    def rotate_axis_angle(self, axis, angle, v):
        half_angle = angle / sm.S(2)
        scalar = sm.cos(half_angle)
        bivector = sm.sin(half_angle) * self.line_from_vec(axis)
        return self.simp((scalar + bivector) * v * (scalar - bivector))
    
    # Axis must be a LINE here
    def rotor(self, axis, angle):
        half_angle = sm.S(angle) / 2
        return sm.cos(half_angle) + sm.sin(half_angle) * axis


In [None]:
A = ProjectiveGeometry3D.create()
e0, e1, e2, e3 = A.blades[1]
e01, e02, e03, e12, e13, e23 = A.blades[2]
e012, e013, e023, e123 = A.blades[3]
e0123, = A.blades[4]

In [None]:
expr = A.plane(sm.Symbol('x') * 3, 0, 1, sm.Symbol('y')**2)
expr

In [None]:
A.blades[3]

In [None]:
A.simp(e1 * e2 * e3 - e2 * e3 * e1)

In [None]:
B = GeometricAlgebra.generate(3, 0, 0, start_inx=1)
B.blades

In [None]:
a = sm.Symbol('alpha_1') * e1 + sm.Symbol('alpha_2') * e2# + sm.Symbol('alpha_3') * e3
b = sm.Symbol('beta_1') * e1 + sm.Symbol('beta_2') * e2# + sm.Symbol('beta_3') * e3
c = sm.Symbol('gamma_1') * e1 + sm.Symbol('gamma_2') * e2# + sm.Symbol('gamma_3') * e3
a, b, c

In [None]:
B.outer(e2, e3)

In [None]:
B.outer(e1, e2*e3)

In [None]:
B.simp(e1 * e2 * e3 + e3 * e1 * e2)

In [None]:
B.outer(B.outer(a, b), c)

In [None]:
A = GeometricAlgebra.generate(4, 0, 1)
A.blades

In [None]:
A.outer(a, b)

In [None]:
coeffs = A.coeffs(expr)
coeffs

In [None]:
A.from_coeffs(coeffs)

In [None]:
sum(k * v for k, v in zip(A.blades_list, coeffs))

In [None]:
A.line_from_vec(A.vec(0, 0, 1))

In [None]:
a = A.vec(*sm.symbols('alpha1:4'))
a

In [None]:
b = A.vec(*sm.symbols('beta1:4'))
b

In [None]:
val = A.wedge(a, b)
val

In [None]:
A.dual(val)

In [None]:
A.meet(A.plane(1, 0, 0, 0), A.plane(0, 0, 1, 0))

In [None]:
A.join(A.point(1, 0, 0), A.point(0, 0, 1))

In [None]:
P1 = A.point(*sm.symbols('a b c'))
P2 = A.point(*sm.symbols('l m n'))
P1, P2

In [None]:
A.wedge(P1, P2)

In [None]:
A.regressive(P1, P2)

In [None]:
A.dual4(1)

In [None]:
e0 * e0123

In [None]:
A.dual4(e0)

In [None]:
A.regressive(P1, P2)

In [None]:
foo = A.simp(val * A.inverse(e123))
foo

In [None]:
rotor = A.rotor(A.line_from_vec(A.vec(0, 0, 1)), np.deg2rad(45))
rotor

In [None]:
A.sandwich(rotor, A.vec(1, 0, 0))

In [None]:
A.dual(e1)

In [None]:
A.reverse(e012)

In [None]:
A.sandwich(e1, e2)

In [None]:
A.reflect(e1, e2)

In [None]:
A.blade(2, (2, 3, 4, 1, 5, 6))

In [None]:
A.plane(1, 0, 0, 1)

In [None]:
A.dual(A.plane(1, 0, 0, 1))

In [None]:
A.vec(1, 2, 3)

In [None]:
A.dual(A.vec(0, 1, 0))

In [None]:
A.dual(A.plane(1, 0, 1))

In [None]:
A.dual(A.vec(1, 0, 1))

In [None]:
A.dual(A.point(1, 0, 0))

In [None]:
A.dual(e0)

In [None]:
A.vec(1, 2, 3)

In [None]:
A.dot(e3, e12)

In [None]:
A.line_from_vec(A.vec(1, 0, 0))

In [None]:
A.vec_from_line(e2 * e3)

In [None]:
# lie groups = confusing
# quaternion math = confusing
# handling of infinity in projective geometry = confusing
# cross products = confusing

In [None]:
A.reflect(A.vec(-0.5, 2, 0), A.vec(1, 0, 0))

In [None]:
A.rotate(e1, e2, A.vec(0, 1, 0))

In [None]:
A.plane_from_normal(e1)

In [None]:
A.normal_from_plane(e1 * e2)

In [None]:
# a * b = |a||b| * (cos(ang) + sin(ang) * B)
# where B is a bivector of two orthogonal unit vectors in the plane of a and b
# for unit a, b
# a * b = cos(ang) + sin(ang) * B

In [None]:
A.rotate_axis_angle(A.vec(0, 0, 1), np.deg2rad(1), A.vec(1, 0, 0))

In [None]:
# A.rotate_axis_angle(
#     A.vec(*sm.symbols('ax ay az')),
#     sm.Symbol('theta'),
#     A.vec(*sm.symbols('x y z'))
# )

In [None]:
p1 = A.plane(0, 1, 0, 0)
p2 = A.plane(0, 0, 1, 1)

In [None]:
A.wedge(p1, p2)

In [None]:
A.point(1, 20, 0)

In [None]:
A.rotate_axis_angle(A.vec(0, 0, 1), np.deg2rad(45), A.vec(1, 0, 0))

In [None]:
A.rotate_axis_angle(A.vec(0, 0, 1), np.deg2rad(45), A.point(1, 0, 0))