## Train a convolutional neural network regressor to predict selected material properties

#### Overview: This notebook provides for training a convolutional neural network to predict selected homogenized material properties. 
#### <u> Discussion</u>: The neural network architecture is modular, allowing the user to select which material properties a model instance will be trained to predict. It also has varying number of property prediction modules, which are fully connected linear neural networks that learn to predict material properties based on the feature vector that the convolutional layers produce. The feature vector is weighted using a linear self-attention layer. 
![Diagram](./media/CNN3D_diagram.png)

#### __Configuration:__ The modules can be set to predict one or more material properties in custom order or grouping. The trained model provided in this package and made available for inference in folder 1_2 predicts all material properties using eight modules that each predict a set of grouped properties.




Sections are as follows:
1. Preliminaries - set filepaths and parameters, import libraries
2. Setup database - training/validation/testing splits
3. Train Model
4. Generate predictions and plot

# 1. Preliminaries

In [1]:
import torch

In [None]:
# For setting the directory references for the entire package
from ML_workflow_utils_v3.PackageDirectories import PackageDirectories as PD   

# This code automatically sets the rootpath as the directory the entire package is contained in, which is then called to initialize the PackageDirectories class below
import os
# check current path if desired
# currentpath = os.getcwd()
# print(currentpath)

os.chdir('../../../')
rootpath = os.getcwd()
# print(rootpath)

# Alternately, rootpath can be set manually
# rootpath = 'filepath/containing/entire/ML_package/'

directory = PD(rootpath = rootpath)

In [2]:
# Directory of the entire ML package
pkgpath = directory.pkgpath

## Model configuration - material properties

#### <u>Discussion</u>: The material properties that the model will be trained to predict are set below using two variables, `matprops` (list) and `matprops_by_module` (list of lists)


For example, predicting volume fraction, CH_11, CH_22, vH_12, vH_13, kappaH_11 in four separate modules would be configured as follows:

`matprops = ['volFrac', 'CH_11 scaled', 'CH_22 scaled', 'vH_12 scaled', 'vH_13 scaled', 'kappaH_11 scaled']`

`matprops_by_module = [['volFrac'], ['CH_11 scaled', 'CH_22 scaled',], ['vH_12 scaled', 'vH_13 scaled'], ['kappaH_11 scaled']]`

The order of the material properties is maintained, and `matprops_by_module` is an input to the call to instantiate the model. Its number of elements defines the number of modules and the length of each sub-list defines the output dimension of the module.

__Note:__ Due to the model's ability to separate material properties, using *scaled* values is not necessary. However, if desiring to use unscaled properties, we recommend grouping similar properties together to avoid confounding due to scale mismatch.


In [5]:
#### Material properties #####

matprops = ['volFrac', 'CH_11 scaled', 'CH_22 scaled', 'vH_12 scaled', 'vH_13 scaled', 'kappaH_11 scaled']

matprops_by_module = [['volFrac'], ['CH_11 scaled', 'CH_22 scaled',], ['vH_12 scaled', 'vH_13 scaled'], ['kappaH_11 scaled']]


num_props = len(matprops)

In [6]:
"""
set    file_suffix    to indicate the material properties being predicted, differentiating between different models
we recommend setting it to a readable abbreviation

"""
# for the properties selected above
file_suffix = 'vfC12v1213kap11'

# fname_base = "filename base" - This is used as the base for model checkpoints and training history files
fname_base = f'CNN3D_{file_suffix}'

In [7]:
# Path for saving outputs of the notebook
# nbpath is the directory of this code notebook
nbpath = directory.nb_1_1_path

In [8]:
import pandas as pd
import numpy as np
import json
import glob
import os
import sys

# PyTorch deep learning library
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

import torch.utils.data as data
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim

# Will automatically select GPU acceleration if hardware is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# For plotting
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from itertools import cycle
from plotly.colors import sequential, qualitative


In [9]:
# Custom classes and functions
from ML_workflow_utils_v3.Dataset_Preprocessor import Dataset_Preprocessor as DataP
from ML_workflow_utils_v3.VoxelDataset import VoxelDataset

# 2. Load data
#### Our updated dataset handling class, `Dataset_Preprocessor`, does not require pre-splitting the data into training/validation/test splits. It operates by loading the prepared material property database and splitting it, storing them as attributes of the class. We have indexed our voxel topology arrays using a part numbering scheme, which enables us to maintain every file (in compressed numpy format, .npz) in the same directory. During training, we use a custom dataset based on the PyTorch Dataset class that loads and formats the voxel topologies to then assemble them into batches. This approach is more memory-efficient as well - where previously the dataset splits were small enough to fit in memory, they were still quite large. Now, only small batches of arrays are loaded into memory.

#### `Dataset_Preprocessor` maintains our approach of removing entire sets of topologies from the training set, so the model's performance is evaluated on data that it was not trained on. The set of topologies can be set randomly using a random seed, or using a custom-defined list.

In [10]:
# For displaying the number of voxel meshes in each topology family
# for topfam in datasplits.matpropcsv['topology_family'].unique():
#     subdf = datasplits.matpropcsv[datasplits.matpropcsv.topology_family == topfam]
#     print(len(subdf.cell_type.unique()))

In [11]:
# Define the following variables from the directory instance to feed to the DatasetPreprocessor instance:
data_source_directory = directory.source_data_path
meshdir = directory.voxeltopo_path

# set seed for test set so group is replicable
testset_seed = seed = 17
volfrac_range = (0.01, 0.98)

csv_fn = 'topology_multiphysics_database_by_partno.csv'

datasplits = DataP(csv_fn = csv_fn, data_source_directory = data_source_directory, meshdir = meshdir, volfrac_range=volfrac_range)

"""
The *length* of    test_set_topo_counts    must match the number of topologies in the test split.
If the test set is chosen randomly, each number in test_set_topo_counts corresponds to the number of topologies from the respective topology family that will be in the test set.

By default, every topology family is sampled from to construct the test set. If fewer are desired, create a custom list for topfam_sampling_set
and construct test_set_topo_counts to reflect the number of topology families and number of topologies per family. For example

topo_families = ['lattice', 'tpms', 'topopt']
test_set_topo_counts = [3, 1, 2]


Further configuration options for Dataset_Preprocessor are available in the .py file
"""

# Define list of topology families for sampling the training split
topo_families = datasplits.matpropcsv['topology_family'].unique()
# returns list of topology families - ['interp', 'lattice', 'plttube', 'synth', 'topopt', 'tpms']
test_set_topo_counts = [1,2,2,2,2,2] # six topology families, two topologies from each family except for the interpolated ('interp') family



# Split the dataset with the call to this function
datasplits.TrainTestSplit(topfam_sampling_set=topo_families, test_set_topo_counts=test_set_topo_counts, translate=False, testset_seed=testset_seed)


In [None]:
# Inspect list of topologies placed in testing split
print(datasplits.testsubset)


### The splits are accessible as follows
`idx` signifies that these are the indices of the data points

`idxTr`  - Training

`idxVal` - Validation

`idxTe`  - Test

e.g. `datasplits.idxTe` is the Testing subset

In [None]:
# Inspect size of Training, Validation, and Testing splits

print(datasplits.idxTr.shape, datasplits.idxVal.shape, datasplits.idxTe.shape)


In [14]:
batch_size = 32


trdat = VoxelDataset(datasplits.idxTr, directory.voxeltopo_path, matprops)
trloader = DataLoader(trdat, batch_size = batch_size, shuffle=True)

valdat = VoxelDataset(datasplits.idxVal, directory.voxeltopo_path, matprops)
valloader = DataLoader(valdat, batch_size = batch_size, shuffle=True)

tedat = VoxelDataset(datasplits.idxTe, directory.voxeltopo_path, matprops)
teloader = DataLoader(tedat, batch_size = batch_size, shuffle=False)

# 3. Model training

In [15]:
from ML_workflow_utils_v3.Training_Utils import train, validate
from ML_workflow_utils_v3.CNN_Property_Predictor_Multimodule import MatProp_CNN3D_varmod

In [16]:
"""
This call instantiates the neural network class object contained in ParamNet.py
num_outputs = num_params >>> the number of parameters the model is trained to predict
    determines the length of its output vector
"""
cnn = MatProp_CNN3D_varmod(matprops_by_module).to(device)

In [17]:
# Loss function
lossfunc = torch.nn.L1Loss() # Mean Absolute Error
lossfunc_name = 'MAE'

# lossfunc = torch.nn.MSELoss() # Mean Squared Error can be used as well - comment out the lines above (add # at start or press ctrl/cmd + /) and uncomment this line and line below (delete #)
# lossfunc_name = 'MSE'

# Optimizer
optimizer = optim.Adam(cnn.parameters(), lr=0.001)

# Filepath of model checkpoint - specifies path where the model's weights are saved
cp_dir = os.path.join(nbpath, 'model_CPs')

# cp = 'checkpoint', which is the term for the model's saved state (trainable weights) for an epoch that has improved validation error
cp_name= f'{fname_base}_model_weights.pth'
best_weights_path = os.path.join(cp_dir, cp_name)

In [18]:
"""
These parameters establish the behavior of the training loop. The loop trains the model until 
validation loss does decrease greater than |earlystop_min_delta| for |patience| number of epochs.
I.e., if validation loss does not improve by more than 0.00075 for more than 150 training iterations,
training will stop.
"""

# Number of epochs to train a model
EPOCHS = 150


min_val_loss = float('inf')
best_val_loss = float('inf')
early_stop_counter = 0

# Number of training iterations after which to terminate training if performance does not improve by more than *earlystop_min_delta*
patience = 75
earlystop_min_delta = 0.00075 # change to 0.00005 or smaller if using MSE loss

# Initializes records of losses and best epoch for tracking model training
best_epoch = 0
train_losses = []
val_losses = []

epochs_completed = 0

# Path for training history JSONs

histdir = os.path.join(nbpath, 'training_history_JSONs')

In [None]:
try:           
    for epoch in range(EPOCHS):
        # Train the model
        train_loss = train(cnn, trloader, lossfunc, optimizer)

        # Calculate validation loss 
        val_loss = validate(cnn, valloader, lossfunc)

        
        # Save the model's weights if validation loss is improved
        improvement_delta = best_val_loss - val_loss
        
        if val_loss < best_val_loss:
            pct_improved = (best_val_loss - val_loss) / best_val_loss * 100
            print(f"Val loss improved from {best_val_loss:.5f} to {val_loss:.5f} ({pct_improved:.2f}% improvement) saving model state...")
            best_val_loss = val_loss
            torch.save(cnn.state_dict(), best_weights_path)  # Save model weights to file

            best_epoch = epoch

            # Save training history at each epoch where validation loss improves
            hist_dict = {f'train_loss {lossfunc_name}': train_losses, f'val_loss {lossfunc_name}': val_losses}
            histdict_name = f'{fname_base}_traininghist_placeholder_{epoch}.json'
            histpath = os.path.join(histdir, histdict_name)
            with open(f'{histpath}', 'w') as f:
                json.dump(hist_dict, f)
        else:
            print(f'Val loss did not improve from {best_val_loss:.5f}.')
            # early_stop_counter += 1  # Increment early stopping counter
            
        if improvement_delta > earlystop_min_delta:
            early_stop_counter = 0
        else:
            early_stop_counter +=1
            

        # Collect model training history
        train_losses.append(train_loss)
        val_losses.append(val_loss)
            
        # Check for early stopping
        if early_stop_counter >= patience:
            print(f'Validation loss did not improve for {early_stop_counter} epochs. Early stopping...')

            if device == torch.device('cpu'):
                model_weights = torch.load(f"{best_weights_path}", map_location=torch.device('cpu'))
            else:
                model_weights = torch.load(f"{best_weights_path}")
            cnn.load_state_dict(model_weights)
            # cnn.load_state_dict(torch.load(best_weights_path))
            print(f"Model best weights restored - training epoch {best_epoch}")
            break
        
        print(f'Epoch [{epoch+1}/{EPOCHS}]\tTrain Loss: {train_loss:.5f}\tValidation Loss: {val_loss:.5f}')


    # Load the best weights at end of training epochs
    if device == torch.device('cpu'):
        model_weights = torch.load(f"{best_weights_path}", map_location=torch.device('cpu'))
    else:
        model_weights = torch.load(f"{best_weights_path}")
    cnn.load_state_dict(model_weights)
    # cnn.load_state_dict(torch.load(best_weights_path))  # Load best model weights
    print(f'Training epochs completed, best model weights restored - epoch {best_epoch}')
    min_val_loss = best_val_loss

# Saves reloads model's weights from last saved checkpoint if 
except KeyboardInterrupt:
    hist_dict = {f'train_loss {lossfunc_name}': train_losses, f'val_loss {lossfunc_name}': val_losses}
    if device == torch.device('cpu'):
        model_weights = torch.load(f"{best_weights_path}", map_location=torch.device('cpu'))
    else:
        model_weights = torch.load(f"{best_weights_path}")
    cnn.load_state_dict(model_weights)
    # cnn.load_state_dict(torch.load(best_weights_path))

In [20]:
## Save model training history
hist_dict = {f'train_loss {lossfunc_name}': train_losses, f'val_loss {lossfunc_name}': val_losses}

histdict_name = f'{fname_base}_model_hist_{EPOCHS}epochs_{best_epoch}_best.json'
histpath = os.path.join(histdir, histdict_name)

with open(f'{histpath}', 'w') as f:
    json.dump(hist_dict, f)

# 4. Predictions and plotting

In [None]:
# Set model to evaluation mode
cnn.eval()

In [22]:
# Setting up columns for dataframe of test set parameters, predictions, residuals, and percent error. 

# idxcols is the list of column names to pull from the Training split dataframe.
idxcols = ['topology_family', 'cell_type', 'dim_idx', 'volFrac',]
idxcols.extend(matprops)
# remove any duplicates - mainly volFrac - with OrderedDict
from collections import OrderedDict
idxcols = list(OrderedDict.fromkeys(idxcols))


predcols = [f'Predicted {prop}' for prop in matprops]

In [23]:
# Creates the dataframe from the split dataframes contained in the DatasetCreator class instance
test_split_dataframe = datasplits.idxTe[idxcols]

#### Run inference on the test set, calculate errors by unit cell topology and material property and save the dataframe as a CSV

In [None]:
test_split_dataframe

In [25]:
"""
Runs each Voxel array in the Testing dataset through the trained CNN, produces a dataframe
of the predictions under the columns of "pred[icted] [parameter name]"
"""
predsarray_batched = []

for batch in teloader:
    arrays, param_vecs = batch

    with torch.no_grad():
        outputs = cnn(arrays.to(device))

    predsarray_batched.append(outputs.cpu().numpy())

predsarray = np.concatenate(predsarray_batched, axis=0)

predictions_dataframe = pd.DataFrame(predsarray, columns=predcols)

test_split_dataframe = test_split_dataframe.join(predictions_dataframe)


for col in predcols:
    par = col[10:]
    rescol = f'Residual {par}'
    test_split_dataframe[rescol] = test_split_dataframe[par] - test_split_dataframe[col]

    pctcol = f'Pct error {par[:-7]}'

    test_split_dataframe[pctcol] = (test_split_dataframe[col] - test_split_dataframe[par]) / test_split_dataframe[par] *100



In [None]:
# Set to true if reporting Mean Absolute Error
mae_loss = True

totalerrors = []
for col in predcols:
    par = col[10:]
    test_error = 0
    if mae_loss: # Calculates Mean Absolute Error
        test_error +=  test_split_dataframe[f'Residual {par}'].abs().sum() # take sum of absolute error

    else: # calculates Root Mean Squared Error
        test_error += (test_split_dataframe[f'Residual {par}']**2).sum()**0.5

    test_error /= len(test_split_dataframe) # Take mean

    if par == 'volFrac':
        print(f'Error of {par}: \t{test_error:.4f}')
    else:
        print(f'Error of {par[:-7]}: \t{test_error:.4f}')
    totalerrors.append(test_error)

totalerrors.insert(0, 'mean')


errorcols = ['cell_type']
errorcols.extend([i[10:] for i in predcols])
errordf = pd.DataFrame(columns = errorcols)
errordf['cell_type'] = test_split_dataframe.cell_type.unique()




for col in predcols:

    par = col[10:]

    errorseries = []

    for celltype in test_split_dataframe.cell_type.unique():

        celltypedf = test_split_dataframe[test_split_dataframe.cell_type == celltype]
        test_error = 0
        if mae_loss: # Calculates Mean Absolute Error
            test_error +=  celltypedf[f'Residual {par}'].abs().sum() # take sum of absolute error

        else: # calculates Root Mean Squared Error
            test_error += (celltypedf[f'Residual {par}']**2).sum()**0.5

        test_error /= len(celltypedf) # Take mean

        errorseries.append(test_error)

    errordf[par] = np.asarray(errorseries)

errordf.loc[len(errordf.index)] = totalerrors

cols = errordf.columns[2:]
means = []
for i in errordf.index:
    celltype_mean = errordf.loc[i][cols].mean()
    means.append(celltype_mean)
errordf['cell_type mean error'] = means

csvpath = os.path.join(nbpath, 'predictions_csvs_and_plots', f'{fname_base}_test_set_errors.csv')
errordf.to_csv(csvpath)


# Plot predictions vs. actual and Volume Fraction vs. Residual or % Error

In [None]:
"""
This cell produces plots of each Predicted parameter versus the actual value, as well as either residuals or percentage error.
Plots are:
(1) Predicted vs. actual parameter
(2) volume fraction vs residual or percentage error
"""

###################This section contains parameters for defining particulars of the plots ##################

# As shipped, *second_plot* can be of either residuals or percentage error 
second_plot = 'Residual'# or 'Pct error' 

width = 1400  # width of plot
height_incr = 450 # height of each subplot

# color plot markers by individual cell type or topology family
category =  'cell_type'# or 'topology_family'

# Name of plot to save
img_name = f'{fname_base}_predictions_plot_{second_plot}'
img_path = os.path.join(nbpath, 'predictions_csvs_and_plots', img_name)

# if desired to save the plot as a PNG, set to True
save_img = True

##########################

colors = cycle(qualitative.G10)


predsdf = test_split_dataframe

# Makes subplots equivalent to the number of parameters Predicted
fig = make_subplots(rows=len(matprops), cols=2, row_heights=[20 for i in range(len(matprops))], column_widths=[15,15],
                    column_titles = ['Predicted vs Actual', f'{second_plot}'.capitalize()],
                    row_titles = matprops)

# Iterates over rows
row=1
for param in matprops:
    # colors and colors2 definitions ensure that the predicted vs. actual and residuals plot the same colors with the same chosen category
    colors = cycle(qualitative.G10)
    colors2 = cycle(qualitative.G10)

    # Plots predictions by category as selected above
    for uc in predsdf[f'{category}'].unique():

        df = predsdf[predsdf[f'{category}']==uc]

        fig.append_trace(go.Scatter(x=df[f'Predicted {param}'], 
                                    y=df[f'{param}'], 
                                    mode='markers', marker=go.scatter.Marker(color=next(colors)), legendgroup='1',
                                    showlegend=True, name=uc), row=row, col=1, )

        
    # Plots second_plot by category, both as selected above
    for uc in predsdf[f'{category}'].unique():

        df = predsdf[predsdf[f'{category}']==uc]

        # df = predsdf[predsdf.cell_type==uc]
        fig.append_trace(go.Scatter(x=df['volFrac'], 
                                    y=df[f'{second_plot} {param}'],  
                                    #y=df[df[f'{second_plot} {param}'].between(-100,100)], # This can filter out extremely large percentage errors if desired
                                    mode='markers', marker=go.scatter.Marker(color=next(colors2)), legendgroup='2',
                                    showlegend=True, name=uc), row=row, col=2, )

    # Plots a line of X=Y on all Predicted-vs-actual plots
    x = np.linspace(predsdf['Predicted {param}'.format(param=param)].min(), predsdf['Predicted {param}'.format(param=param)].max(),100)
    y = x
    fig.add_trace(go.Scatter(x=x, y=y, name='Predicted = actual', legendgroup='1',marker=go.scatter.Marker(color=next(colors))), row=row, col=1 )

    # Updates figure with labels
    fig.update_xaxes(title_text="Predicted {param}".format(param=param), row=row, col=1)
    fig.update_yaxes(title_text="Actual {param}".format(param=param), row=row, col=1)
    fig.update_xaxes(title_text="Volume Fraction", row=row, col=2)
    fig.update_yaxes(title_text=f"{second_plot.capitalize()}", row=row, col=2)
    fig.update_layout(title_text = f'Predictions & {second_plot.capitalize()}</br>')

    # fig = go.Figure(data=go.Scatter(x=x, y=x**2))

    fig.update_xaxes(griddash='solid', minor_griddash="solid")
    fig.update_yaxes(griddash='solid', minor_griddash="solid")
    row+=1

fig.update_layout(
    title={
        'text': f'Plots of test data - multi-parameter prediction', #<br>{runnum} </br>
        'xref':"paper",
        'xanchor':'center',
        'x':0.5},
        height=height_incr*len(matprops), width=width,
        legend_tracegroupgap = 50*len(matprops))
fig.show()

if save_img:
    pngpath = os.path.join(nbpath, 'predictions_csvs_and_plots', f'{img_name}.png')
    fig.write_image(pngpath)
else:
    pass