In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
from ConvNet import ConvNet
from reluMLP import MLP
from UNET import UNET
torch.random.seed =1234

In [None]:
def generate_matrix(N,eigen_values = None, mu = 1,sigma = 0.2,d = 3,similar = True):
    # Generate N invertible matrices of dimension d
    # Returns a tensor as default.  
    # Generate eigenvalues of matrices of reasonable size, close to eachother
    # Allows specifying the eigenvalues of the matrices by passing 
    if eigen_values != None:
        x = eig.repeat(N,1)
        x = x[:,:,None]
    else:
        x = mu*torch.ones((N,d,1)) + sigma*torch.rand(N,d,1) 
    
    #Create diagonal matrices 
    diag  = torch.eye(x.shape[1])*x[:,None] 

    #Transformation matrix for similarity transform 
    if not similar:
        #Creates matrices with different basis. 
        M = torch.rand(N,1,d,d)
        X = torch.matmul(torch.matmul(M,diag),torch.linalg.inv(M))
    #Do similarity transform with same basis
    else: 
        M = torch.rand(d,d)
        X = torch.matmul(torch.matmul(M,diag),torch.linalg.inv(M))
    
    #X = torch.Tensor(X)
    return X

In [None]:
# Training parameters
batch_size = 100
lr = 1e-2
momentum = 0.9
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)

# When a new network is created we init empty training logs
loss_log = []
test_loss_log = []
weighted_average_log = []
weighted_average = deque([], maxlen=100)


# And store the best results
best_loss = np.inf
best_model_state_dict = model.state_dict()

In [None]:
## Training data to (over)fit the model to just one (or more) similar matrices
x_train = generate_matrix(1000, similar = True)
x_eval = generate_matrix(100,similar = True)

#print(x_train[0])
#print(torch.linalg.inv(x_train)[0])
#model(x_train)

In [None]:
### Training Loop
id = torch.eye(3,3)
model = CNN
for i in range(10000):
    # create random matrices
    x = x_train #For over-fitting 
    #x = generate_matrix(batch_size) #for "full training set"
    pred = model(x)
    mean_cond = torch.linalg.cond(x).mean() 
    # Compute loss between network and numerical solution by how well it serves as an inverse
    
    # Normalize with condition number, dont now if it makes sense but difficult to get it to converge/very unstable otherwise
    loss = (torch.matmul(pred,x) - id).square().mean()#/mean_cond  #LOOK IF THIS SHOULD BE REDEFINED
    #loss = (torch.matmul(pred,x) - id).square().sum((2,3)).mean()/mean_cond #Frobenius norm
    #loss = (pred -torch.linalg.inv(x)).square().mean()
    
    # We store the model if it has the lowest fitness
    # (this is to avoid losing good results during a run that goes wild)
    if loss < best_loss:
        best_model_state_dict = model.state_dict()
        best_loss = loss
        #print('New Best: ', loss.item())

    # Update the loss trend indicators
    weighted_average.append(loss.item())

    # Update the logs
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())
    

    # Print every i iterations
    if i % 100 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t")
        test_loss_log.append((torch.matmul(model(x_eval),x_eval) - id).square().mean().item())

    # Zeroes the gradient (necessary because of things)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

### Some preliminary results from training. For more eval, see Model_Evaluation.ipynb

In [None]:
## Inspect if the model overfitted/how it trained on the data.
pred = model(x_train)
#print(torch.inverse(X_data) - pred ) 
print(pred.shape)
print("Approximation of identity matrix:\n ", torch.matmul(pred,x_train)) #Should be the identity matrix.
print("loss = " , (torch.matmul(pred,x) - id).square().mean())

fig = plt.figure()
ax = plt.gca()
plt.plot(loss_log, label = "Training loss")
plt.plot(np.arange(1,10000,100), test_loss_log, label= "Test loss")

#ax.set_yscale('log')
ax.set_xscale('log')
ax.legend()
fig.suptitle('Loss log', fontsize=18)
plt.xlabel('Iteration', fontsize=18)
ylab = plt.ylabel('Loss', fontsize=16)

In [None]:
# Lets try to see how it generalizes, although still generated in the same manner 
x_test = generate_matrix(100, similar = True)
inv_pred = model(x_test)
torch.matmul(inv_pred,x_test)
test_loss =(torch.matmul(inv_pred,x_test) - id).square().sum((2,3)).detach().numpy()

print("Approximation of identity matrix:\n ",torch.matmul(inv_pred,x_test)[1])
print("loss = ", (torch.matmul(inv_pred,x_test) - id).square().mean().item())

print("MSE = ", (torch.linalg.inv(x_test) - inv_pred).square().mean().item())
fig = plt.figure()
ax = plt.gca()
plt.scatter(torch.linalg.cond(x_test),test_loss)
ax.set_yscale('log')
ax.set_xscale('log')
fig.suptitle('Preliminary evaluation', fontsize=18)
plt.xlabel('Condition number', fontsize=18)
ylab = plt.ylabel('Test loss', fontsize=16)

In [None]:
#torch.save(model.state_dict(), "CNN_mu1sigma02_similar.pt")
#torch.save(model.state_dict(),"MLP_mu1sigma02_similar.pt")