In [139]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from collections import deque

In [140]:
def generate_matrix(N,d = 3, tensor = True, similar = True):
    # Generate N invertible matrices of dimension d with small condition number. 
    # Returns a tensor as default.  
    
    #Generate eigenvalues of matrices of reasonable size, close to eachother
    x = 2*np.ones((N,d,1), dtype = np.float64) - 0.5*np.random.randn(N,d,1) 
    
    #Create diagonal matrices 
    diag  = np.eye(x.shape[1])*x[:,np.newaxis] 
    
    #Transformation matrix for similarity transform 
    if not similar:
        #Creates matrices with different basis.
        M = np.random.randn(N,d,d)
        X = np.dot(np.dot(M,diag),np.linalg.inv(M))
    #Do similarity transform with same basis
    M = np.random.randn(d,d)
    X = np.matmul(np.matmul(M,diag),np.linalg.inv(M))
    if tensor:
        X = torch.Tensor(X)
    return X


In [141]:
#Simple MLP
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten(start_dim=2)
        self.unflatten = nn.Unflatten(-1,(3,3))
        self.relu_stack = nn.Sequential(
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
        )
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu_stack(x)
        x = self.unflatten(x)
        return x
model = MLP()


In [168]:

class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET,self).__init__()
        self.conv1 = self.contract_block(in_channels, 32, 3,1)
        #self.conv2 = self.contract_block(32, 64, 3, 1)
        #self.conv3 = self.contract_block(64, 128, 3, 1)
        #self.upconv3 = self.expand_block(128, 64, 3, 1)
        #self.upconv2 = self.expand_block(64, 32, 3, 1)
        self.upconv1 = self.expand_block(32, out_channels, 3, 1)
    def forward(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        #conv2 = self.conv2(conv1)
        #conv3 = self.conv3(conv2)

        #upconv3 = self.upconv3(conv3)
        #upconv2 = self.upconv2(conv2)
        #upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        #upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))
        upconv1 = self.upconv1(conv1)
        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2, padding=1, output_padding=1)
                            )
        return expand
model = UNET(1,1)
print(model)

UNET(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (upconv1): Sequential(
    (0): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)


In [169]:
# Training parameters
batch_size = 100
lr = 1e-2
momentum = 0.9
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [170]:
## Training data to (over)fit the model 
X_data = generate_matrix(1)
x = X_data
#x
#model(x)
print(model(x))


tensor([[[[-0.0995, -0.0995, -0.0995],
          [-0.1553,  0.1004, -0.2278],
          [-0.0988,  0.1621, -0.0978]]]],
       grad_fn=<SlowConvTranspose2DBackward0>)


In [171]:
### Training Loop
id = torch.eye(3,3)
for i in range(10000):
    # create random matrices
    #x = generate_matrix(batch_size, similar = False)
    pred = model(x)

    # Compute loss between network and numerical solution by how well it serves as an inverse
    loss = (torch.matmul(pred,x) - id).square().mean()

    # We store the model if it has the lowest fitness
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        #print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())

    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())

    # Print every i iterations
    if i % 1000 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")

    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

It=0	 loss=3.523e-01	  weighted_average=3.523e-01	
It=1000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=2000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=3000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=4000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=5000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=6000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=7000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=8000	 loss=1.905e-02	  weighted_average=1.905e-02	
It=9000	 loss=1.905e-02	  weighted_average=1.905e-02	


In [172]:
## Test to see if the model overfitted 
pred = model(X_data)
#print(torch.inverse(X_data) - pred ) 
#print(pred)
print("Approximation of identity matrix:\n ", torch.matmul(pred,X_data)) #Should be the identity matrix.
print("loss = " , (torch.matmul(pred,x) - id).square().mean())

Approximation of identity matrix:
  tensor([[[[ 1.0168,  0.1736,  0.1448],
          [ 0.0377,  1.0086,  0.0271],
          [-0.2741, -0.0629,  0.8030]]]], grad_fn=<UnsafeViewBackward0>)
loss =  tensor(0.0191, grad_fn=<MeanBackward0>)


In [152]:
# Lets try to see how it generalizes, although still generated in the same manner 
x_test = generate_matrix(2, similar = True)
inv_pred = model(x_test)
torch.matmul(inv_pred,x_test)

print("Approximation of identity matrix:\n ",torch.matmul(inv_pred,x_test))
print("loss = ", (torch.matmul(inv_pred,x_test) - id).square().mean().item())

print("MSE = ", (torch.linalg.inv(x_test) - inv_pred).square().mean().item())

Approximation of identity matrix:
  tensor([[[[ 1.4844,  0.1928,  1.8354],
          [ 0.0063,  1.1366, -0.5787],
          [-0.0851, -0.1782,  1.0838]]],


        [[[ 1.2463, -0.1116,  0.7568],
          [ 0.0864,  1.2187,  1.1738],
          [-0.0723, -0.0809,  0.9094]]]], grad_fn=<UnsafeViewBackward0>)
loss =  0.3410756289958954
MSE =  0.06486474722623825


In [None]:
# What if we change the structure even more? 
x_test = torch.Tensor(np.random.rand(1,1,3,3))
print(x_test)
inv_pred = model(x_test)
torch.matmul(inv_pred,x_test)

print("Approximation of identity matrix:\n ",torch.matmul(inv_pred,x_test))
print("loss = ", (torch.matmul(inv_pred,x_test) - id).square().mean())

print("MSE = ", (torch.linalg.inv(x_test) - inv_pred).square().mean())