In [None]:
import torch, pickle, argparse, os
import importlib
from fig_settings import *

In [None]:
run_id = 'material_model_run_3'
params = torch.load(run_id+'/args.pkl', weights_only=True)
args = argparse.Namespace(**params)  
folder = f'overleaf/figs/{run_id}' 
if not os.path.exists(folder):
    os.makedirs(folder)

In [None]:
mm = importlib.import_module(args.material_model)
from util import LossFunction
from m_encoder import *

device = torch.device(args.device if torch.cuda.is_available() else "cpu")

with open(args.data_path, "rb") as f:
    data = pickle.load(f)


N = args.n_samples
step = args.step

e = torch.tensor(data["strain"][:N, ::step], dtype=torch.float32).to(device)
e_dot = torch.tensor(data["strain_rate"][:N, ::step], dtype=torch.float32).to(device)
s = torch.tensor(data["stress"][:N, ::step], dtype=torch.float32).to(device)
E = torch.tensor(data["E"][:N], dtype=torch.float32).to(device)
nu = torch.tensor(data["nu"][:N], dtype=torch.float32).to(device)

loss_function = LossFunction()

ae_E = AutoEncoder(E.shape[1], args.encoder_hidden_dim, args.encoder_latent_dim).to(
    device
)
ae_nu = AutoEncoder(nu.shape[1], args.encoder_hidden_dim, args.encoder_latent_dim).to(
    device
)

ae_E.load_state_dict(torch.load(f"{args.encoder_path}/ae_E.pth", weights_only=True, map_location=torch.device('cpu')))
ae_nu.load_state_dict(torch.load(f"{args.encoder_path}/ae_nu.pth", weights_only=True, map_location=torch.device('cpu')))

In [None]:
energy_input_dim = (1, 1, args.encoder_latent_dim * 2)
energy_hidden_dim = args.hidden_dim
dissipation_input_dim = (1, 1, args.encoder_latent_dim * 2)
dissipation_hidden_dim = args.hidden_dim

vmm = mm.ViscoelasticMaterialModelM(
    energy_input_dim,
    energy_hidden_dim,
    dissipation_input_dim,
    dissipation_hidden_dim,
    ae_E.encoder,
    ae_nu.encoder,
    dt = step/5000.0
).to(device)

vmm.load_state_dict(torch.load(f"material_model_run_{args.run_id}/vmm.pth", weights_only=True, map_location=torch.device('cpu')))

spred, xipred = mm.prediction_step(vmm, e, e_dot, E, nu)