# 2. SHO Fitting in Pytorch

In [None]:
import sys
sys.path.append('../../')
sys.path.append('/home/ferroelectric/m3_learning/m3_learning/src')

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np

from m3_learning.nn.random import random_seed
from m3_learning.viz.style import set_style
from m3_learning.viz.printing import printer
from m3_learning.be.viz import Viz
from m3_learning.be.dataset import BE_Dataset
# from m3_learning.be.nn import SHO_Model, SHO_NN_Model, SHO_fit_func_nn

# from m3_learning.be.dataset import BE_Dataset
printing = printer(basepath = './../../../Figures/2023_rapid_fitting/')


set_style("printing")
random_seed(seed=42)

%matplotlib inline

# import matplotlib.pyplot as plt
# import numpy as np

# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader

# from scipy.signal import resample
# from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split 
# from sklearn.metrics import mean_squared_error

# from m3_learning.optimizers.AdaHessian import AdaHessian
# from m3_learning.nn.SHO_fitter.SHO import SHO_fit_func_torch
# from m3_learning.be.processing import convert_amp_phase, SHO_fit_to_array
# from m3_learning.util.preprocessing import global_scaler
# from m3_learning.nn.random import random_seed
# from m3_learning.nn.benchmarks.inference import computeTime
from m3_learning.util.file_IO import make_folder
from m3_learning.be.nn import SHO_fit_func_nn, SHO_Model
# from m3_learning.be.dataset import BE_Dataset
# from m3_learning.viz.style import set_style

# set_style("printing")

In [None]:
# import seaborn as sns
# # sns.set_theme(style="whitegrid")

# # Load the example tips dataset
# tips = sns.load_dataset("tips")

# # Draw a nested violinplot and split the violins for easier comparison
# sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker",
#                split=True, inner="quart", linewidth=1,
#                palette={"Yes": "b", "No": ".85"})
# sns.despine(left=True)

# plt.show()


## Loads Data

In [None]:
# Specify the filename and the path to save the file
filename = 'data_raw.h5'
save_path = './../../../Data/2023_rapid_fitting'


data_path = save_path + '/' + filename

# instantiate the dataset object
dataset = BE_Dataset(data_path, resample_bins = 80,  SHO_fit_func_NN = SHO_fit_func_nn)

# print the contents of the file
dataset.print_be_tree()

## Testing the Torch Function

The function for a simple-harmonic oscillator needs to be recast in PyTorch. Here we prove that the PyTorch function is implemented identically to the Numpy model. 

Note: This uses the results from the least squares fitting LSQF results. 

In [None]:
true = {'fitter' : 'NN',
        'resampled' : False, 
        "label": "NN Fit"}

predicted = {'fitter' : 'LSQF',
             'resampled' : False, 
             "label": "Raw", 
             'scaled': False}

BE_viz = Viz(dataset, printing, verbose=True)


In [None]:
BE_viz.fit_tester(true, predicted, filename="Figure_7_PyTorch_fit_tester")

**Figure 7** Shows the result of the PyTorch function. The result based on the LSQF results shows that the PyTorch function is implemented correctly.

## Pytorch Model

### Model Architecture

### Scaling the Data

When training the neural network it is useful to scale the data. We apply a global scaler such that the spectrum have a mean of 0 and a standard deviation of 1.

#### Visualizing the Scaled Data

In [None]:
BE_viz = Viz(dataset, printing, verbose=True)


state = {'fitter' : 'LSQF',
             'resampled' : True,
             'scaled' : True,
             "label": "Scaled"}

BE_viz.nn_checker(state, filename="Figure_8_Scaled Raw Data")

**Figure 8** shows the scaled data. The data is scaled to have a mean of 0 and a standard deviation of 1. This is done using a global scaler of the entire spectrum. 

In [None]:
dataset.LSQF_phase_shift = np.pi/2

BE_viz.SHO_hist(dataset.SHO_fit_results(),
                      filename="Figure_9_Phase_Shifted_Scaled_Histograms")

**Figure 9** shows the histograms of the scaled a) amplitude, b) resonance frequency, c) quality factor, and d) phase. Note there is a transformation applied to the phase.  

### Training the Model

We will train the model from scratch. Generally the model trains very well in a few epochs. This will take less than 5 minutes to train on a GPU.

In [None]:
random_seed(seed=42)

# instantiate the model
model = SHO_Model(dataset, training=True, model_basename='SHO_Fitter_original_data')

# constructs a test train split
X_train, X_test, y_train, y_test = dataset.test_train_split_(shuffle = True)


In [None]:
train = False

if train:
    # fits the model
    model.fit(dataset.X_train, 200, loss_func=CustomLoss(penalty=2, verbose=False, scale_factor=dataset.SHO_scaler.mean_[0]))
else:
    model.load("/home/ferroelectric/m3_learning/m3_learning/papers/2023_Rapid_Fitting/Trained Models/SHO Fitter/SHO_Fitter_original_data_model_epoch_5_train_loss_0.0414678600463958.pth")

### GPU Inference Speedtest

Here we show the speedtest for the GPU. This is done using the torch.cuda.synchronize() function. This is used to ensure that the GPU is done processing before the timer is stopped.

In [None]:
X_data, Y_data = dataset.NN_data()

model.inference_timer(X_data, batch_size=1000)

### Visualization of the Distribution of the NN Fit Results

It is useful to check the distribution of the scaled and unscaled fit results for the entire dataset, this will also allow us to add a correction for a phase shift (if necessary).

#### Unscaled Histograms of Neural Network Fit Results

In [None]:
dataset.NN_phase_shift = -3*np.pi/2

pred_data, scaled_param, parm = model.predict(X_train)

BE_viz.SHO_hist(parm, filename = "Figure_10_NN_Unscaled_Histograms")

**Figure 10** Calculated fitting parameters from the neural network.  Histograms of the unscaled a) amplitude, b) resonance frequency, c) quality factor, and d) phase. Note there is a transformation applied to the phase.  

#### Scaled Histograms of Neural Network Fit Results

In [None]:
BE_viz.SHO_hist(scaled_param, filename = "Figure_11_NN_Scaled_Histograms")

**Figure 11** Calculated fitting parameters from the neural network.  Histograms of the scaled a) amplitude, b) resonance frequency, c) quality factor, and d) phase. Note there is a transformation applied to the phase.  

### Model Validation

It is helpful to view reconstructions of the data from the training and validation datasets. This ensures that the model is doing a good job of fitting the data.

#### Random Training Data Fit

In [None]:
BE_viz.nn_validation(model, X_train, 
                     filename = "Figure_11_NN_Validation_example_training", 
                     SHO_results = y_train)

**Figure 11** A random reconstruction of the neural network fits for the training dataset

In [None]:
BE_viz.nn_validation(model, X_test, 
                     filename = "Figure_12_NN_Validation_example_test", 
                     SHO_results = y_test)

**Figure 12** A random reconstruction of the neural network fits for the training dataset

In [None]:
state = {"raw_format": "magnitude spectrum",}

BE_viz.set_attributes(**state)

BE_viz.best_median_worst_reconstructions(model, X_data, SHO_values=Y_data, filename="Figure_13_NN_Best_Median_Worst_Reconstructions_Training")


In [None]:
BE_viz.best_median_worst_reconstructions(model, X_test, SHO_values=y_test, filename="Figure 14_NN_Best_Median_Worst_Reconstructions_Testing")


Overall, the fit results are excellent for both the training and validation datasets.

## Comparison of NN and LSQF Results

It is useful to compare the NN and LSQF results. While generally, the LSQF results might be considered the ground truth, this is not really the case. It is unclear which fitting method is actually more precise and accurate. We conjecture through this analysis that the neural network is actually a more accurate and precise fitting method. 

This section will help to make this case. 

In [None]:
state = {'fitter' : 'LSQF',
             'resampled' : True,
             'scaled' : False, 
             "raw_format": "magnitude spectrum",}

BE_viz.set_attributes(**state)

LSQF_results = dataset.raw_spectra(fit_results = dataset.SHO_fit_results())

state = {'fitter' : 'LSQF',
             'resampled' : True,
             'scaled' : False, 
             "raw_format": "magnitude spectrum",}

BE_viz.set_attributes(**state)

raw_spectra = dataset.raw_spectra()

# ind, mse = BE_viz.best_median_worst_fit_comparison(LSQF_results, raw_spectra, 1, 1)

In [None]:
BE_viz.best_median_worst_fit_comparison()

In [None]:
self = BE_viz

# for the SHO curves it makes sense to determine the error based on the normalized fit results in complex form. 
state = {'fitter' : 'LSQF',
        'resampled' : False,
        'scaled' : True, 
        "raw_format": "complex",}

self.set_attributes(**state)

fit_results_compare = self.dataset.raw_spectra(fit_results = self.dataset.SHO_fit_results())

raw_SHO = self.dataset.raw_spectra()

index1, mse1, d1, d2 = SHO_Model.get_rankings(raw_SHO, fit_results_compare, n = 1)

In [None]:
index1

In [None]:
a = raw_SHO[0].reshape(-1, 165)[745386]
b = np.array(fit_results_compare[0].reshape(-1, 165)[745386])

a1 = raw_SHO[1].reshape(-1, 165)[745386]
b1 = np.array(fit_results_compare[1].reshape(-1, 165)[745386])

In [None]:
(np.mean((a-b)**2) + np.mean((a1-b1)**2))/2

In [None]:
X_data.shape

In [None]:
dataset.resampled_freq.shape

In [None]:
plt.plot(dataset.frequency_bin)
plt.plot(dataset.resampled_freq, 'o')

In [None]:
import matplotlib.pyplot as plt
plt.plot(
    X_data[745386,:,0])
# plt.plot(dataset.frequency_bin, a)
# plt.plot(dataset.frequency_bin, b)

In [None]:
state = {'fitter' : 'LSQF',
             'resampled' : True,
             'scaled' : True, 
             "raw_format": "complex",}

BE_viz.set_attributes(**state)

LSQF_results = dataset.raw_spectra(fit_results = dataset.SHO_fit_results())

In [None]:
np.array(d1).shape

In [None]:
d1[1][0].shape

In [None]:
plt.plot(d1[0][0])
plt.plot(d1[1][0])
mse1[0]

In [None]:
np.array(LSQF_results).shape.index(2)

In [None]:
np.array(LSQF_results).shape.index(2)

In [None]:
index, mse, c1, c2 = SHO_Model.get_rankings(LSQF_results, raw_spectra, n= 1, curves=True)

In [None]:
c1.shape

In [None]:
plt.plot(c1[0,0])
plt.plot(c2[0,0] , 'o')

In [None]:
plt.plot(c1[2,:,1])
plt.plot(c2[2,:,1] , 'o')

In [None]:
[tensor.numpy() for tensor in LSQF_results]

In [None]:
import matplotlib.pyplot as plt

plt.plot(LSQF_results[0][100,100,:])
plt.plot(raw_spectra[0][100,100,:])


In [None]:
[0### LSQF NN compairison resonstruction

### LSQF NN comapairison distributions

### LSQF NN comparison movies.'

In [None]:
import pandas as pd

df = pd.DataFrame()

true = params_test
compare = dataset.nn_validation_params_scaled

names = [true, compare]
names_str = ['SHO', 'NN']
labels = ['Amplitude', 'Resonance', 'Q-Factor', 'Phase']

for j, name in enumerate(names):
    for i, label in enumerate(labels):
        dict_ = {"value": name[:,i],
                 "parameter": np.repeat(label, name.shape[0]),
                 "dataset": np.repeat(names_str[j], name.shape[0])}
        
        df = pd.concat((df, pd.DataFrame(dict_)))


In [None]:
import seaborn as sns

df.head()
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), sharey=True)

sns.violinplot(x='parameter', y='value', hue='dataset',
               data=df, ax=axs, scale='count', split=True, inner='quartile')

plt.show()