In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
from matplotlib.animation import FuncAnimation
import os
from path import Path
import pickle
plt.rcParams["font.family"] = "serif"
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

In [None]:
# First run the test.py to create the results pickle file for analysis
architectures = ['vit', 'cnn']
include_state = True
train_type = 'structure'
save_dir = Path('results')

In [None]:
# Load the data file or files
results = {}
for arch in architectures:
    filename = save_dir/'{}_{}_{}.pkl'.format(arch, "state" if include_state else "visu", train_type)
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    results[arch] = data


In [None]:
# Cell to plot the force estimation for both the cnn and the vit to compare with the ground truth
x = (1/30)*np.linspace(0, 580, len(results['vit']['gt']))

fig, ax = plt.subplots(ncols=1, nrows=3, sharex=False, sharey = False, figsize=(12, 6))
fig.suptitle("Force estimation for {} train set".format(train_type), fontsize=16, fontweight='bold')
ax[0].plot(x, results['vit']['shared_gt'][:, 0], 'b')
ax[0].plot(x, results['vit']['shared_pred'][:, 0], 'r')
ax[0].plot(x, results['cnn']['shared_pred'][:, 0], 'g')
ax[0].legend(["Ground truth", "ViST" if include_state else "ViT", "VSCNN" if include_state else "VCNN"], fontsize=10)
ax[0].set_ylabel("Force X (N)", fontsize=12)
ax[1].plot(x, results['vit']['shared_gt'][:, 1], 'b')
ax[1].plot(x, results['vit']['shared_pred'][:, 1], 'r')
ax[1].plot(x, results['cnn']['shared_pred'][:, 1], 'g')
ax[1].set_ylabel("Force Y (N)", fontsize=12)
ax[2].plot(x, results['vit']['shared_gt'][:, 2], 'b')
ax[2].plot(x, results['vit']['shared_pred'][:, 2], 'r')
ax[2].plot(x, results['cnn']['shared_pred'][:, 2], 'g')
ax[2].set_ylabel("Force Z (N)", fontsize=12)
ax[2].set_xlabel("Time (s)", fontsize=12)
fig.align_labels()

fig.savefig('figures/shared/predictions_{}.png'.format(train_type), dpi=800)

In [None]:
name = ['ViST', 'VSCNN']
rmse_values = [results['vit']['shared_rmse'].mean(), results['cnn']['shared_rmse'].mean()]
fig = plt.figure(figsize = (7, 5))
plt.bar(name[0], rmse_values[0], color='red', width=0.5)
plt.bar(name[1], rmse_values[1], color='green', width=0.5)
plt.xlabel('Network architectures', fontsize=12)
plt.ylabel('RMSE (N)', fontsize=12)
plt.title('Error for {} train set'.format(train_type), fontsize=16, fontweight='bold')

fig.savefig('figures/shared/metrics_{}.png'.format(train_type), dpi=800)

In [None]:
# Cell to plot the force estimation for both the cnn and the vit to compare with the ground truth
x = (1/30)*np.linspace(0, 580, len(results['vit']['gt']))

fig, ax = plt.subplots(ncols=1, nrows=3, sharex=False, sharey = False, figsize=(12, 6))
fig.suptitle("Force estimation for {} shift".format(train_type), fontsize=16, fontweight='bold')
ax[0].plot(x, results['vit']['test_gt'][:, 0], 'b')
ax[0].plot(x, results['vit']['test_pred'][:, 0], 'r')
ax[0].plot(x, results['cnn']['test_pred'][:, 0], 'g')
ax[0].legend(["Ground truth", "ViST" if include_state else "ViT", "VSCNN" if include_state else "VCNN"], fontsize=10)
ax[0].set_ylabel("Force X (N)", fontsize=12)
ax[1].plot(x, results['vit']['test_gt'][:, 1], 'b')
ax[1].plot(x, results['vit']['test_pred'][:, 1], 'r')
ax[1].plot(x, results['cnn']['test_pred'][:, 1], 'g')
ax[1].set_ylabel("Force Y (N)", fontsize=12)
ax[2].plot(x, results['vit']['test_gt'][:, 2], 'b')
ax[2].plot(x, results['vit']['test_pred'][:, 2], 'r')
ax[2].plot(x, results['cnn']['test_pred'][:, 2], 'g')
ax[2].set_ylabel("Force Z (N)", fontsize=12)
ax[2].set_xlabel("Time (s)", fontsize=12)
fig.align_labels()

fig.savefig('figures/test/predictions_{}.png'.format(train_type), dpi=800)

In [None]:
name = ['ViST', 'VSCNN']
rmse_values = [results['vit']['test_rmse'].mean(), results['cnn']['test_rmse'].mean()]
fig = plt.figure(figsize = (7, 5))
plt.bar(name[0], rmse_values[0], color='red', width=0.5)
plt.bar(name[1], rmse_values[1], color='green', width=0.5)
plt.xlabel('Network architectures', fontsize=12)
plt.ylabel('RMSE (N)', fontsize=12)
plt.title('Error for {} shift'.format(train_type), fontsize=16, fontweight='bold')

fig.savefig('figures/test/metrics_{}.png'.format(train_type), dpi=800)