In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch import nn, optim

from data_loading import *
from pytorch_utils import *
from models import *
from evaluation import *

### Load datasets

- 6 within-sample-set datasets: stored in `wss_data` dictionary - key names in `data_names`
- 6 out-of-sample-set datasets: stored in `oss_data` dictionary - key names in `data_names`
- Load from datasets in `data/` directory
- Also get train/test indices for each. These are in `i_tr` and `i_val` dictionaries


In [None]:
shape_names = ["vor","lat","both"]
stress_names = ["stress_" + shape for shape in shape_names]
temp_names = ["temp_" + shape for shape in shape_names]
data_names = stress_names + temp_names

wss_data = dict()
oss_data = dict()
i_tr = dict()
i_val = dict()
datadir = "data/"

for name in data_names:
    if name == "stress_both":
        wss_data[name] = wss_data["stress_vor"] + wss_data["stress_lat"]
        oss_data[name] = oss_data["stress_vor"] + oss_data["stress_lat"]
    
    elif name == "temp_both":
        wss_data[name] = wss_data["temp_vor"] + wss_data["temp_lat"]
        oss_data[name] = oss_data["temp_vor"] + oss_data["temp_lat"]
    
    else:
        scale = 1. if "temp" in name else 10000.   # Divide stress values by 10000Pa
        wss_data[name] = load_matlab_dataset(datadir + name + "_w.mat", scale)
        oss_data[name] = load_matlab_dataset(datadir + name + "_o.mat", scale)
    
    idxs_tr, idxs_val = get_split_indices(wss_data[name])
    
    i_tr[name] = idxs_tr
    i_val[name] = idxs_val

## Models
### Create models

- 6 `SSENet` models, one per dataset: `models` dict

In [None]:
models = dict()
for name in data_names:
    model = SSENet()
    models[name] = model

## Train models

- Train each model for 50 epochs with learning rate 0.001
- Store loss curves in `hist_tr` (training) and `hist_val` (validation)
- Store training times in `times`
- Save models as .pth files in this directory

In [None]:
hist_tr = dict()
hist_val = dict()
times = dict()

for name in data_names:
    print(f"\n\n_________________________ Now Training: {name} _________________________")
    model = models[name]
    dataset = wss_data[name]
    idxs_tr, idxs_val = i_tr[name], i_val[name]
    model, tr_loss, val_loss, train_time = train_model(model, dataset, idxs_tr, idxs_val)
    
    models[name] = model
    hist_tr[name] = tr_loss
    hist_val[name] = val_loss
    times[name] = train_time
    torch.save(model, "model_" + name + ".pth")


### Plot loss

- Plot loss for stress prediction models

In [None]:
label_names = dict(stress_lat = "Lattice Set", stress_vor = "Voronoi Set", stress_both = "Combined Set" )


plt.figure(dpi=120, figsize=(5,3))

for name in stress_names:
    plt.plot(hist_val[name],"--",label = label_names[name] + ": Validation")
    plt.plot(hist_tr[name], "-", label = label_names[name] + ": Training")

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

#plt.savefig('stress_loss.png',bbox_inches = "tight")
plt.show()

### Evaluate models

- Compute $R^2$ values for all models
- Save all distributions to `r2s_tr` (training), `r2s_te` (testing), and `rts_oss` (outside sample set)
- Display median values for each

In [None]:
r2s_tr  = dict()
r2s_te  = dict()
r2s_oss = dict()


for name in data_names:
    model = models[name]
    wss = wss_data[name]
    oss = oss_data[name]
    idxs_tr, idxs_val = i_tr[name], i_val[name]
    
    vals1, vals2, vals3 = evaluate_all_data(model, wss, idxs_tr, idxs_val, oss)
    r2s_tr[name]  = vals1
    r2s_te[name]  = vals2
    r2s_oss[name] = vals3
    print(f"Model: {name}")
    print("Train Median: %0.4f    Test Median: %0.4f    OSS Median: %0.4f\n" 
        %(    np.median(vals1),     np.median(vals2),    np.median(vals3)))


### Plot $R^2$ distributions

- Using the model trained on Combined Set for stress prediction, `models["stress_both"]`
- Display boxplot distribution of $R^2$ for training, testing, and out-of-sample sets



In [None]:
name = "stress_both"
model = models[name]
wss = wss_data[name]
oss = oss_data[name]
idxs_tr, idxs_val = i_tr[name], i_val[name]

vals1, vals2, vals3 = evaluate_all_data(model, wss, idxs_tr, idxs_val, oss)
plot_boxes(vals1,vals2,vals3)

### Train partial models

#### Train models with the following input combinations:

|  (x, y, SDF) | Local features   | Global features |
|  :-:         | :-:              | :-:             |
| $\checkmark$ | -                | $\checkmark$    |
| $\checkmark$ | $\checkmark$     |  -              |
| -            | $\checkmark$     | $\checkmark$    |

- Models use `SSENetCustom()` to get access to partial inputs
- Train for Combined Set stress data only
- Save models in this directory

In [None]:
partial_names = ["xyd_global","xyd_local","local_global"]
partials = [(1,0,1),(1,1,0),(0,1,1)]
partial_models = []
name = "stress_both"
dataset = wss_data[name]
idxs_tr, idxs_val = i_tr[name], i_val[name]

for i, p in enumerate(partials):
    model = SSENetCustom(p)
    model, tr_loss, val_loss, train_time = train_model(model, dataset, idxs_tr, idxs_val)
    partial_models.append(model)
    model_name = "model_" + partial_names[i] + ".pth"
    torch.save(model, model_name)
    

### Evaluate partial models

In [None]:
name = "stress_both"
wss = wss_data[name]
oss = oss_data[name]
idxs_tr, idxs_val = i_tr[name], i_val[name]

for i in enumerate(partial_names):
    model = partial_models[i]
    vals1, vals2, vals3 = evaluate_all_data(model, wss, idxs_tr, idxs_val, oss)

    print(f"Model: {partial_names[i]}")
    print("Train Median: %0.4f    Test Median: %0.4f    OSS Median: %0.4f\n" 
        %(    np.median(vals1),     np.median(vals2),    np.median(vals3)))
