<a href="https://colab.research.google.com/github/davidwhogg/EmuCosmoSim/blob/main/ipynb/group_averaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finding equivariant convolution operators by group averaging

## Authors:
- **David W. Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## License
Copyright 2022 the authors. All rights reserved *for now*.

## To-do
- Do the 2-tensor filters!
- Make structure so the code is agnostic about scalar/vector/tensor?

## Bugs:
- The group operators should be found by recursion; this ought to be more efficient.
- Fix 3-d plotting so it does a real projection (not just a set of incomprehensible hacks).

In [None]:
import numpy as np
import pylab as plt
import itertools as it

In [None]:
# Set integers:
D = 2 # D-dimensional image (must be 2 or 3 for plotting to work)
filter_size = 5 # must be an odd integer

In [None]:
# Make all n x n pixels
if filter_size % 2 != 1:
    print("Filter size must be odd")
    assert False
foo = range(-((filter_size - 1) // 2), ((filter_size + 1) // 2))
pixels = np.array([pp for pp in it.product(foo, repeat=D)]).astype(int)

In [None]:
# Define hash and unhash functions for pixel names
def hash(pp):
    return str(pp.astype(int))[1:-1]

def unhash(kk):
    return np.fromstring(kk, sep=" ").astype(int)

In [None]:
# Make all keys
keys = [hash(pp) for pp in pixels]
print(keys)

In [None]:
# Make all possible group generators

# Make the flip operator
foo = np.ones(D).astype(int)
foo[0] = -1
gg = np.diag(foo).astype(int)
generators = [gg, ]

# Make the 90-degree rotation operators
for i in range(D):
    for j in range(i + 1, D):
        gg = np.eye(D).astype(int)
        gg[i, i] = 0
        gg[j, j] = 0
        gg[i, j] = -1
        gg[j, i] = 1
        generators.append(gg)
generators = np.array(generators)

# Look at them
for gg in generators:
    print(gg)

In [None]:
# Make all possible group operators.
# This code is very wasteful; there is a better way with recursion.

def make_all_operators(generators):
    operators = np.array([np.eye(D).astype(int), ])
    foo = 0
    while len(operators) != foo:
        foo = len(operators)
        operators = make_new_operators(operators, generators)
    return(operators)

def make_new_operators(operators, generators):
    for op in operators:
        for gg in generators:
            op2 = (gg @ op).astype(int)
            operators = np.unique(np.append(operators, op2[None, :, :], axis=0), axis=0)
    return operators

group_operators = make_all_operators(generators)
for gg in group_operators:
    print(gg, "determinant:", np.linalg.slogdet(gg)[0].astype(int))

In [None]:
# Check that each group operator leaves the pixel list unchanged
for gg in group_operators:
    newpixels = np.array([gg @ pp.copy() for pp in pixels]).astype(int)
    assert(set(tuple(pp) for pp in newpixels) == set(tuple(pp) for pp in pixels))
    print(gg, True)

In [None]:
# Check that the list of group operators is closed
for gg in group_operators:
    for gg2 in group_operators:
        assert ((gg @ gg2).astype(int) in group_operators)
    print(gg, True)

In [None]:
# Check that gg.T is gg.inv for all gg in group?
for gg in group_operators:
    print(gg, np.allclose(gg @ gg.T, np.eye(D)))

In [None]:
# Make filter manipulation functions

def rotate_scalar(filter, gg):
    newfilter = filter.copy()
    for kk in keys:
        newfilter[hash(gg @ unhash(kk))] = filter[kk]
    return newfilter

def rotate_pseudoscalar(filter, gg):
    newfilter = filter.copy()
    for kk in keys:
        newfilter[hash(gg @ unhash(kk))] = (np.linalg.slogdet(gg)[0] * filter[kk]).astype(int)
    return newfilter

def rotate_vector(filter, gg):
    newfilter = filter.copy()
    for kk in keys:
        newfilter[hash(gg @ unhash(kk))] = (gg @ filter[kk]).astype(int)
    return newfilter

def rotate_pseudovector(filter, gg):
    newfilter = filter.copy()
    for kk in keys:
        newfilter[hash(gg @ unhash(kk))] = (np.linalg.slogdet(gg)[0] * gg @ filter[kk]).astype(int)
    return newfilter

def add(filter1, filter2):
    newfilter = filter1.copy()
    for kk in keys:
        newfilter[kk] = filter1[kk] + filter2[kk]
    return newfilter

def pack_scalar_filter(amps):
    assert len(amps) == filter_size ** D
    return {kk: ff for kk, ff in zip(keys, amps)}

def unpack_scalar_filter(filter):
    return np.array([filter[kk] for kk in keys])

def make_zero_scalar_filter():
    return pack_scalar_filter(np.zeros(filter_size ** D).astype(int))

def pack_vector_filter(vecs):
    assert len(vecs) == filter_size ** D
    return {kk: ff for kk, ff in zip(keys, vecs)}

def unpack_vector_filter(filter):
    return np.array([filter[kk] for kk in keys])

def make_zero_vector_filter():
    return pack_vector_filter(np.zeros((filter_size ** D, D)).astype(int))

In [None]:
# Make nxn independent scalar-to-scalar filters
allfilters = []
for kk in keys:
    thisfilter = make_zero_scalar_filter()
    thisfilter[kk] = 1
    allfilters.append(thisfilter)

In [None]:
# Sum all the group-element-tranformed scalar-to-scalar filters and make a matrix of them
n = len(allfilters)
filter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allfilters):
    ff = make_zero_scalar_filter()
    for gg in group_operators:
        ff = add(ff, rotate_scalar(f1, gg))
    filter_matrix[i] = unpack_scalar_filter(ff)

In [None]:
# What are the unique scalar-to-scalar filters?
def get_unique_scalar_filters(matrix):
    u, s, v = np.linalg.svd(matrix)
    TINY = 1.e-5
    sbig = s > TINY
    if not np.any(sbig):
        return []
    # normalize the ampltidues so they max out at +/- 1.
    amps = v[sbig] / np.max(np.abs(v[sbig]), axis=1)[:, None]
    # make sure the amps are positive, generally
    for i in range(len(amps)):
        if np.sum(amps[i]) < 0:
            amps[i] *= -1
    # make sure that the zeros are zeros.
    amps[np.abs(amps) < TINY] = 0.
    return [pack_scalar_filter(aa) for aa in amps]

filters = get_unique_scalar_filters(filter_matrix)
for ff in filters:
    print(ff)

In [None]:
# Visualize (badly) the scalar filters.

FIGSIZE = (4, 3)
XOFF, YOFF = 0.15, -0.1
TINY = 1.e-5

def setup_plot():
    fig = plt.figure(figsize=FIGSIZE)

def finish_plot(title):
    plt.title(title)
    if D == 2:
        plt.xlim(np.min(pixels)-0.5, np.max(pixels)+0.5)
        plt.ylim(np.min(pixels)-0.5, np.max(pixels)+0.5)
    if D == 3:
        plt.xlim(np.min(pixels)-0.75, np.max(pixels)+0.75)
        plt.ylim(np.min(pixels)-0.75, np.max(pixels)+0.75)
    plt.gca().set_aspect("equal")

def plot_boxes(xs, ys):
    for x, y in zip(xs, ys):
        plt.plot([x-0.5, x-0.5, x+0.5, x+0.5, x-0.5],
                 [y-0.5, y+0.5, y+0.5, y-0.5, y-0.5], "k-", lw=0.5)

def fill_boxes(xs, ys, ws):
    for x, y, w in zip(xs, ys, ws):
        if np.abs(w) > TINY:
            plt.fill_between([x - 0.5, x + 0.5], [y - 0.5, y - 0.5], [y + 0.5, y + 0.5],
                             color="k", alpha=0.1)

def plot_scalars(xs, ys, ws):
    plot_boxes(xs, ys)
    fill_boxes(xs, ys, ws)
    plt.scatter(xs[ws > TINY], ys[ws > TINY], marker="+", c="k", s=(1000/filter_size)*ws[ws > TINY])
    plt.scatter(xs[ws < TINY], ys[ws < TINY], marker="_", c="k", s=(-1000/filter_size)*ws[ws < TINY])

def plot_scalar_filter(filter, title):
    if D not in [2, 3]:
        print("plot_scalar_filter(): Only works for D in [2, 3].")
        return
    setup_plot()
    xs, ys, zs = np.zeros(filter_size ** D), np.zeros(filter_size ** D), np.zeros(filter_size ** D)
    ws = np.zeros(filter_size ** D)
    if D == 2:
        for i, kk in enumerate(keys):
            xs[i], ys[i] = unhash(kk)
            ws[i] = filter[kk]
        plot_scalars(xs, ys, ws)
    if D == 3:
        for i, kk in enumerate(keys):
            xs[i], ys[i], zs[i] = unhash(kk)
            ws[i] = filter[kk]
        plot_scalars(xs + XOFF * zs, ys + YOFF * zs, ws)
    finish_plot(title)

for i, ff in enumerate(filters):
    plot_scalar_filter(ff, "scalar " + str(i))

In [None]:
# Sum all the group-element-tranformed scalar-to-pseudoscalar filters and make a matrix of them
n = len(allfilters)
pfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allfilters):
    ff = make_zero_scalar_filter()
    for gg in group_operators:
        ff = add(ff, rotate_pseudoscalar(f1, gg))
    pfilter_matrix[i] = unpack_scalar_filter(ff)

In [None]:
# What are the unique scalar-to-scalar filters?
pfilters = get_unique_scalar_filters(pfilter_matrix)
for ff in pfilters:
    print(ff)

In [None]:
# Visualize (badly) the pseudoscalar filters.
for i, ff in enumerate(pfilters):
    plot_scalar_filter(ff, "pseudoscalar " + str(i))

In [None]:
# Make Dn x Dn x ... independent scalar-to-vector filters
allvfilters = []
for kk in keys:
    for i in range(D):
        thisfilter = make_zero_vector_filter()
        thisfilter[kk][i] = 1
        allvfilters.append(thisfilter)

In [None]:
# Sum all the group-element-tranformed scalar-to-vector filters and make a matrix of them
n = len(allvfilters)
vfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allvfilters):
    ff = make_zero_vector_filter()
    for gg in group_operators:
        ff = add(ff, rotate_vector(f1, gg))
    vfilter_matrix[i] = unpack_vector_filter(ff).flatten()

In [None]:
# What are the unique scalar-to-vector filters?
def get_unique_vector_filters(matrix, parity):
    u, s, v = np.linalg.svd(matrix)
    TINY = 1.e-5
    sbig = s > TINY
    if not np.any(sbig):
        return []
    nbig = np.sum(sbig).astype(int)
    vecs = v[sbig].reshape((nbig, filter_size ** D, D))
    # normalize so the vectors are (on average) unit vectors
    norms = np.sqrt(np.sum(vecs * vecs, axis=2))
    norms[norms < TINY] = 1.0
    vecs = vecs / norms[:, :, None]
    # make sure the divergences or curls are positive
    if parity > 0:
        for i in range(nbig):
            if np.sum(vecs[i] * pixels) < 0:
                vecs[i] *= -1
    if D == 2 and parity < 0:
        rpixels = np.zeros_like(pixels)
        rpixels[:, 0], rpixels[:, 1] = -1. * pixels[:, 1], pixels[:, 0]
        for i in range(nbig):
            if np.sum(vecs[i] * rpixels) < 0:
                vecs[i] *= -1
    # make sure zeros are exactly zero
    vecs[np.abs(vecs) < TINY] = 0.0
    return [pack_vector_filter(vv) for vv in vecs]

vfilters = get_unique_vector_filters(vfilter_matrix, 1)
for ff in vfilters:
    print(ff)

In [None]:
# Visualize the vector filters.

def plot_vectors(xs, ys, ws):
    plot_boxes(xs, ys)
    fill_boxes(xs, ys, np.sum(np.abs(ws), axis=-1))
    for x, y, w in zip(xs, ys, ws):
        if np.sum(w * w) > TINY:
            plt.arrow(x - 0.3 * w[0], y - 0.3 * w[1],
                      0.6 * w[0], 0.6 * w[1],
                      length_includes_head=True, head_width=0.1, color="k")

def plot_vector_filter(filter, title):
    if D not in [2, 3]:
        print("plot_vector_filter(): Only works for D in [2, 3].")
        return
    setup_plot()
    xs, ys, zs = np.zeros(filter_size ** D), np.zeros(filter_size ** D), np.zeros(filter_size ** D)
    ws = np.zeros((filter_size ** D, D))
    if D == 2:
        for i, kk in enumerate(keys):
            xs[i], ys[i] = unhash(kk)
            ws[i] = filter[kk]
        plot_vectors(xs, ys, ws)
    if D == 3:
        for i, kk in enumerate(keys):
            xs[i], ys[i], zs[i] = unhash(kk)
            ws[i] = filter[kk]
        plot_vectors(xs + XOFF * zs, ys + YOFF * zs, ws)
    finish_plot(title)

for i, ff in enumerate(vfilters):
    plot_vector_filter(ff, "vector " + str(i))

In [None]:
# Sum all the group-element-tranformed scalar-to-pseudovector filters and make a matrix of them
n = len(allvfilters)
pvfilter_matrix = np.zeros((n, n)).astype(int)
for i, f1 in enumerate(allvfilters):
    ff = make_zero_vector_filter()
    for gg in group_operators:
        ff = add(ff, rotate_pseudovector(f1, gg))
    pvfilter_matrix[i] = unpack_vector_filter(ff).flatten()

In [None]:
# What are the unique scalar-to-pseudovector filters?
pvfilters = get_unique_vector_filters(pvfilter_matrix, -1)
for ff in pvfilters:
    print(ff)

In [None]:
# Visualize (badly) the pseudovector filters.
for i, ff in enumerate(pvfilters):
    plot_vector_filter(ff, "pseudovector " + str(i))