In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

In [None]:
def generate_matrix(N,mu = 1,sigma = 0.2,d = 3,similar = True):
    # Generate N invertible matrices of dimension d
    # Returns a tensor as default.  
    # Generate eigenvalues of matrices of reasonable size, close to eachother
    
    #Fix seed? 
    #random.seed(1234) 
    x = mu*np.ones((N,d,1), dtype = np.float64) + sigma*np.random.randn(N,d,1) 
    
    #Create diagonal matrices 
    diag  = np.eye(x.shape[1])*x[:,np.newaxis] 

    #Transformation matrix for similarity transform 
    if not similar:
        #Creates matrices with different basis. 
        M = np.random.randn(N,1,d,d)
        X = np.matmul(np.matmul(M,diag),np.linalg.inv(M))
    #Do similarity transform with same basis
    else: 
        M = np.random.randn(d,d)
        X = np.matmul(np.matmul(M,diag),np.linalg.inv(M))
    
    X = torch.Tensor(X)
    return X

### Comment on matrix generation: 
Maybe instead one should work with deterministic eigenvalues to begin with to have complete control, and only introduce diversity in the data set via the similarity transform. 
Alternatively, and perhaps especially for similar matrices, one can use the same fixed transformation matrix when generating. Feels like could make online sampling easier. 



In [None]:
#Simple MLP
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten(start_dim=2)
        self.unflatten = nn.Unflatten(-1,(3,3))
        self.relu_stack = nn.Sequential(
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),
                nn.ReLU(),
                nn.Linear(9,9),)
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu_stack(x)
        x = self.unflatten(x)
        return x
model = MLP()

In [None]:
#Simple MLP
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.norm = nn.BatchNorm2d(64)
        self.conv1 = nn.Conv2d(1, 64, 3,padding="same")
        self.conv2 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv3 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv4 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv5 = nn.Conv2d(64, 64, 3,padding="same")
        self.conv6 = nn.Conv2d(64, 1, 3,padding="same")
       
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.norm(x)
        x = F.relu(self.conv2(x))
        x = self.norm(x)
        x = F.relu(self.conv3(x))
        x = self.norm(x)
        x = F.relu(self.conv4(x))
        x = self.norm(x)
        x = F.relu(self.conv5(x))
        x = self.conv6(x)
        return x

#model = ConvNet()

In [None]:
#UNET Architecture, not really working for 3x3. 
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET,self).__init__()
        self.conv1 = self.contract_block(in_channels, 32, 3,1)
        #self.conv2 = self.contract_block(32, 64, 3, 1)
        #self.conv3 = self.contract_block(64, 128, 3, 1)
        #self.upconv3 = self.expand_block(128, 64, 3, 1)
        #self.upconv2 = self.expand_block(64, 32, 3, 1)
        self.upconv1 = self.expand_block(32, out_channels, 3, 1)
    def forward(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        #conv2 = self.conv2(conv1)
        #conv3 = self.conv3(conv2)

        #upconv3 = self.upconv3(conv3)
        #upconv2 = self.upconv2(conv2)
        #upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        #upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))
        upconv1 = self.upconv1(conv1)
        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
                                 )
        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2, padding=1, output_padding=1)
                            )
        return expand
#model = UNET(1,1)
#print(model)

In [None]:
# Training parameters
batch_size = 100
lr = 1e-1
momentum = 0.9
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [None]:
## Training data to (over)fit the model to just one (or more) similar matrices
x_train = generate_matrix(1000, similar = True)
x_zero = torch.zeros(1,1,3,3)
print(x_train[0])
print(torch.linalg.inv(x_train)[0])

In [None]:
### Training Loop
id = torch.eye(3,3)

for i in range(10000):
    # create random matrices
    x = x_train #For over-fitting 
    #x = generate_matrix(batch_size) #for "full training set"
    pred = model(x)
    mean_cond = torch.linalg.cond(x).mean() 
    # Compute loss between network and numerical solution by how well it serves as an inverse
    
    # Normalize with condition number, dont now if it makes sense but difficult to get it to converge/very unstable otherwise
    loss = (torch.matmul(pred,x) - id).square().mean()#/mean_cond  #LOOK IF THIS SHOULD BE REDEFINED
    #loss = (torch.matmul(pred,x) - id).square().sum((2,3)).mean()/mean_cond #Frobenius norm
    #loss = (pred -torch.linalg.inv(x)).square().mean()
    
    # We store the model if it has the lowest fitness
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        #print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())

    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())

    # Print every i iterations
    if i % 100 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")

    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

### Some preliminary results from training. For more eval, see Model_Evaluation.ipynb

In [None]:
## Inspect if the model overfitted/how it trained on the data.
pred = model(x_train)
#print(torch.inverse(X_data) - pred ) 
print(pred.shape)
print("Approximation of identity matrix:\n ", torch.matmul(pred,x_train)) #Should be the identity matrix.
print("loss = " , (torch.matmul(pred,x) - id).square().mean())

fig = plt.figure()
ax = plt.gca()
plt.plot(loss_log)
#ax.set_yscale('log')
ax.set_xscale('log')
fig.suptitle('Loss log', fontsize=18)
plt.xlabel('Iteration', fontsize=18)
ylab = plt.ylabel('Training loss', fontsize=16)

In [None]:
# Lets try to see how it generalizes, although still generated in the same manner 
x_test = generate_matrix(100, similar = True)
inv_pred = model(x_test)
torch.matmul(inv_pred,x_test)
test_loss =(torch.matmul(inv_pred,x_test) - id).square().sum((2,3)).detach().numpy()

print("Approximation of identity matrix:\n ",torch.matmul(inv_pred,x_test)[1])
print("loss = ", (torch.matmul(inv_pred,x_test) - id).square().mean().item())

print("MSE = ", (torch.linalg.inv(x_test) - inv_pred).square().mean().item())
fig = plt.figure()
ax = plt.gca()
plt.scatter(torch.linalg.cond(x_test),test_loss)
ax.set_yscale('log')
ax.set_xscale('log')
fig.suptitle('Preliminary evaluation', fontsize=18)
plt.xlabel('Condition number', fontsize=18)
ylab = plt.ylabel('Test loss', fontsize=16)

In [None]:
#torch.save(model.state_dict(), "CNN_mu1sigma02_similar.pt")
#torch.save(model.state_dict(),"MLP_mu1sigma02_similar.pt")