# Neural Network Fitting and Benchmarking

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
from m3util.viz.printing import printer
from m3util.ml.optimizers.TrustRegion import TRCG
from m3util.viz.style import set_style
from m3util.ml.rand import set_seeds
from belearn.dataset.dataset import BE_Dataset
from belearn.functions.sho import SHO_nn
from belearn.nn.nn import BatchTrainer
from belearn.nn.inference import BEInference
from datafed_torchflow.datafed import DataFed
from datafed_torchflow.pytorch import TorchViewer
from datetime import datetime

from autophyslearn.postprocessing.complex import ComplexPostProcessor
from autophyslearn.spectroscopic.nn import Multiscale1DFitter, Model


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

# Specify the filename and the path to save the file
filename = "data_raw.h5"
save_path = "./Data"


optimizer_TR = {"name": "TRCG", "optimizer": TRCG, "radius": 5, "device": "cuda", "ADAM_epochs": 2}
optimizers = [ 'Adam', optimizer_TR]
noise_list = [0, 1, 2, 3, 4, 5, 6, 7, 8]
batch_size = [500, 1000, 5000, 10000]
epochs = [5]
seed = [41, 43, 44, 45, 46]
early_stopping_time = 60*3
basepath_postfix = 'nn_benchmarks_noise'

# Original filename
csv_name = 'nn_benchmarks_noise.csv'

printing = printer(basepath='./Figures/')

set_style("printing")
set_seeds(seed=42)

data_path = save_path + "/" + filename

In [None]:
# instantiate the dataset object
dataset = BE_Dataset(data_path, SHO_fit_func_LSQF=SHO_nn, datafed = "2024_SHO_Fitting/Training_Benchmarks_NN_SHO_9_22_2024")

# print the contents of the file
dataset.print_be_tree()

In [None]:
batch_training = True

# Get the current date and time
current_datetime = datetime.now()

# Format the date and time in a 'pretty' format (e.g., YYYY-MM-DD_HH-MM-SS)
formatted_datetime = current_datetime.strftime('%Y-%m-%d_%H-%M-%S')

basepath = f'{formatted_datetime}_{basepath_postfix}'

trainer = BatchTrainer(
    dataset=dataset,
    optimizers=optimizers,
    noise_list=noise_list,
    batch_size=batch_size,
    epochs=epochs,
    seed=seed,
    basepath=basepath,
    datafed_path="2024_SHO_Fitting/Training_Benchmarks_NN_SHO_9_22_2024",
    script_path=f"{os.getcwd()}/5_nn_fitting_all.ipynb",
    early_stopping_loss=None,
    early_stopping_count=None,
    early_stopping_time=early_stopping_time,
    skip=0,
    write_CSV="Batch_Trainging_SpeedTest.csv",
)

if batch_training == True:
    trainer.run_training(dataset) 

In [None]:
## torch_viewer = TorchViewer("2024_SHO_Fitting/Training_Benchmarks_NN_SHO")

pd = torch_viewer.getModelCheckpoints()

pd.head()

In [None]:
;'
' inference_ = False

if inference_: 
    inference = BEInference(pd, dataset, 
                df_api=DataFed("2024_SHO_Fitting/Training_Benchmarks_NN_SHO"), 
                root_directory="./Trained Models")
    
    inference.run()

In [None]:
torch_viewer = TorchViewer("2024_SHO_Fitting/Training_Benchmarks_NN_SHO")

pd = torch_viewer.getModelCheckpoints()

pd.head()