# Training on multiple topologies
Eamon Whalen

In [1]:
import sys
import glob
import numpy as np
import pandas as pd
import altair as alt
from sklearn.model_selection import train_test_split

sys.path.append('./models')
from feastnetSurrogateModel import FeaStNet

sys.path.append('./readers')
from loadGhGraphs import loadGhGraphs

sys.path.append('./visualization')
from altTrussViz import plotTruss, interactiveErrorPlot

sys.path.append('./util')
from gcnSurrogateUtil import *

## 1. Load simulation data

In [2]:
trainSets ,valSets ,testSets = {}, {}, {}

doeFiles = np.sort(glob.glob("/home/ewhalen/projects/data/trusses/2D_Truss_v1.3/*1000.csv"))
for doeFile in doeFiles:
    designName = doeFile.split('/')[-1].split('_N')[0]
    print(f'loading {designName}')
    allGraphsUnfiltered = loadGhGraphs(doeFile, NUM_DV=5)
    allGraphs = filterbyDisp(allGraphsUnfiltered, 0.9)
    trainData, valData, testData = partitionGraphList(allGraphs)
    trainSets[designName] = trainData
    valSets[designName] = valData
    testSets[designName] = testData

loading design_5
loading design_6
loading design_7
loading design_8
loading design_9


## 2. Train on each group seperately

In [3]:
resultsList = []
saveDir = './results/topoTest01/'
epochs = 100

In [4]:
for trainName, trainSet in trainSets.items():
    print('training on '+trainName)

    # train
    gcn = FeaStNet()
    history = gcn.trainModel(trainSet, valSets[trainName], epochs=epochs, batch_size=256, flatten=True, logTrans=False, 
                             ssTrans=True, saveDir=saveDir+trainName)

    display(plotHistory(history))

    # test
    print('testing on '+trainName+'\n')
    resultsDict = gcn.testModel(testSets[trainName], level='field')
    resultsDict['Trained on'] = ['test group']*len(resultsDict['mse'])
    resultsDict['Tested on'] = [trainName]*len(resultsDict['mse'])
    results = pivotDict(resultsDict)
    resultsList.extend(results)
        
pd.DataFrame(resultsList)

training on design_5
epoch: 0   trainLoss: 1.0356e+00   valLoss:1.0439e+00  time: 1.29e+00
epoch: 1   trainLoss: 9.3205e-01   valLoss:1.0446e+00  time: 9.85e-01
epoch: 2   trainLoss: 9.2407e-01   valLoss:1.0541e+00  time: 9.83e-01
epoch: 3   trainLoss: 8.4189e-01   valLoss:1.0973e+00  time: 1.05e+00
epoch: 4   trainLoss: 8.3505e-01   valLoss:1.1607e+00  time: 9.54e-01
epoch: 5   trainLoss: 8.1768e-01   valLoss:1.1738e+00  time: 9.43e-01
epoch: 6   trainLoss: 7.4498e-01   valLoss:1.0955e+00  time: 9.45e-01
epoch: 7   trainLoss: 6.7474e-01   valLoss:9.3193e-01  time: 9.41e-01
epoch: 8   trainLoss: 6.6743e-01   valLoss:7.7763e-01  time: 9.98e-01
epoch: 9   trainLoss: 6.2234e-01   valLoss:6.7188e-01  time: 9.52e-01
epoch: 10   trainLoss: 5.9795e-01   valLoss:6.1265e-01  time: 9.67e-01
epoch: 11   trainLoss: 5.8180e-01   valLoss:5.7374e-01  time: 9.56e-01
epoch: 12   trainLoss: 5.3583e-01   valLoss:5.4524e-01  time: 1.05e+00
epoch: 13   trainLoss: 5.0522e-01   valLoss:5.1771e-01  time: 1.01

testing on design_5

training on design_6
epoch: 0   trainLoss: 9.6502e-01   valLoss:1.1972e+00  time: 9.85e-01
epoch: 1   trainLoss: 8.9800e-01   valLoss:1.2026e+00  time: 9.93e-01
epoch: 2   trainLoss: 8.6612e-01   valLoss:1.2054e+00  time: 9.87e-01
epoch: 3   trainLoss: 7.8615e-01   valLoss:1.2015e+00  time: 1.00e+00
epoch: 4   trainLoss: 7.9711e-01   valLoss:1.2041e+00  time: 1.15e+00
epoch: 5   trainLoss: 7.5902e-01   valLoss:1.1768e+00  time: 1.16e+00
epoch: 6   trainLoss: 6.8668e-01   valLoss:1.1238e+00  time: 1.14e+00
epoch: 7   trainLoss: 6.7836e-01   valLoss:1.0371e+00  time: 1.00e+00
epoch: 8   trainLoss: 6.2324e-01   valLoss:9.2628e-01  time: 1.01e+00
epoch: 9   trainLoss: 5.9268e-01   valLoss:7.5886e-01  time: 1.02e+00
epoch: 10   trainLoss: 5.5817e-01   valLoss:6.7341e-01  time: 1.02e+00
epoch: 11   trainLoss: 5.7479e-01   valLoss:6.2014e-01  time: 1.14e+00
epoch: 12   trainLoss: 5.0688e-01   valLoss:5.7041e-01  time: 1.14e+00
epoch: 13   trainLoss: 4.7484e-01   valLoss:5

testing on design_6

training on design_7
epoch: 0   trainLoss: 8.8450e-01   valLoss:9.4310e-01  time: 1.05e+00
epoch: 1   trainLoss: 8.3724e-01   valLoss:9.4267e-01  time: 1.03e+00
epoch: 2   trainLoss: 8.1404e-01   valLoss:9.4263e-01  time: 1.07e+00
epoch: 3   trainLoss: 7.3584e-01   valLoss:9.4545e-01  time: 1.04e+00
epoch: 4   trainLoss: 7.3168e-01   valLoss:9.3272e-01  time: 1.03e+00
epoch: 5   trainLoss: 7.0533e-01   valLoss:9.2125e-01  time: 1.16e+00
epoch: 6   trainLoss: 6.3549e-01   valLoss:8.7636e-01  time: 1.04e+00
epoch: 7   trainLoss: 6.0489e-01   valLoss:7.8333e-01  time: 1.05e+00
epoch: 8   trainLoss: 6.4063e-01   valLoss:6.7551e-01  time: 1.06e+00
epoch: 9   trainLoss: 5.5619e-01   valLoss:6.1818e-01  time: 1.04e+00
epoch: 10   trainLoss: 5.5086e-01   valLoss:5.9239e-01  time: 1.03e+00
epoch: 11   trainLoss: 5.4329e-01   valLoss:5.7101e-01  time: 1.05e+00
epoch: 12   trainLoss: 5.6116e-01   valLoss:5.5412e-01  time: 1.04e+00
epoch: 13   trainLoss: 4.8820e-01   valLoss:5

testing on design_7

training on design_8
epoch: 0   trainLoss: 8.8766e-01   valLoss:9.0016e-01  time: 1.09e+00
epoch: 1   trainLoss: 7.7776e-01   valLoss:8.7512e-01  time: 1.06e+00
epoch: 2   trainLoss: 7.7424e-01   valLoss:8.4472e-01  time: 1.07e+00
epoch: 3   trainLoss: 7.0403e-01   valLoss:8.5272e-01  time: 1.08e+00
epoch: 4   trainLoss: 6.5942e-01   valLoss:9.0604e-01  time: 1.08e+00
epoch: 5   trainLoss: 6.2513e-01   valLoss:9.7293e-01  time: 1.11e+00
epoch: 6   trainLoss: 5.6999e-01   valLoss:9.5564e-01  time: 1.25e+00
epoch: 7   trainLoss: 5.2024e-01   valLoss:8.0793e-01  time: 1.07e+00
epoch: 8   trainLoss: 5.4949e-01   valLoss:6.2151e-01  time: 1.10e+00
epoch: 9   trainLoss: 4.7895e-01   valLoss:4.8270e-01  time: 1.10e+00
epoch: 10   trainLoss: 4.5267e-01   valLoss:4.5090e-01  time: 1.10e+00
epoch: 11   trainLoss: 4.2458e-01   valLoss:4.3038e-01  time: 1.09e+00
epoch: 12   trainLoss: 3.9712e-01   valLoss:4.0851e-01  time: 1.06e+00
epoch: 13   trainLoss: 3.5359e-01   valLoss:3

testing on design_8

training on design_9
epoch: 0   trainLoss: 1.0742e+00   valLoss:1.3588e+00  time: 1.13e+00
epoch: 1   trainLoss: 9.3244e-01   valLoss:1.3202e+00  time: 1.22e+00
epoch: 2   trainLoss: 8.3198e-01   valLoss:1.2608e+00  time: 1.14e+00
epoch: 3   trainLoss: 7.9700e-01   valLoss:1.1973e+00  time: 1.21e+00
epoch: 4   trainLoss: 7.2501e-01   valLoss:1.2067e+00  time: 1.14e+00
epoch: 5   trainLoss: 6.4700e-01   valLoss:1.1764e+00  time: 1.13e+00
epoch: 6   trainLoss: 6.2242e-01   valLoss:1.1342e+00  time: 1.25e+00
epoch: 7   trainLoss: 5.6671e-01   valLoss:1.0049e+00  time: 1.12e+00
epoch: 8   trainLoss: 5.4338e-01   valLoss:7.8998e-01  time: 1.12e+00
epoch: 9   trainLoss: 4.8971e-01   valLoss:6.0398e-01  time: 1.22e+00
epoch: 10   trainLoss: 4.5206e-01   valLoss:4.9935e-01  time: 1.14e+00
epoch: 11   trainLoss: 4.2324e-01   valLoss:4.5414e-01  time: 1.13e+00
epoch: 12   trainLoss: 4.0892e-01   valLoss:4.2691e-01  time: 1.12e+00
epoch: 13   trainLoss: 3.9189e-01   valLoss:4

testing on design_9



Unnamed: 0,mse,mae,mre,maxAE,mae/peak,maxAE/peak,relEAtPeak,Trained on,Tested on
0,0.000005,0.001795,0.118980,0.005643,0.109546,0.344441,0.038204,test group,design_5
1,0.000004,0.001591,0.088504,0.003511,0.047683,0.105254,0.047008,test group,design_5
2,0.000014,0.002641,0.098149,0.009333,0.046424,0.164063,0.140409,test group,design_5
3,0.000021,0.003645,0.080626,0.010110,0.037601,0.104297,0.104297,test group,design_5
4,0.000007,0.002221,0.145833,0.005644,0.156502,0.397689,0.015627,test group,design_5
...,...,...,...,...,...,...,...,...,...
895,0.000035,0.004708,0.170756,0.017067,0.126866,0.459868,0.418951,test group,design_9
896,0.000109,0.008930,0.141644,0.020024,0.055678,0.124842,0.101105,test group,design_9
897,0.000026,0.003981,0.159829,0.011790,0.120447,0.356676,0.168230,test group,design_9
898,0.000045,0.005326,0.157918,0.021916,0.096798,0.398298,0.336077,test group,design_9


## 3. Train on all groups at once

In [5]:
allTrainData, allValData = [], []
print('training on all groups')
for name, data in trainSets.items():
    allTrainData = allTrainData + data
    allValData = allValData + valSets[name]

gcn = FeaStNet()
history = gcn.trainModel(allTrainData, allValData, epochs=epochs, batch_size=256, flatten=True, logTrans=False, 
                         ssTrans=True, saveDir=saveDir+'allGroups')

display(plotHistory(history))

# test
for testName, testSet in testSets.items():
    print('testing on '+testName+'\n')
    resultsDict = gcn.testModel(testSet, level='field')
    resultsDict['Trained on'] = ['all groups']*len(resultsDict['mse'])
    resultsDict['Tested on'] = [testName]*len(resultsDict['mse'])
    results = pivotDict(resultsDict)
    resultsList.extend(results)
    
pd.DataFrame(resultsList)

training on all groups
epoch: 0   trainLoss: 9.2671e-01   valLoss:1.1365e+00  time: 5.43e+00
epoch: 1   trainLoss: 7.2841e-01   valLoss:1.0169e+00  time: 5.54e+00
epoch: 2   trainLoss: 5.9773e-01   valLoss:5.9983e-01  time: 5.71e+00
epoch: 3   trainLoss: 5.0316e-01   valLoss:5.0962e-01  time: 5.60e+00
epoch: 4   trainLoss: 4.3783e-01   valLoss:4.4042e-01  time: 5.47e+00
epoch: 5   trainLoss: 3.8774e-01   valLoss:4.0421e-01  time: 5.61e+00
epoch: 6   trainLoss: 3.5105e-01   valLoss:3.6468e-01  time: 5.60e+00
epoch: 7   trainLoss: 3.1636e-01   valLoss:3.6055e-01  time: 5.47e+00
epoch: 8   trainLoss: 2.8824e-01   valLoss:3.2132e-01  time: 5.70e+00
epoch: 9   trainLoss: 2.6415e-01   valLoss:3.3713e-01  time: 5.32e+00
epoch: 10   trainLoss: 2.4084e-01   valLoss:2.8333e-01  time: 5.61e+00
epoch: 11   trainLoss: 2.2124e-01   valLoss:2.6678e-01  time: 5.33e+00
epoch: 12   trainLoss: 2.1076e-01   valLoss:2.9544e-01  time: 5.64e+00
epoch: 13   trainLoss: 1.9571e-01   valLoss:2.6086e-01  time: 5.

testing on design_5

testing on design_6

testing on design_7

testing on design_8

testing on design_9



Unnamed: 0,mse,mae,mre,maxAE,mae/peak,maxAE/peak,relEAtPeak,Trained on,Tested on
0,0.000005,0.001795,0.118980,0.005643,0.109546,0.344441,0.038204,test group,design_5
1,0.000004,0.001591,0.088504,0.003511,0.047683,0.105254,0.047008,test group,design_5
2,0.000014,0.002641,0.098149,0.009333,0.046424,0.164063,0.140409,test group,design_5
3,0.000021,0.003645,0.080626,0.010110,0.037601,0.104297,0.104297,test group,design_5
4,0.000007,0.002221,0.145833,0.005644,0.156502,0.397689,0.015627,test group,design_5
...,...,...,...,...,...,...,...,...,...
1795,0.000009,0.002165,0.082089,0.007113,0.058327,0.191664,0.116318,all groups,design_9
1796,0.000071,0.006632,0.099831,0.017901,0.041346,0.111605,0.064973,all groups,design_9
1797,0.000038,0.004583,0.192846,0.015018,0.138644,0.454345,0.085863,all groups,design_9
1798,0.000005,0.001618,0.051679,0.004988,0.029405,0.090646,0.015784,all groups,design_9


## 4. Leave one out

In [6]:
checkptFiles = {}
for trainName, trainSet in trainSets.items():
    allTrainData, allValData = [], []
    print('training on all but '+ trainName)
    for name, data in trainSets.items():
        if name != trainName:
            allTrainData = allTrainData + data
            allValData = allValData + valSets[name]

    gcn = FeaStNet()
    history = gcn.trainModel(allTrainData, allValData, epochs=epochs, batch_size=256, flatten=True, logTrans=False, 
                             ssTrans=True, saveDir=saveDir+'allBut_'+ trainName)

    display(plotHistory(history))

    # test
    print('testing on '+trainName+'\n')
    resultsDict = gcn.testModel(testSets[trainName], level='field')
    resultsDict['Trained on'] = ['all groups but test group']*len(resultsDict['mse'])
    resultsDict['Tested on'] = [trainName]*len(resultsDict['mse'])
    results = pivotDict(resultsDict)
    resultsList.extend(results)
    
    # note checkpoint files for use in transfer learning
    checkptFiles[trainName] = gcn.checkptFile
    
pd.DataFrame(resultsList)

training on all but design_5
epoch: 0   trainLoss: 9.6930e-01   valLoss:1.0775e+00  time: 4.29e+00
epoch: 1   trainLoss: 7.8108e-01   valLoss:1.0078e+00  time: 4.15e+00
epoch: 2   trainLoss: 6.6328e-01   valLoss:6.9700e-01  time: 4.36e+00
epoch: 3   trainLoss: 5.3653e-01   valLoss:5.3344e-01  time: 4.40e+00
epoch: 4   trainLoss: 4.5843e-01   valLoss:4.5802e-01  time: 4.40e+00
epoch: 5   trainLoss: 4.0087e-01   valLoss:4.1856e-01  time: 4.28e+00
epoch: 6   trainLoss: 3.6536e-01   valLoss:3.8926e-01  time: 4.54e+00
epoch: 7   trainLoss: 3.3237e-01   valLoss:3.8573e-01  time: 4.32e+00
epoch: 8   trainLoss: 3.0631e-01   valLoss:4.0324e-01  time: 4.41e+00
epoch: 9   trainLoss: 2.8531e-01   valLoss:3.0220e-01  time: 4.42e+00
epoch: 10   trainLoss: 2.5992e-01   valLoss:2.8157e-01  time: 4.37e+00
epoch: 11   trainLoss: 2.4162e-01   valLoss:2.8961e-01  time: 4.49e+00
epoch: 12   trainLoss: 2.3087e-01   valLoss:2.8926e-01  time: 4.56e+00
epoch: 13   trainLoss: 2.1378e-01   valLoss:2.8547e-01  ti

testing on design_5

training on all but design_6
epoch: 0   trainLoss: 9.0773e-01   valLoss:1.0125e+00  time: 4.60e+00
epoch: 1   trainLoss: 7.5854e-01   valLoss:9.3181e-01  time: 4.58e+00
epoch: 2   trainLoss: 6.4352e-01   valLoss:6.9308e-01  time: 4.46e+00
epoch: 3   trainLoss: 5.5340e-01   valLoss:5.6763e-01  time: 4.24e+00


KeyboardInterrupt: 

## 5. Transfer learning

In [None]:
allIds = list(range(len(trainSet)))
_, trainIds = train_test_split(allIds, test_size=0.2, shuffle=True, random_state=1234) # train on only 25%

for trainName, trainSet in trainSets.items():
    # load pre-trained model and train on new data
    print('transfer learning on '+trainName)
    gcn = FeaStNet()
    smallTrainSet = [trainSet[i] for i in trainIds]
    history = gcn.trainModel(smallTrainSet, valSets[trainName], restartFile=checkptFiles[trainName], 
                             epochs=epochs, batch_size=256, saveDir=saveDir+'transferLearn_'+ trainName)

    display(plotHistory(history))

    # test
    print('testing on '+trainName+'\n')
    resultsDict = gcn.testModel(testSets[trainName], level='field')
    resultsDict['Trained on'] = ['transfer learning (20%)']*len(resultsDict['mse'])
    resultsDict['Tested on'] = [trainName]*len(resultsDict['mse'])
    results = pivotDict(resultsDict)
    resultsList.extend(results)
        
pd.DataFrame(resultsList)

In [None]:
# save test results to file
df = pd.DataFrame(resultsList)
df.to_csv(saveDir+'testResults.csv', index=False)

In [None]:
# load test results from file
df = pd.read_csv(saveDir+'testResults.csv')
df

## 6. Plot results

In [None]:
order = ['test group', 'all groups']
barChart = alt.Chart(df).mark_bar().encode(
    x=alt.X('Trained on:N', sort=order, title='', axis=alt.Axis(ticks=False, labels=False)),
    y=alt.Y('mean(mse):Q', scale=alt.Scale(type='log'), axis=alt.Axis(tickCount=8, format=".0e"), title='MSE'),
    color=alt.Color('Trained on:N', sort=order),
    opacity = alt.OpacityValue(0.8),
    tooltip='mean(mse):Q'
).properties(width=75, height=200)

scatter = alt.Chart(df).mark_circle(size=20).encode(
    x=alt.X('Trained on:N', title='', sort=order),
    y=alt.Y('mse:Q', scale=alt.Scale(type='log')),
    color=alt.Color('Trained on:N', sort=order),
    opacity = alt.OpacityValue(0.3),
    tooltip='mse:Q'
)

alt.layer(barChart, scatter, data=df).facet(
    column=alt.Column('Tested on:N'))

In [None]:
order = ['test group', 'all groups']
barChart = alt.Chart(df).mark_bar().encode(
    x=alt.X('Trained on:N', sort=order, title='', axis=alt.Axis(ticks=False, labels=False)),
    y=alt.Y('mean(mse):Q', axis=alt.Axis(tickCount=8, format=".0e"), title='MSE'),
    color=alt.Color('Trained on:N', sort=order),
    opacity = alt.OpacityValue(0.8),
    tooltip='mean(mse):Q'
).properties(width=75, height=200)

scatter = alt.Chart(df).mark_circle(size=20).encode(
    x=alt.X('Trained on:N', title='', sort=order),
    y=alt.Y('mse:Q'),
    color=alt.Color('Trained on:N', sort=order),
    opacity = alt.OpacityValue(0.3),
    tooltip='mse:Q'
)

alt.layer(barChart, scatter, data=df).facet(
    column=alt.Column('Tested on:N'))

In [None]:
order = ['test group', 'all groups']
barChart = alt.Chart(df).mark_bar().encode(
    x=alt.X('Trained on:N', sort=order, title='', axis=alt.Axis(ticks=False, labels=False)),
    y=alt.Y('mean(mse):Q', axis=alt.Axis(tickCount=8, format=".0e"), title='MSE'),
    color=alt.Color('Trained on:N', sort=order),
    opacity = alt.OpacityValue(0.8),
    tooltip='mean(mse):Q'
).properties(width=75, height=200)

alt.layer(barChart, data=df).facet(
    column=alt.Column('Tested on:N'))

In [None]:
order = ['test group', 'all groups']
barChart = alt.Chart(df).mark_bar().encode(
    x=alt.X('Trained on:N', sort=order, title='', axis=alt.Axis(ticks=False, labels=False)),
    y=alt.Y('mean(mse):Q', axis=alt.Axis(tickCount=8, format=".0e"), title='MSE'),
    color=alt.Color('Trained on:N', sort=order, legend=alt.Legend(orient='bottom')),
    opacity = alt.OpacityValue(0.8),
    tooltip='mean(mse):Q'
).properties(width=75, height=200)

alt.layer(barChart, data=df).facet(
    column=alt.Column('Tested on:N'))