In [1]:
import sys
import os
import glob
import numpy as np
import pandas as pd
import altair as alt

from gcnSurrogate.models.feastnetSurrogateModel import FeaStNet
from gcnSurrogate.readers.loadConmechGraphs import loadConmechGraphs
from gcnSurrogate.visualization.altTrussViz import plotTruss, interactiveErrorPlot
from gcnSurrogate.util.gcnSurrogateUtil import *

# Design space and loading conditions

In [5]:
# alt.themes.enable('none')
sweepDirs = glob.glob('data/2D_Truss_v1.3/conmech/param_sweep*/')
width = 600

sweepPlotList = []
for i,dataDir in enumerate(np.sort(sweepDirs)):
    print(f'loading from {dataDir}')
    allGraphs = loadConmechGraphs(dataDir)
    
    # make a title
    title = alt.Chart(
        {'values': [{"text": f'x{i+1}'}]}
    ).mark_text(size=30, align='center', font='Arial').encode(
        text="text:N"
    ).properties(width=width)
    sweepPlotList.append(title)

    opacities = np.linspace(0.1,0.6,5)
    opacities[-1] = 1
    plotList = [plotTruss(g, withoutConfigure=True, 
                          lineOpacity=o, 
                          baseColor='#000000',
                          domX=[-5,55], 
                          domY=[-20,0], 
                width=width) for g,o in zip(allGraphs, opacities)]
    sweepPlot = alt.layer(*plotList)
    sweepPlotList.append(sweepPlot)
    
# make a title
nomTitle = alt.Chart(
    {'values': [{"text": f'Loading conditions'}]}
).mark_text(size=30, align='center', font='Arial').encode(
    text="text:N"
).properties(width=width)
    
nomDesign = plotTruss(allGraphs[2], 
                      withoutConfigure=True, 
                      domX=[-5,55], 
                      domY=[-20,0], 
                      width=width,
                      baseColor='#000000')
    
col0 = alt.vconcat(*sweepPlotList[:6])
col1 = alt.vconcat(*sweepPlotList[6:],nomTitle,nomDesign)
finalChart = alt.hconcat(col0, col1, background='white').configure_view(strokeOpacity=0)
finalChart

loading from data/2D_Truss_v1.3/conmech/param_sweep_0_7_N_5/
loading from data/2D_Truss_v1.3/conmech/param_sweep_1_7_N_5/
loading from data/2D_Truss_v1.3/conmech/param_sweep_2_7_N_5/
loading from data/2D_Truss_v1.3/conmech/param_sweep_3_7_N_5/
loading from data/2D_Truss_v1.3/conmech/param_sweep_4_7_N_5/


# Error plot and sample predictions

In [51]:
# load model
gcn = FeaStNet()
gcn.loadModel('results/gcn01/checkpoint_94')

# load data
dataDir = 'data/2D_Truss_v1.3/conmech/design_7_N_1000/'
allGraphsUnfiltered = loadConmechGraphs(dataDir)
allGraphs = filterbyDisp(allGraphsUnfiltered, 0.9)
trainData, valData, testData = partitionGraphList(allGraphs)

# run inference
predictions = gcn.predict(testData)
resultsDictWide = computeFieldLossMetrics([g.y.cpu().numpy() for g in testData], 
                              predictions, 
                              baselineRef=None, level='field')
resultsDict = [dict(zip(resultsDictWide,t)) for t in zip(*resultsDictWide.values())]
df = pd.DataFrame(resultsDict)
df['Percentile'] = df['mse'].rank(pct=True)
df = df.sort_values(['Percentile'])

In [72]:
cdf = alt.Chart(df).mark_circle().encode(
    y=alt.Y('Percentile:Q'),
    x=alt.X('mse:Q', scale=alt.Scale(type='log'), axis=alt.Axis(tickCount=5, format=".0e"), title='MSE'),
).properties(width=300, height=300)

trussList = []
for percentileTarget in np.linspace(0,1, num=5):
    i = df[df.Percentile >= percentileTarget].index[0]
    t = plotTruss(testData[i], showDeformed=True, prediction=predictions[i], defScale=200, withoutConfigure=True, width=100)
    trussList.append(t)
    
col0 = alt.vconcat(*trussList)

alt.hconcat(cdf, col0).configure_view(strokeOpacity=0)

# Multi-topology generalization

In [3]:
dpi = 500 # dots per inch
picW = 3.5625 # inches
imageWidth = dpi*picW
imageWidth

1781.25

In [14]:
tempFile = 'figures/TEMP.svg'
nomDesign = plotTruss(allGraphs[2], withoutConfigure=True, domX=[-5,55], domY=[-20,5], baseColor='#000000', width=500)
nomDesign.save(tempFile)
# img = Image.open(tempFile)
# print(img.size)
img

UnidentifiedImageError: cannot identify image file 'figures/TEMP.svg'

In [18]:
svg2png(url=tempFile, write_to='figures/CAIROTEMP.png', dpi=300)

In [None]:
from PIL import Image
from cairosvg import svg2png

def alt2png(altChart, tempFile='figures/TEMP.svg', dpi=300):
    altChart.save(tempFile)

In [29]:
from vega_datasets import data
cars = data.cars()

alt.Chart(cars).mark_point().encode(
    x='Horsepower',
    y='Miles_per_Gallon',
    color='Origin',
).properties(title="Blah blah blah")