<a href="https://colab.research.google.com/github/davidwhogg/EmuCosmoSim/blob/main/ipynb/group_averaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finding equivariant convolution operators by group averaging

## Authors:
- **David W. Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## License
Copyright 2022 the authors. All rights reserved *for now*.

## To-do
- Make the code so it doesn't have a set of `scalar_filter` functions and `vector_filter` functions and so on. The functions should just take `k` as an input.
- Make structure so the code is agnostic about scalar/vector/tensor? That is, such that the objects know their own transformation properties? And norms? And visualization methods?

## Bugs:
- This code transmits `D, M, pixels, keys` as global variables, not by reading them off inputs (or getting them as inputs).
- Haven't figured out yet how to visualize the 2-tensor filters; maybe plot eigenvalues and eigenvectors?
- The group operators should be found by recursion; this ought to be more efficient.
- Fix 3-d plotting so it does a real projection (not just a set of incomprehensible hacks).
- We switched from M to 2m+1, make it consistent

In [None]:
import numpy as np
import pylab as plt
import itertools as it
import scipy.signal as sig

In [None]:
# Set integers:
D = 2 # D-dimensional image (must be 2 or 3 for plotting to work)
M = 5 # must be an odd integer

# Get the group ready and test it

In [None]:
# Make all possible group generators

# Make the flip operator
foo = np.ones(D).astype(int)
foo[0] = -1
gg = np.diag(foo).astype(int)
generators = [gg, ]

# Make the 90-degree rotation operators
for i in range(D):
    for j in range(i + 1, D):
        gg = np.eye(D).astype(int)
        gg[i, i] = 0
        gg[j, j] = 0
        gg[i, j] = -1
        gg[j, i] = 1
        generators.append(gg)
generators = np.array(generators)

# Look at them
for gg in generators:
    print(gg)

In [None]:
# Make all possible group operators.
# This code is very wasteful; there is a better way with recursion.

def make_all_operators(generators):
    operators = np.array([np.eye(D).astype(int), ])
    foo = 0
    while len(operators) != foo:
        foo = len(operators)
        operators = make_new_operators(operators, generators)
    return(operators)

def make_new_operators(operators, generators):
    for op in operators:
        for gg in generators:
            op2 = (gg @ op).astype(int)
            operators = np.unique(np.append(operators, op2[None, :, :], axis=0), axis=0)
    return operators

group_operators = make_all_operators(generators)
print("I found", len(group_operators), "group operators; here are their determinants:")
for gg in group_operators:
    print(gg, "determinant:", np.linalg.slogdet(gg)[0].astype(int))

In [None]:
# Check that the list of group operators is closed
for gg in group_operators:
    for gg2 in group_operators:
        assert ((gg @ gg2).astype(int) in group_operators)
    print(gg, True)

In [None]:
# Check that gg.T is gg.inv for all gg in group?
for gg in group_operators:
    print(gg, np.allclose(gg @ gg.T, np.eye(D)))

# Define the geometric objects and classes

In [None]:
class ktensor:

    def __init__(self, data, parity, D=D):
        self.D = D
        assert self.D > 1, \
        "ktensor: geometry makes no sense if D<2."
        self.parity = parity
        assert np.abs(self.parity) == 1, \
        "ktensor: parity must be 1 or -1."
        if len(np.atleast_1d(data)) == 1:
            self.data = data
            self.k = 0
        else:
            self.data = np.array(data)
            self.k = len(data.shape)
            assert np.all(np.array(data.shape) == D), \
            "ktensor: shape must be (D, D, D, ...)."
    
    def __add__(self, other):
        assert self.k == other.k, \
        "ktensor: can't add objects of different k"
        assert self.parity == other.parity, \
        "ktensor: can't add objects of different parity"
        return ktensor(self.data + other.data, self.parity)

    def __mul__(self, other):
        if self.k == 0 or other.k == 0:
            return ktensor(self.data * other.data, self.parity * other.parity)
        return ktensor(np.outer(self.data, other.data),
                       self.parity * other.parity)

    def __str__(self):
        return "<k-tensor object with k={} and parity={}>".format(self.k,
                                                                 self.parity)

    def times_group_element(self, gg):
        # BUG: THIS IS UNTESTED.
        # BUG: This is incomprehensible.
        assert self.k < 14
        assert gg.shape == (D, D)
        sign, logdet = np.linalg.slogdet(gg)
        assert logdet == 0.
        if self.k == 0:
            newdata = 1. * self.data
        else:
            firstletters  = "abcdefghijklm"
            secondletters = "nopqrstuvwxyz"
            einstr = "".join([firstletters[i] for i in range(self.k)]) +"," + \
            ",".join([secondletters[i] + firstletters[i] for i in range(self.k)])
            foo = (self.data, ) + self.k * (gg, )
            newdata = np.einsum(einstr, *foo)
        if self.parity < 0:
            newdata *= sign
        return ktensor(newdata, self.parity)

In [None]:
class geometric_filter:

    def make_pixels_and_keys(self):
        foo = range(-self.m, self.m + 1)
        self.pixels = np.array([pp for pp in it.product(foo, repeat=D)]).astype(int)
        self.keys = [tuple(pp) for pp in self.pixels]
        return

    def __init__(self, data, parity, D=D):
        self.D = D
        self.M = np.round(len(data) ** (1. / D)).astype(int)
        assert len(data) == self.M ** self.D, \
        "geometric_filter: data doesn't seem to be the right length?"
        self.m = (self.M - 1) // 2
        assert self.M == 2 * self.m + 1, \
        "geometric_filter: M needs to be odd."
        self.make_pixels_and_keys()
        self.parity = parity
        self.data = {kk: ktensor(ff, self.parity, self.D)
                     for kk, ff in zip(self.keys, data)}
        self.k = self.data[self.keys[0]].k
        return

    def copy(self):
        return geometric_filter(self.unpack(), self.parity, self.D)

    def __add__(self, other):
        assert self.D == other.D
        assert self.M == other.M
        newfilter = self.copy()
        for kk in self.keys:
            newfilter.data[kk] = newfilter.data[kk] + other.data[kk]
        return newfilter

    def times_group_element(self, gg):
        newfilter = self.copy()
        for pp, kk in zip(self.pixels, self.keys):
            newfilter.data[kk] = self.data[tuple(gg.T @ pp)].times_group_element(gg)
        return newfilter

    def unpack(self):
        return np.array([self.data[kk].data for kk in self.keys])

In [None]:
# Visualize (badly) a scalar filter.

FIGSIZE = (4, 3)
XOFF, YOFF = 0.15, -0.1
TINY = 1.e-5

def setup_plot():
    fig = plt.figure(figsize=FIGSIZE)

def finish_plot(title, pixels):
    plt.title(title)
    if D == 2:
        plt.xlim(np.min(pixels)-0.5, np.max(pixels)+0.5)
        plt.ylim(np.min(pixels)-0.5, np.max(pixels)+0.5)
    if D == 3:
        plt.xlim(np.min(pixels)-0.75, np.max(pixels)+0.75)
        plt.ylim(np.min(pixels)-0.75, np.max(pixels)+0.75)
    plt.gca().set_aspect("equal")
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])

def plot_boxes(xs, ys):
    for x, y in zip(xs, ys):
        plt.plot([x-0.5, x-0.5, x+0.5, x+0.5, x-0.5],
                 [y-0.5, y+0.5, y+0.5, y-0.5, y-0.5], "k-", lw=0.5)

def fill_boxes(xs, ys, ws):
    for x, y, w in zip(xs, ys, ws):
        if np.abs(w) > TINY:
            plt.fill_between([x - 0.5, x + 0.5], [y - 0.5, y - 0.5], [y + 0.5, y + 0.5],
                             color="k", alpha=0.1 * np.abs(w))

def plot_scalars(xs, ys, ws):
    plot_boxes(xs, ys)
    fill_boxes(xs, ys, ws)
    plt.scatter(xs[ws > TINY], ys[ws > TINY], marker="+", c="k", s=(1000/M)*ws[ws > TINY])
    plt.scatter(xs[ws < TINY], ys[ws < TINY], marker="_", c="k", s=(-1000/M)*ws[ws < TINY])

def plot_scalar_filter(filter, title):
    assert filter.k == 0
    if filter.D not in [2, 3]:
        print("plot_scalar_filter(): Only works for D in [2, 3].")
        return
    setup_plot()
    MtotheD = filter.M ** filter.D
    xs, ys, zs = np.zeros(MtotheD), np.zeros(MtotheD), np.zeros(MtotheD)
    ws = np.zeros(MtotheD)
    if filter.D == 2:
        for i, (kk, pp) in enumerate(zip(filter.keys, filter.pixels)):
            xs[i], ys[i] = pp
            ws[i] = filter.data[kk].data
        plot_scalars(xs, ys, ws)
    if filter.D == 3:
        for i, (kk, pp) in enumerate(zip(filter.keys, filter.pixels)):
            xs[i], ys[i], zs[i] = pp
            ws[i] = filter.data[kk].data
        plot_scalars(xs + XOFF * zs, ys + YOFF * zs, ws)
    finish_plot(title, filter.pixels)

In [None]:
foo = geometric_filter(np.random.normal(size=9), 1)
plot_scalar_filter(foo, "foo")
for i, gg in enumerate(group_operators):
    plot_scalar_filter(foo.times_group_element(gg), "$g_{}\cdot$foo".format(i))

In [None]:
# Visualize the vector filters.

def plot_vectors(xs, ys, ws):
    plot_boxes(xs, ys)
    fill_boxes(xs, ys, np.sum(np.abs(ws), axis=-1))
    for x, y, w in zip(xs, ys, ws):
        if np.sum(w * w) > TINY:
            plt.arrow(x - 0.3 * w[0], y - 0.3 * w[1],
                      0.6 * w[0], 0.6 * w[1],
                      length_includes_head=True, head_width=0.1, color="k")

def plot_vector_filter(filter, title):
    assert filter.k == 1
    if filter.D not in [2, 3]:
        print("plot_vector_filter(): Only works for D in [2, 3].")
        return
    setup_plot()
    MtotheD = filter.M ** filter.D
    xs, ys, zs = np.zeros(MtotheD), np.zeros(MtotheD), np.zeros(MtotheD)
    ws = np.zeros((MtotheD, filter.D))
    if filter.D == 2:
        for i, (kk, pp) in enumerate(zip(filter.keys, filter.pixels)):
            xs[i], ys[i] = pp
            ws[i] = filter.data[kk].data
        plot_vectors(xs, ys, ws)
    if filter.D == 3:
        for i, (kk, pp) in enumerate(zip(filter.keys, filter.pixels)):
            xs[i], ys[i], zs[i] = pp
            ws[i] = filter.data[kk].data
        plot_vectors(xs + XOFF * zs, ys + YOFF * zs, ws)
    finish_plot(title, filter.pixels)

In [None]:
foo = geometric_filter(np.random.normal(size=(9,2)), 1)
plot_vector_filter(foo, "foo")
for i, gg in enumerate(group_operators):
    s, l = np.linalg.slogdet(gg)
    if s > 0:
        plot_vector_filter(foo.times_group_element(gg), "$g_{}\cdot$foo".format(i))

# Now start the process of making the invariant filters

In [None]:
# Make geometric filter manipulation functions

def make_zero_filter(k, parity):
    data = np.zeros((M ** D, ) + k * (D, ))
    return geometric_filter(data, parity, D)

def rotate_vector(filter, gg, parity=1):
    return rotate(filter, gg)
    newfilter = filter.copy()
    for pp, kk in zip(pixels, keys):
        newfilter[kk] = gg @ filter[hash(gg.T @ pp)]
    if parity < 0:
        return scalar_multiply(np.linalg.slogdet(gg)[0], newfilter)
    return newfilter

def rotate_pseudovector(filter, gg):
    return rotate_vector(filter, gg, -1)

def rotate_2_tensor(filter, gg, parity=1):
    newfilter = filter.copy()
    for pp, kk in zip(pixels, keys):
        newfilter[kk] = gg @ filter[hash(gg.T @ pp)] @ gg.T
    if parity < 0:
        return scalar_multiply(np.linalg.slogdet(gg)[0], newfilter)
    return newfilter

def scalar_multiply(scalar, filter):
    newfilter = filter.copy()
    for kk in keys:
        newfilter[kk] = scalar * filter[kk]
    return newfilter

In [None]:
# Make nxn independent scalar-to-scalar filters
keys = make_zero_filter(0, 1).keys
allfilters = []
for kk in keys:
    thisfilter = make_zero_filter(0, 1)
    thisfilter.data[kk].data = 1
    allfilters.append(thisfilter)

In [None]:
# Sum all the group-element-transformed scalar filters and make a matrix of them
n = len(allfilters)
filter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allfilters):
    ff = make_zero_filter(0, 1)
    for gg in group_operators:
        ff = ff + f1.times_group_element(gg)
    filter_matrix[i] = ff.unpack()

In [None]:
# What are the unique scalar filters?
def get_unique_filters(k, parity):
    # make the seed filters
    keys = make_zero_filter(k, parity).keys
    allfilters = []
    if k == 0:
        for kk in keys:
            thisfilter = make_zero_filter(k, parity)
            thisfilter.data[kk].data = 1
            allfilters.append(thisfilter)
    else:
        for kk in keys:
            thisfilter = make_zero_filter(k, parity)
            for indices in something:
                thisfilter.data[kk].data[indices] = 1
                allfilters.append(thisfilter)
    # do the group averaging
    shape = (len(allfilters), ) + thisfilter.unpack().shape # THIS LINE IS WRONG
    filter_matrix = np.zeros(shape)
    for i, f1 in enumerate(allfilters):
        ff = make_zero_filter(k, parity)
        for gg in group_operators:
            ff = ff + f1.times_group_element(gg)
        filter_matrix[i] = ff.unpack()
    # do the SVD
    u, s, v = np.linalg.svd(filter_matrix)
    TINY = 1.e-5
    sbig = s > TINY
    if not np.any(sbig):
        return []
    # normalize the amplitudes so they max out at +/- 1.
    amps = v[sbig] / np.max(np.abs(v[sbig]), axis=1)[:, None]
    # make sure the amps are positive, generally
    for i in range(len(amps)):
        if np.sum(amps[i]) < 0:
            amps[i] *= -1
    # make sure that the zeros are zeros.
    amps[np.abs(amps) < TINY] = 0.
    return [geometric_filter(aa, parity) for aa in amps]

In [None]:
scalar_filters = get_unique_filters(0, 1)
for i, ff in enumerate(scalar_filters):
    plot_scalar_filter(ff, "scalar {}".format(i))

In [None]:
pseudoscalar_filters = get_unique_filters(0, -1)
for i, ff in enumerate(pseudoscalar_filters):
    plot_scalar_filter(ff, "pseudoscalar {}".format(i))

In [None]:
# Make Dn x Dn x ... independent scalar-to-vector filters
allvfilters = []
for kk in keys:
    for i in range(D):
        thisfilter = make_zero_vector_filter()
        thisfilter[kk][i] = 1
        allvfilters.append(thisfilter)

In [None]:
# Sum all the group-element-tranformed scalar-to-vector filters and make a matrix of them
n = len(allvfilters)
vfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allvfilters):
    ff = make_zero_vector_filter()
    for gg in group_operators:
        ff = add(ff, rotate_vector(f1, gg))
    vfilter_matrix[i] = unpack_vector_filter(ff).flatten()

In [None]:
# What are the unique scalar-to-vector filters?
def get_unique_vector_filters(matrix, parity):
    u, s, v = np.linalg.svd(matrix)
    TINY = 1.e-5
    sbig = s > TINY
    if not np.any(sbig):
        return []
    nbig = np.sum(sbig).astype(int)
    vecs = v[sbig].reshape((nbig, M ** D, D))
    # normalize so the biggest vectors are unit vectors
    for i in range(nbig):
        vecs[i] /= np.max([np.linalg.norm(vv) for vv in vecs[i]])
    # make sure the divergences or curls are positive
    if parity > 0:
        for i in range(nbig):
            if np.sum(vecs[i] * pixels) < 0:
                vecs[i] *= -1
    if D == 2 and parity < 0:
        rpixels = np.zeros_like(pixels)
        rpixels[:, 0], rpixels[:, 1] = -1. * pixels[:, 1], pixels[:, 0]
        for i in range(nbig):
            if np.sum(vecs[i] * rpixels) < 0:
                vecs[i] *= -1
    # make sure zeros are exactly zero
    vecs[np.abs(vecs) < TINY] = 0.0
    return order_filters([pack_vector_filter(vv) for vv in vecs])

vfilters = get_unique_vector_filters(vfilter_matrix, 1)
for ff in vfilters:
    print(ff)

In [None]:
# Sum all the group-element-tranformed scalar-to-pseudovector filters and make a matrix of them
n = len(allvfilters)
pvfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allvfilters):
    ff = make_zero_vector_filter()
    for gg in group_operators:
        ff = add(ff, rotate_pseudovector(f1, gg))
    pvfilter_matrix[i] = unpack_vector_filter(ff).flatten()

In [None]:
# What are the unique scalar-to-pseudovector filters?
pvfilters = get_unique_vector_filters(pvfilter_matrix, -1)
for ff in pvfilters:
    print(ff)

In [None]:
# Visualize (badly) the pseudovector filters.
for i, ff in enumerate(pvfilters):
    plot_vector_filter(ff, "pseudovector " + str(i))

In [None]:
# Make Dn x Dn x ... independent scalar-to-vector filters
allttfilters = []
for kk in keys:
    for i,j in it.product(range(D), repeat=2):
        thisfilter = make_zero_2_tensor_filter()
        thisfilter[kk][i, j] = 1
        allttfilters.append(thisfilter)

In [None]:
# Sum all the group-element-tranformed scalar-to-vector filters and make a matrix of them
n = len(allttfilters)
ttfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allttfilters):
    ff = make_zero_2_tensor_filter()
    for gg in group_operators:
        ff = add(ff, rotate_2_tensor(f1, gg))
    ttfilter_matrix[i] = unpack_2_tensor_filter(ff).flatten()

In [None]:
# What are the unique scalar-to-vector filters?
def get_unique_2_tensor_filters(matrix, parity):
    u, s, v = np.linalg.svd(matrix)
    TINY = 1.e-5
    sbig = s > TINY
    if not np.any(sbig):
        return []
    nbig = np.sum(sbig).astype(int)
    tens = v[sbig].reshape((nbig, M ** D, D, D))
    # change signs so the tensors are largely positive
    for i in range(nbig):
        tracesum = np.sum([np.trace(tt) for tt in tens[i]])
        if tracesum < 0:
            tens[i] *= -1.
    # normalize so the largest tensors are unit tensors
    for i in range(nbig):
        tens[i] /= np.max([np.linalg.norm(tt, ord=2) for tt in tens[i]])
    # make sure zeros are exactly zero
    tens[np.abs(tens) < TINY] = 0.0
    # check whether they are hermitian?
    for i in range(nbig):
        print(i, [int(np.allclose(tt, tt.T)) for tt in tens[i]])
    return order_filters([pack_2_tensor_filter(tt) for tt in tens])

ttfilters = get_unique_2_tensor_filters(ttfilter_matrix, 1)
# for ff in ttfilters:
#     print(ff)

In [None]:
# Visualize the 2-tensor filters
# HOGG TBD

# Now try convolution

In [None]:
# We need this to make fake data
!pip install finufft
import finufft

In [None]:
# make a sensible smooth scalar image on a 2-torus
N = 16
np.random.seed(42)
image = np.random.normal(size=(N, N))
foo = np.pi * np.arange(-1. + 1. / N, 1., 2. / N)
ys, xs = np.meshgrid(foo, foo) # ys, xs or xs, ys??
ft = finufft.nufft2d1(xs.flatten(), ys.flatten(), image.flatten().astype(complex), (3, 3))
scalar_image = finufft.nufft2d2(xs.flatten(), ys.flatten(), ft).reshape(N, N).real
scalar_image /= np.sqrt(np.mean(scalar_image ** 2))
print(scalar_image.shape, ft.shape)

In [None]:
plt.imshow(scalar_image, interpolation="nearest", origin="lower", cmap="gray")
plt.title("scalar image")
plt.colorbar()

In [None]:
# Make a sensible smooth vector image on a 2-torus
np.random.seed(42)
imagex = np.random.normal(size=(N, N))
imagey = np.random.normal(size=(N, N))
ftx = finufft.nufft2d1(xs.flatten(), ys.flatten(), imagex.flatten().astype(complex), (3, 3))
fty = finufft.nufft2d1(xs.flatten(), ys.flatten(), imagey.flatten().astype(complex), (3, 3))
vector_image = np.zeros((N, N, 2))
vector_image[:, :, 0] = finufft.nufft2d2(xs.flatten(), ys.flatten(), ftx).reshape(N, N).real
vector_image[:, :, 1] = finufft.nufft2d2(xs.flatten(), ys.flatten(), fty).reshape(N, N).real
vector_image /= np.sqrt(np.mean(vector_image ** 2))
print(vector_image.shape, ftx.shape)

In [None]:
pxs, pys = np.meshgrid(np.arange(N), np.arange(N))
plt.gca().set_aspect("equal", adjustable="box")
plt.quiver(pxs.flatten(), pys.flatten(),
           vector_image[:, :, 0].flatten(), vector_image[:, :, 1].flatten())
plt.xlim(-0.5, N-0.5)
plt.ylim(-0.5, N-0.5)
plt.title("vector image")
plt.show()

In [None]:
# now turn this into D-dimensional data in the D=3 case
assert D in [2, 3]
if D == 3:
    foo, bar = scalar_image, vector_image
    scalar_image = np.zeros((N, ) * D)
    vector_image = np.zeros((N, ) * D + (D, ))
    scalar_image = foo[:, :, None]
    vector_image = bar[:, :, None, :]

In [None]:
# Make d-torus convolution operator

def reformat_filter_into_block(filter):
    """
    ## bugs:
    - This function should learn D, M, m, kprime from the filter.
    - In general there are way too many global variables!
    """
    m = (M - 1) // 2
    assert M == 2 * m + 1
    if len(np.atleast_1d(filter[keys[0]])) == 1:
        kprime = 0
    else:
        kprime = len(filter[keys[0]].shape)
    filtershape = (M, ) * D
    filtershape += (D, ) * kprime
    ff = np.zeros(filtershape)
    for pp, kk in zip(pixels, keys):
        ff[pp + m] = filter[kk]
    return ff

def geometric_convolve_torus(image, filter):
    """
    ## bugs:
    - Things in this code will not scale to big problems!
    - Barely tested.
    - Never tested on 2-tensors or higher.
    """
    # check all inputs
    Ns = image.shape[:D]
    k = len(image.shape) - D
    assert len(filter) == M ** D
    assert M % 2 == 1
    m = (M - 1) // 2
    C = reformat_filter_into_block(filter)
    kprime = len(C.shape) - D
    # make output array
    outshape = image.shape + (D, ) * kprime
    outimage = np.zeros(outshape)
    # this loops over all image pixels
    for ii in it.product(*(range(N) for N in Ns)):
        # this loops over all filter pixels
        for jj in it.product(*(range(-m, m+1) for N in Ns)):
            # this handles the torus wrapping
            ll = tuple((i - j) % N for i, j, N in zip(ii, jj, Ns))
            # this makes the filter index correctly
            oo = tuple(j + m for j in jj)
            # this performs a safe outer product
            outimage[ii] += np.multiply.outer(image[ll], C[oo])
    return outimage

In [None]:
# Now plot convlutions of images with filters
# - for example: "(scalar image) \star (pseudovector 0)"
sstarv = geometric_convolve_torus(scalar_image, vfilters[1])
sstars = geometric_convolve_torus(scalar_image, filters[0])
vstars = geometric_convolve_torus(vector_image, filters[0])
vstarv = geometric_convolve_torus(vector_image, vfilters[1])
if len(pvfilters) > 0:
    vstarpv = geometric_convolve_torus(vector_image, pvfilters[1])

In [None]:
plt.gca().set_aspect("equal", adjustable="box")
plt.imshow(scalar_image, origin="lower")
plt.quiver(pxs.flatten(), pys.flatten(),
           sstarv[:, :, 0].flatten(), sstarv[:, :, 1].flatten(),
           color="r")
plt.xlim(-0.5, N-0.5)
plt.ylim(-0.5, N-0.5)
plt.title("scalar image STAR vector filter")
plt.show()

In [None]:
# Now show a nonlinear, local, scalar function of the scalar image
plt.imshow(np.sum(sstarv * sstarv, axis=-1), origin="lower")
plt.title("squared norm of the above vectors")

In [None]:
plt.gca().set_aspect("equal", adjustable="box")
plt.imshow(vstarv[:, :, 0, 0] + vstarv[:, :, 1, 1], origin="lower")
plt.quiver(pxs.flatten(), pys.flatten(),
           vector_image[:, :, 0].flatten(), vector_image[:, :, 1].flatten(),
           color="r")
plt.xlim(-0.5, N-0.5)
plt.ylim(-0.5, N-0.5)
plt.title("vec image STAR vec filter, contracted")
plt.show()

In [None]:
plt.gca().set_aspect("equal", adjustable="box")
plt.imshow(vstarpv[:, :, 0, 0] + vstarpv[:, :, 1, 1], origin="lower")
plt.quiver(pxs.flatten(), pys.flatten(),
           vector_image[:, :, 0].flatten(), vector_image[:, :, 1].flatten(),
           color="r")
plt.xlim(-0.5, N-0.5)
plt.ylim(-0.5, N-0.5)
plt.title("vec image STAR pvec filter, contracted")
plt.show()