# Direct Ink Write ddDT Demo
This notebook presents an example for training a ddDT on the direct ink write advanced manufacturing exemplar.

Before running this notebook, make sure the package is installed in your system by running 
`pip install -e .` from the base directory of this repository.

In [1]:
# imports
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from tsl.engines import Predictor
from tsl.data import ImputationDataset
from tsl.utils.casting import torch_to_numpy
from tsl.ops.connectivity import adj_to_edge_index
from tsl.metrics.torch import MaskedMSE, MaskedMAE, MaskedMAPE
from tsl.data.preprocessing import MinMaxScaler
from tsl.data.datamodule import SpatioTemporalDataModule, TemporalSplitter

from graphfoundationmodels.models.stGAE import STConvAE
from graphfoundationmodels.dataloaders.dataloader_DIW import DIWDataset

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
# check for GPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device used: {device}")

Device used: cpu



Now, we will define the data loading process using the dataloader we've defined.

In [2]:
DATA_PATH = '../data/diw-stgnn-featmat.parquet'

torch_dataset = DIWDataset(target_path=DATA_PATH,eval_mask=None,
                                  connectivity=None,
                                  window=250,
                                  stride=25)

splitter = TemporalSplitter(val_len=0.25, test_len=0.05)
dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    splitter=splitter,
    batch_size=8,
)
dm.setup()

Finally, we can define our model, importing the standard `STConvAE` from the package. We also utilize multiple metrics, such as MSE, MAE, and MAPE.
TorchSpacialTemporal also provides us with a useful `Predictor` object to keep track of our training details, such as loss function, learning rate, and optimizer, for which we use fairly standard values.

In [3]:
stgnn = STConvAE(device=device, 
                 num_nodes=torch_dataset.n_nodes, 
                 channel_size_list=np.array([[3, 8, 16], [16, 8, 3]]), 
                 num_layers=2,
                 kernel_size=4, 
                 K=2,
                 kernel_size_de=2,
                 stride=1,
                 padding=1,
                 normalization='sym',
                 bias=True)

loss_fn = MaskedMAE()

metrics = {'mse': MaskedMSE(),
           'mae': MaskedMAE(),
           'mape': MaskedMAPE(),
           }

predictor = Predictor(
    model=stgnn,
    optim_class=torch.optim.Adam,
    optim_kwargs={'lr': 0.01},
    loss_fn=loss_fn,
    metrics=metrics,
    scheduler_class=torch.optim.lr_scheduler.StepLR,
    scheduler_kwargs={'step_size':15}
)

We can now begin training. We also pull in a helper callback to make some plots.

In [5]:
from graphfoundationmodels.util.callbacks import LossPlotCallback
loss_plot_callback = LossPlotCallback()

trainer = pl.Trainer(max_epochs=1,
                     accelerator='gpu',
                     callbacks=[loss_plot_callback])

trainer.fit(predictor, datamodule=dm)

loss_plot_callback.plot_losses()

To validate the model, we now predict on the test set that the model hasn't seen.

In [None]:
predictor.freeze()
trainer.test(predictor, datamodule=dm)

# generate predictions
output = trainer.predict(predictor, dataloaders=dm.test_dataloader())
output = predictor.collate_prediction_outputs(output)
output = torch_to_numpy(output)

truth = output['y']
pred = output['y_hat']

To visualize, make a scatterplot of the error:

In [None]:
num_features = pred.shape[3]
truth_flattened = truth.reshape(-1, num_features)
pred_flattened = pred.reshape(-1, num_features)

fig, axes = plt.subplots(1, num_features, figsize=(15, 5))

feature_names = ['X-error', 'Y-error', 'Z-error']

for feature_idx in range(num_features):
    ax = axes[feature_idx]
    ax.scatter(truth_flattened[:, feature_idx], pred_flattened[:, feature_idx], alpha=0.5, s=1)

    ax.set_title(f'{feature_names[feature_idx]}')
    ax.set_xlabel('Actual Values')
    ax.set_ylabel('Predicted Values')
    ax.grid(True)

plt.tight_layout()
plt.show()


Or plot reconstructions:

In [None]:
for i in random.sample(range(0, pred.shape[0]), 3):

    node_id = random.sample(range(0, pred.shape[2]), 1)
    feature_names = ['X-error', 'Y-error', 'Z-error']

    fig, axes = plt.subplots(1, num_features, figsize=(15, 5))
    for feature_idx in range(num_features):
        ax = axes[feature_idx]
        ax.plot(truth[i, :, node_id, feature_idx][0])
        ax.plot(pred[i, :, node_id, feature_idx][0])

        ax.set_title(f'{feature_names[feature_idx]}')
        ax.set_xlabel('Timestep')
        ax.set_ylabel('Error')
        ax.grid(True)

    plt.tight_layout()
    plt.show()