In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from collections import deque

In [None]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3,padding="same")
        self.conv2 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv3 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv4 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv5 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv6 = nn.Conv2d(64, 1, 3,padding="same")

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        return x

model = Net()
print(model)

In [None]:
# Training parameters
batch_size = 100
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [None]:
### Training Loop
for i in range(10000):
    
    # create random matrices
    x = torch.rand((batch_size,1,3,3))
        
    pred = model(x)
    
    inv_x = torch.inverse(x)
    
    # Compute loss between network and numerical solution
    loss = (pred - inv_x).abs().mean() / inv_x.max()
    
    # We store the model if it has the lowest fitness 
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())
    
    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())
    
    # Print every i iterations
    if i % 100 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")
        
    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()    
  
    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

In [None]:
# Let's look at an example
m = torch.rand((1,1,3,3))
print(m)

In [None]:
torch.inverse(m)

In [None]:
model(m)