# Ask questions about how many linear and quadratic ops there are

## Authors:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- **Soledad Villar** (JHU)

## To-do items and bugs:
- Do something.

In [None]:
import itertools as it
import numpy as np
import geometric as geom
import finufft
import pylab as plt
%load_ext autoreload
%autoreload 2

In [None]:
D, N = 2, 3

## Set up the group

In [None]:
group_operators = geom.make_all_operators(D)
print(len(group_operators))

In [None]:
geom.test_group(group_operators)

In [None]:
geom.test_group_actions(group_operators)

## Make the invariant geometric filters

In [None]:
allfilters = {}
names = {}
maxn = {}
for M in [N, ]:
    maxn[(D, M)] = 0
    klist = (0, 1, 2)
    for k, parity in it.product(klist, (0, 1)):
        key = (D, M, k, parity)
        allfilters[key] = geom.get_unique_invariant_filters(M, k, parity, D, group_operators)
        n = len(allfilters[key])
        if n > maxn[(D, M)]:
            maxn[(D, M)] = n
        names[key] = ["{} {}".format(geom.ktensor.name(k, parity), i) for i in range(n)]

In [None]:
dpi = 300
paritysign = {0: "+", 1: "-"}
for key in allfilters.keys():
    D, M, k, parity = key
    fig = geom.plot_filters(allfilters[key], names[key], maxn[(D, M)])

## Make input vector image

In [None]:
# Make a sensible vector image on a D-torus
np.random.seed(42)
if D == 2:
    imagex = np.random.normal(size=(N, N))
    imagey = np.random.normal(size=(N, N))
    package = np.zeros((N, N, D))
    filtered = False # True if you want the image to be "smooth".
    if filtered:
        foo = np.pi * np.arange(-1. + 1. / N, 1., 2. / N)
        ys, xs = np.meshgrid(foo, foo) # ys, xs or xs, ys??
        ftx = finufft.nufft2d1(xs.flatten(), ys.flatten(), imagex.flatten().astype(complex), (12, 12))
        fty = finufft.nufft2d1(xs.flatten(), ys.flatten(), imagey.flatten().astype(complex), (12, 12))
        package[:, :, 0] = finufft.nufft2d2(xs.flatten(), ys.flatten(), ftx).reshape(N, N).real
        package[:, :, 1] = finufft.nufft2d2(xs.flatten(), ys.flatten(), fty).reshape(N, N).real
    else:
        package[:, :, 0] = imagex
        package[:, :, 1] = imagey
if D == 3:
    package = np.random.normal(size=(N, N, N, D))
package /= np.sqrt(np.mean(package ** 2))
print(package.shape)
vector_image = geom.geometric_image(package, 0, D).normalize()
print(vector_image)

In [None]:
if D == 2:
    fig = geom.plot_image(vector_image)

## Consider the linear case

In [None]:
# How can we make a vector image from this vector image?
# 1. Convolve with scalar filters
M = N
key = (D, M, 0, 0) # D M k parity
v_images = [vector_image.convolve_with(ff).normalize() for ff in allfilters[key]]
print(len(v_images))

In [None]:
# 2. Convolve with pseudoscalar filters and Levi-Civita contract
# Oh wait, we don't have any pseudoscalars at 3x3!
key = (D, M, 0, 1) # D M k parity
v_images += [vector_image.convolve_with(ff).levi_civita_contract(0).normalize() for ff in allfilters[key]]
print(len(v_images))

In [None]:
# 3. Convolve with 2-tensor filters and contract
key = (D, M, 2, 0) # D M k parity
v_images += [vector_image.convolve_with(ff).contract(0, 1).normalize() for ff in allfilters[key]]
v_images += [vector_image.convolve_with(ff).contract(0, 2).normalize() for ff in allfilters[key]]
print(len(v_images))

In [None]:
# 4.A Convolve with 2-pseudotensor filters, Levi-Civita contract, and contract
if D == 2:
    key = (D, M, 2, 1) # D M k parity
    v_images += [vector_image.convolve_with(ff).levi_civita_contract(0).contract(0, 1).normalize() for ff in allfilters[key]]
    v_images += [vector_image.convolve_with(ff).levi_civita_contract(0).contract(0, 2).normalize() for ff in allfilters[key]]
    print(len(v_images))

In [None]:
# 4.B Convolve with pseudovector filters, Levi-Civita contract (ie, cross product!)
if D in (3, ):
    key = (D, M, 1, 1) # D M k parity
    v_images += [vector_image.convolve_with(ff).levi_civita_contract([0, 1]).normalize() for ff in allfilters[key]]
    print(len(v_images))

In [None]:
datablock = np.array([im.unpack().flatten() for im in v_images])
print(datablock.shape)
u, s, v = np.linalg.svd(datablock)
print("there are", np.sum(s > geom.TINY), "different images at N = M =", N, "to linear order")

## Now consider the quadratic case

In [None]:
flat_list_of_filters = []
for key in allfilters.keys():
    flat_list_of_filters += allfilters[key]
for ff in flat_list_of_filters:
    print(ff, ff.k, ff.parity)

In [None]:
def make_all_quadratic_outputs(A, filters):
    """
    ## Bugs:
    - Not tested properly.
    - Does this need to consider all Levi-Civita contractions too?
    """
    output = []
    for C1, C2, C3 in it.product(filters, repeat=3):
        k = C1.k + C2.k + C3.k + 2
        p = (C1.parity + C2.parity + C3.parity) % 2
        if k % 2 == 0:
            continue
        if p == 1:
            continue
        B3 = (A.convolve_with(C1) * A.convolve_with(C2)).convolve_with(C3)
        assert B3.k == k
        assert B3.parity == p
        Bs_at_k = [B3, ]
        while k > 1:
            Bs_at_k_minus_2 = []
            for i in range(k-1):
                for j in range(i + 1, k):
                    Bs_at_k_minus_2 += [B.contract(i, j) for B in Bs_at_k]
            k = k - 2
            Bs_at_k = Bs_at_k_minus_2
        assert k == 1
        output += Bs_at_k
    return output

In [None]:
v_images_2 = make_all_quadratic_outputs(vector_image, flat_list_of_filters)
print(len(v_images_2))

In [None]:
# DON'T EVER run this cell: the standard SVD can't handle this.
assert False
datablock = np.array([im.unpack().flatten() for im in v_images_2])
print(datablock.shape)
u, s, v = np.linalg.svd(datablock)
print("there are", np.sum(s > geom.TINY),
      "different images at N = M =", N,
      "to quadratic order")