# [Solved] Lab 7: **Geometric Deep Learning** and *permutation equivariance*

Advanced Topics in Machine Learning -- Fall 2023, UniTS

<a target="_blank" href="https://colab.research.google.com/github/ganselmif/adv-ml-units/blob/main/solutions/AdvML_UniTS_2023_Lab_07_Permutation_Equivariance_Solved.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>


#### Geometric Deep Learning

The nowadays popular and growing subfield of *Geometric Deep Learning* is concerned with the study of deep learning models explicitly designed to learn representations that are invariant/equivariant to certain transformations. Such goal is very often achieved thanks to a *deeper* understanding of the mathematical properties of data, representations, and how a given model is able to capture them.

#### *Deep Sets* and *permutation equivariance*

In this lab, we will focus on a very simple mathematical construction, introduced in the [*Deep Sets* paper (2017)](https://arxiv.org/abs/1703.06114), able to exhibit in theory *permutation equivariance* with respect to input data, by design. Reading the paper is strongly recommended before starting this lab.

Recall the *permutation equivariance* property, as stated in the paper:

![](./img/perm_equi.png)

with $f$ being the model, $x$ a set of input data, and $\pi$ a permutation of the indices.

1. Implement such function (you can use *plain* `NumPy` !), as described in the following  snippet from the paper:
![](./img/perm_equi_th.png)

2. Check that the function is indeed permutation equivariant, applying it on synthetic data (you can generate them yourself), as we did in the previous lab(s) with images.


In [1]:
import functools
import numpy as np

In [2]:
def theta(_xshape, _lambda, _gamma):
    return _lambda * np.eye(_xshape) + _gamma * np.ones((_xshape, _xshape))


def mish(_x):
    return _x * np.tanh(np.log(1 + np.exp(_x)))


def f(_x, _lambda=1, _gamma=1):
    out = np.matmul(_x, theta(_x.shape[0], _lambda, _gamma))
    return mish(out)  # Any nonlinearity suffices! :)

In [3]:
# Generate some data
xsize = 200
nperms = 200

xlist = [
    np.random.rand(xsize),
]
for _ in range(nperms - 1):
    xlist.append(np.random.permutation(xlist[0]))

x = np.stack(xlist)

In [4]:
# Apply function
newf = functools.partial(f, _lambda=0.25, _gamma=0.4)
y = np.apply_along_axis(newf, 1, x)

In [5]:
ysorted = np.sort(y, axis=1)
if np.isclose(ysorted[0], ysorted).all():
    print("Success!")

Success!
