# Intro to gragh convolutional surrogate models
Eamon Whalen

In [1]:
import sys
import os
import numpy as np
import pandas as pd
import altair as alt

from gcnSurrogate.models.feastnetSurrogateModel import FeaStNet
from gcnSurrogate.readers.loadConmechGraphs import loadConmechGraphs
from gcnSurrogate.visualization.altTrussViz import plotTruss, interactiveErrorPlot
from gcnSurrogate.util.gcnSurrogateUtil import *

## 1. Load simulation data

In [2]:
dataDir = 'data/2D_Truss_v1.3/conmech/design_7_N_1000/'
allGraphsUnfiltered = loadConmechGraphs(dataDir)

maxes = [max(np.abs(graph.y.numpy().flatten())) for graph in allGraphsUnfiltered]
source = pd.DataFrame(maxes, columns=['maxes'])
source.describe()

Unnamed: 0,maxes
count,1000.0
mean,0.137562
std,1.740254
min,0.00453
25%,0.010635
50%,0.016778
75%,0.032392
max,46.419552


In [3]:
# plotTruss(allGraphsUnfiltered[0], showDeformed=True, defScale=10)

## 2. Filter and partition

In [4]:
allGraphs = filterbyDisp(allGraphsUnfiltered, 0.9)
trainData, valData, testData = partitionGraphList(allGraphs)

maxes = [max(np.abs(graph.y.numpy().flatten())) for graph in allGraphs]
source = pd.DataFrame(maxes, columns=['maxes'])
source.describe()

Unnamed: 0,maxes
count,900.0
mean,0.019962
std,0.013149
min,0.00453
25%,0.010309
50%,0.014975
75%,0.026453
max,0.067472


## 3. Train a GCN

In [5]:
gcn = FeaStNet()
history = gcn.trainModel(trainData, valData, 
                         epochs=100,
                         saveDir='results/gcn01/')

plotHistory(history)

train model: flatten: True
train model: self.flatten: True
fitSS: self.flatten: True
epoch: 0   trainLoss: 9.3707e-01   valLoss:9.6553e-01  time: 6.21e+00
epoch: 1   trainLoss: 7.2877e-01   valLoss:9.5059e-01  time: 1.40e+00
epoch: 2   trainLoss: 6.1796e-01   valLoss:1.0302e+00  time: 1.43e+00
epoch: 3   trainLoss: 5.0096e-01   valLoss:1.3340e+00  time: 1.41e+00
epoch: 4   trainLoss: 4.5089e-01   valLoss:1.1515e+00  time: 1.39e+00
epoch: 5   trainLoss: 3.9906e-01   valLoss:1.1968e+00  time: 1.40e+00
epoch: 6   trainLoss: 3.4878e-01   valLoss:1.5384e+00  time: 1.40e+00
epoch: 7   trainLoss: 3.2077e-01   valLoss:1.3084e+00  time: 1.41e+00
epoch: 8   trainLoss: 3.0518e-01   valLoss:1.4648e+00  time: 1.41e+00
epoch: 9   trainLoss: 2.5161e-01   valLoss:2.5058e+00  time: 1.43e+00
epoch: 10   trainLoss: 2.1811e-01   valLoss:1.9082e+00  time: 1.43e+00
epoch: 11   trainLoss: 1.9170e-01   valLoss:3.4080e+00  time: 1.41e+00
epoch: 12   trainLoss: 1.8146e-01   valLoss:1.5016e+00  time: 1.43e+00
ep

In [13]:
gcn2 = FeaStNet()
gcn2.loadModel('results/gcn01/checkpoint_94')
trainRes = gcn2.testModel(trainData)
testRes = gcn2.testModel(testData)
pd.DataFrame([trainRes, testRes], index=['train', 'test'])

Unnamed: 0,mse,mae,mre,peakR2,maxAggR2,meanAggR2,minAggR2
train,4e-06,0.0014,0.074304,0.942238,0.972997,0.847795,0.0
test,7e-06,0.001609,0.083742,0.841736,0.936382,0.804915,0.0


In [20]:
# predict
import torch

graph = allGraphs[0]

a = np.ones((15,2))
b = torch.from_numpy(a)

gcn2.ss.transform(b)

array([[98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923],
       [98.55522923, 98.55522923]])

In [22]:
gcn2.ss.mean_

gcn2(input)

array([-0.00366223])

## 4. Test the GCN

In [9]:
trainRes = gcn.testModel(trainData)
testRes = gcn.testModel(testData)
pd.DataFrame([trainRes, testRes], index=['train', 'test'])

Unnamed: 0,mse,mae,mre,peakR2,maxAggR2,meanAggR2,minAggR2
train,4e-06,0.0014,0.074304,0.942238,0.972997,0.847795,0.0
test,7e-06,0.001609,0.083742,0.841736,0.936382,0.804915,0.0


## 5. Visualize some predictions

In [11]:
i = 4
pred = gcn.predict([testData[i]])[0]
plotTruss(testData[i], showDeformed=True, defScale=20, prediction=pred)

In [10]:
# # interactive scatter plot
# alt.data_transformers.enable('json')
# allPreds = gcn.predict(testData)
# display(interactiveErrorPlot(testData, allPreds))
# alt.data_transformers.enable('default');