In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

#Needed for training and evaluation
from losses import *
from RandomMatrixDataSet import get_sample,RandomMatrixDataSet
from train import train_on_batch, run_training
from evaluation import *
from plotting import plot_loss_logs, error_histogram, plot_mean_identity_approx

#Seed and looks
torch.random.seed = 1234
plt.rcParams['figure.figsize'] = [14, 6]
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})

#Models
from ConvNet import EigConvNet
from MLP import EigMLP
from nerf import EigNERF
from siren import EigSiren

In [None]:
#Define operation
operation = torch.linalg.eig

def rayleigh_quotient(pred,x):
    """ This is not correct yet. """ 
    norm = torch.linalg.norm(pred)**2
    return torch.matmul(pred,torch.matmul(pred,x.transpose(2,3)).transpose(2,3))/norm

### Definitions and instantiation 

In [None]:
d = 4 #Matrix dimension

# Define model(s) hyperparameters
output_features = d
hidden_layers = 3
hidden_features = 100
kernel_size = 3
filters = 32
skip = [2,4,6]

In [None]:
#Initate some example models
reluMLP = EigMLP(d,d**2,hidden_layers, hidden_features)
CNN = EigConvNet(d,hidden_layers,filters,kernel_size)
SIREN = EigSiren(d,hidden_features, hidden_layers)
nerf = EigNERF(d,d**2,skip,hidden_features)

In [None]:
# Choose model
model = nerf

In [None]:
# Training parameters
loss_fcn = eigval_error 
lr = 3e-4
momentum = 0.9
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)
k = 2000 #Training iterations 

### Run training

In [None]:
batch_size = 100
epoch = 10
#With this setting, dxd random symmetric matrices are generated of specified batchsizes. 
matrix_parameters = {"N": batch_size,
                    "d": d, 
                    "operation": operation,
                     "det": False, 
                    "det_channel": False,
                     "symmetric": True}
#Run training

trained_model,loss_log, weighted_average_log,eval_loss_log, eval_set = run_training(k,model,loss_fcn,optimizer, matrix_parameters,scheduler,epoch)

In [None]:
#Generate test set and predictions on test set 
eval_size = 1000
test_parameters = {"N": eval_size,
                    "d": d, 
                    "operation": operation,
                    "det": False,
                    "det_channel":False, 
                    "symmetric": True}


test_set = get_test_set(test_parameters)
test_set.compute_labels()
eigvals = torch.sort(torch.real(test_set.Y[0]),2)[0]
predicted_eigvals = trained_model(test_set.X)

### Plot some results

In [None]:
#Will add this to plotting function eventually
m = {"errors": (predicted_eigvals- eigvals).square().squeeze().detach().numpy()[:,:],
     "mean_error": eigval_error(predicted_eigvals,eigvals).mean().detach().numpy()}

                         
plt.rcParams['figure.figsize'] = [18, 12]
fig = plt.figure()
spec = fig.add_gridspec(2, 2)
fig.suptitle("Eigenvalue approximation of $ {} \\times {}$ random symmetric matrices \n Model: NERF  ({} layers, {} neurons)".format(d,d,8,hidden_features), fontsize = 25)
#fig.suptitle("Eigenvalue approximation of $ {} \\times {}$ random symmetric matrices \n Model: ConvNet  ({} layers, {}  kernels  $ {} \\times {}$ )".format(d,d,hidden_layers,filters, kernel_size, kernel_size), fontsize = 25)


### Plot training "dynamics"
ax = fig.add_subplot(spec[0, 0])
plot_loss_logs(ax,loss_log,weighted_average_log,eval_loss_log,k)


### Plot eigval samples
ax = fig.add_subplot(spec[0, 1])
ax.scatter(np.arange(1,d+1), eigvals[0].detach().numpy(),s =200, color='pink', marker='o', label = "Exact")
ax.scatter(np.arange(1,d+1), predicted_eigvals[0].detach().numpy(), s = 200, color='blue', marker='x', label = " Approximated " )

ax.set_title("Sample eigenvalue approximation", fontsize = 18)
ax.legend(fontsize = 16, loc = "best")
ax.set_xticks(np.arange(0, d, 1) + 1)
ax.set_xlabel("\# Eigenvalue ", fontsize = 16)

###
ax = fig.add_subplot(spec[1, :])
legend = "$\lambda_{min} $ "
ax.hist(np.log10(m["errors"][:,0]), bins = 40, color = 'pink', edgecolor='black', alpha=0.65, label = legend )
legend = "$\lambda_{max} $ "
ax.hist(np.log10(m["errors"][:,3]), bins = 40, edgecolor='black', alpha=0.65, label = legend )
ax.axvline(np.log10(m["mean_error"]), color='r', linestyle='dashed', linewidth=1, label = "Mean test error")
ax.set_xlabel("$log_{10}(error)$", fontsize = 16)
ax.set_ylabel("Frequency", fontsize = 16)
ax.legend(fontsize = 16, loc = "upper center")
min_ylim, max_ylim  =  ax.get_ylim()
ax.text(np.log10(m["mean_error"])*0.95, max_ylim*0.8, 'Mean: {:.4f}'.format(m["mean_error"]), fontsize = 14)
ax.set_title("Evaluation results on test set, N = {} ".format(eval_size), fontsize = 20);
#plt.savefig("eigval_10by10_Nerf.png")