In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from collections import deque

In [None]:
def generate_matrix(N,d = 3):
    # Function to generate N invertable matrices of dimension d with small condition number
    # Returns a 3D-numpy array 
    
    #Generate eigenvalues of matrices 
    x = 2*np.ones((N,d,1), dtype = np.float64) - 0.5*np.random.randn(N,d,1) 
    #Create diagonal matrices 
    diag  = np.eye(x.shape[1])*x[:,np.newaxis] 
    #Transformation matrix for similarity transform 
    M = np.random.randn(d,d) 
    return np.matmul(np.matmul(M,diag),np.linalg.inv(M))

X_data = generate_matrix(10000)
X_test = X_data[0]

print(X_test, np.linalg.inv(X_test), np.linalg.cond(X_test))

In [None]:
#Simple MLP
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten(start_dim=2)
        self.unflatten = nn.Unflatten(-1,(3,3))
        self.relu_stack = nn.Sequential(
                #nn.Linear(9,9),
                nn.ReLU(),
                nn.Sigmoid(),
                nn.ReLU(),
                nn.Sigmoid(),
                nn.ReLU(),
                nn.Sigmoid(),
                nn.ReLU(),
                nn.Sigmoid(),
                nn.ReLU(),
                nn.Sigmoid(),
                nn.Linear(9,9)
        )
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu_stack(x)
        x = self.unflatten(x)
        return x
model = MLP()
# Test model ouput has correct format. 
#print(model(torch.rand((1,1,3,3))))

In [None]:
# Training parameters
batch_size = 100
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [None]:
### Training Loop
for i in range(10000):

    # create random matrices
    x = torch.rand((batch_size,1,3,3))

    pred = model(x)
    id = torch.eye(3,3)
    inv_x = torch.inverse(x)

    # Compute loss between network and numerical solution
    # By how well it serves as an inverse
    loss = (torch.matmul(pred,x) - id).square().mean()

    # We store the model if it has the lowest fitness
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())

    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())

    # Print every i iterations
    if i % 100 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")

    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

In [None]:
# Let's look at an example
m = torch.rand((1,1,3,3))
print(m)

In [None]:
torch.inverse(m)

In [None]:
pred = model(m)
print(pred)
print(torch.matmul(pred,m)) #Should the identity matrix.
