In [None]:
import torch, pickle
from tqdm import tqdm
import importlib

args = {
    "run_id": "1",
    "data_path": "data/2024-10-13_PC1D_process10_data.pkl",
    "device": "cuda",
    "n_samples": 1000,
    "epochs": 1000,
    "lr": 1e-3,
    "hidden_dim": 10,
    "step": 50,
    "encoder_hidden_dim": 128,
    "encoder_latent_dim": 10,
    "encoder_path": "encoder_run_4",
    "material_model": "m_dependent_b",
}
# args = torch.load("material_model_run_1/args.pkl")

mm = importlib.import_module(args.material_model)
from util import LossFunction
from m_encoder import *

device = torch.device(args.device if torch.cuda.is_available() else "cpu")

with open(args.data_path, "rb") as f:
    data = pickle.load(f)


N = args.n_samples
step = args.step

e = torch.tensor(data["strain"][:N, ::step], dtype=torch.float32).to(device)
e_dot = torch.tensor(data["strain_rate"][:N, ::step], dtype=torch.float32).to(device)
s = torch.tensor(data["stress"][:N, ::step], dtype=torch.float32).to(device)
E = torch.tensor(data["E"][:N], dtype=torch.float32).to(device)
nu = torch.tensor(data["nu"][:N], dtype=torch.float32).to(device)

loss_function = LossFunction()

ae_E = AutoEncoder(E.shape[1], args.encoder_hidden_dim, args.encoder_latent_dim).to(
    device
)
ae_nu = AutoEncoder(nu.shape[1], args.encoder_hidden_dim, args.encoder_latent_dim).to(
    device
)

ae_E.load_state_dict(torch.load(f"{args.encoder_path}/ae_E.pth", weights_only=True))
ae_nu.load_state_dict(torch.load(f"{args.encoder_path}/ae_nu.pth", weights_only=True))

energy_input_dim = args.encoder_latent_dim * 2 + 2
energy_hidden_dim = args.hidden_dim
dissipation_input_dim = args.encoder_latent_dim * 2 + 2
dissipation_hidden_dim = args.hidden_dim

vmm = mm.ViscoelasticMaterialModelM(
    energy_input_dim,
    energy_hidden_dim,
    dissipation_input_dim,
    dissipation_hidden_dim,
    ae_E.encoder,
    ae_nu.encoder,
).to(device)
optimizer_m = torch.optim.Adam(vmm.parameters(), lr=args.lr)
loss_history_m = []

epochs = args.epochs
for epoch in tqdm(range(epochs)):
    loss = mm.train_step_M(vmm, optimizer_m, e, e_dot, E, nu, s)
    loss_history_m.append(loss)
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss:.4f}")