# Group invariant ML using fundamental domain projections

## Introduction

In this notebook we use the approach from *B. Aslan, D. Platt, D. Sheard: “Group invariant machine learning by fundamental domain projections”* for group invariant machine learning. We apply this to the dataset from *P.S. Green, T. Hubsch and C. A. Lutken: “All Hodge Numbers of All Complete Intersection Calabi-Yau Manifolds”*. Namely, we learn learn the first Hodge numbers. Random row and column transformations were applied to input matrices.

## Description of the group invariant ML architecture architecture

We explain the approach for functions which take a matrix as their input and output a number.
We say that a vector `v` is lexicographically bigger than a vector `w` if, read from front to end, the first entry where the two vectors disagree is bigger in `v` than in `w`.
For example:

```
(1, 2, 0, -1) > (1, 1, 3, 3)
(0.52, 1, 2)  > (0.5, -1, 3)
```

Let `F` be the map that applies row and column permutations to a given matrix to make it lexicographically as big as possible if read line by line.
For example:

```
   0  0  0    3  1  0
F: 0  1  3 -> 1  1  2
   2  1  1    0  0  0

because: (3, 1, 0, 1, 1, 2, 0, 0, 0) > (0, 0, 0, 0, 1, 3, 2, 1, 1)
```

The fantastic thing is: no matter how you permute rows and columns of an input matrix, the output of `F` will be the same.
That means if you apply `F` first, followed by any machine learning algorithm, it becomes group invariant.

This is easy to implement:
instead of making `F` part of the machine learning model, you can just apply `F` to every data point (training and test data), and then use ML as usual on it.
This works for every ML model: neural networks, SVMs, random forests...

The problem is that `F` is hard to compute.
(It has complexity greater than the graph isomorphism problem, which may be NP-hard.)
So, practically, we just apply some random permutations to an input matrix until it doesn't get lexicographically bigger anymore.
That means we are not guaranteed to find the global maximum, but in practice that's still okay.


## Data loading

In [1]:
import numpy as np
import itertools
from tqdm import tqdm
X = np.load('data/matrices_permuted.npy')
y = np.load('data/hodge_numbers.npy')[:,0]

## Applying the pre-processing `F`

In [2]:
# Pre-processing
def swap_rows(M, i, j):
  '''Swaps the i-th and j-th rows of the input matrix M.'''
  v = M.copy()
  v[[i,j]] = v[[j,i]]
  return v

def swap_cols(M, i, j):
  '''Swaps the i-th and j-th columns of the input matrix M.'''
  v = M.copy()
  v[:, [i, j]] = v[:, [j, i]]
  return v

def greater(A, B):
  '''Returns True if A is lexicographically greater than B.'''
  A = np.reshape(A,-1)
  B = np.reshape(B,-1)
  try:
    idx = np.where( (A>B) != (A<B) )[0][0]
  except IndexError as e:
    return False

  if A[idx] > B[idx]:
    return True
  else:
    return False

def maximise_matrix(M):
  '''Applies column and row transpositions to make the input matrix M
  lexicographically as big as possible.'''
  rows = M.shape[0]
  cols = M.shape[1]

  max_found = False

  while not max_found:
    max_found = True
    for rowpair in itertools.combinations(range(rows), 2):
      if greater(swap_rows(M, rowpair[0], rowpair[1]), M):
        M = swap_rows(M, rowpair[0], rowpair[1])
        max_found = False
    for colpair in itertools.combinations(range(cols), 2):
      if greater(swap_cols(M, colpair[0], colpair[1]), M):
        M = swap_cols(M, colpair[0], colpair[1])
        max_found = False
  return M

X_preprocessed = []
for mat in tqdm(X):
  X_preprocessed += [maximise_matrix(mat)]

100%|██████████| 7890/7890 [00:18<00:00, 432.63it/s]


In [3]:
X.reshape(-1,180)
X_preprocessed = np.array(X_preprocessed).reshape(-1,180)

## Training the neural network

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

X = torch.tensor(X_preprocessed, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

# define the model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(180, 128)
        self.act1 = nn.ReLU()
        self.hidden2 = nn.Linear(128, 64)
        self.act2 = nn.ReLU()
        self.output = nn.Linear(64, 1)

    def forward(self, x):
        x = self.act1(self.hidden1(x))
        x = self.act2(self.hidden2(x))
        x = self.output(x)
        return x

def check_acc(model, X, y):
  '''Prints the accuracy of the model for input data X and target values y.'''
  y_pred = model(X)
  accuracy = (y_pred.round() == y).float().mean()
  print(f"Accuracy {accuracy}")

model = MLP()
print(model)

# train the model
loss_fn   = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs = 50
batch_size = 10

for epoch in range(n_epochs):
    print(f'epoch: {epoch}')
    check_acc(model, X, y)
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        y_pred = model(Xbatch)
        ybatch = y[i:i+batch_size]
        loss = loss_fn(y_pred, ybatch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

check_acc(model, X, y)

MLP(
  (hidden1): Linear(in_features=180, out_features=128, bias=True)
  (act1): ReLU()
  (hidden2): Linear(in_features=128, out_features=64, bias=True)
  (act2): ReLU()
  (output): Linear(in_features=64, out_features=1, bias=True)
)
epoch: 0
Accuracy 0.002788339741528034
epoch: 1
Accuracy 0.051330797374248505
epoch: 2
Accuracy 0.04676806181669235
epoch: 3
Accuracy 0.034220531582832336
epoch: 4
Accuracy 0.04740177467465401
epoch: 5
Accuracy 0.04816222935914993
epoch: 6
Accuracy 0.049936629831790924
epoch: 7
Accuracy 0.05095057189464569
epoch: 8
Accuracy 0.05386565253138542
epoch: 9
Accuracy 0.05766793340444565
epoch: 10
Accuracy 0.06248415634036064
epoch: 11
Accuracy 0.0661596953868866
epoch: 12
Accuracy 0.05234473943710327
epoch: 13
Accuracy 0.07338403165340424
epoch: 14
Accuracy 0.07148288935422897
epoch: 15
Accuracy 0.10494296252727509
epoch: 16
Accuracy 0.11863117665052414
epoch: 17
Accuracy 0.13307984173297882
epoch: 18
Accuracy 0.15297845005989075
epoch: 19
Accuracy 0.16070975363