In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from collections import deque

In [84]:
#Simple MLP
class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        #self.flatten = nn.Flatten(start_dim=2)
        #self.unflatten = nn.Unflatten(1,(3,3))
        self.relu_stack = nn.Sequential(
            nn.ReLU(),
            nn.ReLU(),
            nn.ReLU(),
            nn.ReLU(),
            nn.ReLU(),
            nn.Sigmoid(),
        )
    def forward(self, x):
        #x = self.flatten(x)
        x = self.relu_stack(x)
        #x = self.unflatten(x)
        return x

model = MLP()
print(model)

MLP(
  (flatten): Flatten(start_dim=2, end_dim=-1)
  (unflatten): Unflatten(dim=-1, unflattened_size=(3, 3))
  (relu_stack): Sequential(
    (0): Linear(in_features=9, out_features=9, bias=True)
    (1): ReLU()
    (2): Sigmoid()
    (3): ReLU()
    (4): Sigmoid()
    (5): ReLU()
    (6): Sigmoid()
    (7): ReLU()
    (8): Sigmoid()
    (9): ReLU()
    (10): Sigmoid()
    (11): Linear(in_features=9, out_features=9, bias=True)
  )
)


tensor([[[[-0.4223, -0.3669,  0.5406],
          [ 0.3375, -0.0517,  0.3179],
          [ 0.8510,  0.1866, -0.0635]]]], grad_fn=<ViewBackward0>)

In [85]:
# Training parameters
batch_size = 100
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [86]:
### Training Loop
for i in range(10000):

    # create random matrices
    x = torch.rand((batch_size,1,3,3))

    pred = model(x)
    print(pred.shape())
    id = torch.eye(3,3)
    inv_x = torch.inverse(x)

    # Compute loss between network and numerical solution
    loss = (pred*x - id).square().mean()

    # We store the model if it has the lowest fitness
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())

    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())

    # Print every i iterations
    if i % 100 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")

    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

New Best:  0.3415718376636505
It=0	 loss=3.416e-01	  weighted_average=3.416e-01	
New Best:  0.33730068802833557
New Best:  0.31803134083747864
New Best:  0.3142189681529999
New Best:  0.3023819327354431
New Best:  0.29980066418647766
New Best:  0.2942497432231903
New Best:  0.293241024017334
New Best:  0.2829401195049286
New Best:  0.28268536925315857
New Best:  0.2685202956199646
New Best:  0.262570858001709
New Best:  0.26035648584365845
New Best:  0.25847604870796204
New Best:  0.2578469514846802
New Best:  0.254860520362854
New Best:  0.2531417906284332
New Best:  0.2486710548400879
New Best:  0.24754352867603302
New Best:  0.24669446051120758
It=100	 loss=2.532e-01	  weighted_average=2.713e-01	
New Best:  0.24440936744213104
New Best:  0.24343113601207733
New Best:  0.2430737167596817
New Best:  0.24216103553771973
New Best:  0.24138802289962769
New Best:  0.2374332696199417
New Best:  0.23700568079948425
It=200	 loss=2.531e-01	  weighted_average=2.476e-01	
New Best:  0.2336723357

In [90]:
# Let's look at an example
m = torch.rand((1,1,3,3))
print(m)

tensor([[[[0.1449, 0.1482, 0.3835],
          [0.8750, 0.3568, 0.0071],
          [0.6605, 0.5584, 0.1736]]]])


In [91]:
torch.inverse(m)

tensor([[[[ 0.6938,  2.2544, -1.6248],
          [-1.7613, -2.7300,  4.0029],
          [ 3.0261,  0.2033, -0.9330]]]])

In [92]:
model(m)


tensor([[[[0.1996, 0.1979, 0.2019],
          [0.1992, 0.1996, 0.2007],
          [0.1992, 0.1987, 0.1985]]]], grad_fn=<ViewBackward0>)
