In [1]:
import sys
import torch

sys.path.insert(0, "..")
basedir = "../.."

from common.config import create_object, load_config

%matplotlib widget

torch._dynamo.config.suppress_errors = True
torch._dynamo.disable()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# dconfig = load_config("../autoencoder/configs/data/burgersshift.yaml")
# dconfig.datasize.spacedim = 1
# dset = create_object(dconfig)

In [None]:
import time
import glob
import datetime
import copy
import os
import pickle
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import seaborn as sns
import matplotlib.cm as cm

import itertools

from itertools import combinations
from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from copy import deepcopy

import utils

class LTINetHelper():
  def __init__(self, config):
    self.update_config(config)

  def update_config(self, config):
    self.config = deepcopy(config)

  def create_ltinet(self, dataset, config=None, **args):
    if config is None:
      config = self.config

    assert(len(dataset.data.shape) < 4)
    if len(dataset.data.shape) == 3:
      din = dataset.params.shape[-1]
      dout = dataset.data.shape[-1]

    td = args.get("td", None)
    seed = args.get("seed", 0)
    device = args.get("device", 0)

    recclass = globals()[args.get("recclass", config.recclass)]
    recparams = copy.deepcopy(dict(args.get("recparams", config.recparams)))

    recparams["seq"][0] = din + 1
    recparams["seq"][-1] = dout

    return LTINet(dataset, recclass, recparams, td=td, seed=seed, device=device)

  @staticmethod
  def get_operrs(ltinet, times=None, testonly=False):
    if testonly:
      data = ltinet.dataset.data[ltinet.numtrain:,]
    else:
      data = ltinet.dataset.data

    errors = ltinet.get_errors(data, times=times, aggregate=False)

    return errors
  
  @staticmethod
  def plot_op_predicts(ltinet, testonly=False, xs=None, cmap="viridis"):
    if testonly:
      data = ltinet.dataset.data[ltinet.numtrain:,]
      params = ltinet.dataset.params[ltinet.numtrain:,]
    else:
      data = ltinet.dataset.data
      params = ltinet.dataset.params

    if xs == None:
      xs = np.linspace(0, 1, len(data[0, 0]))

    params = torch.tensor(np.float32(params)).to(ltinet.device)

    predicts = ltinet.propagate(params).cpu().detach()

    errors = []
    n = predicts.shape[0]
    for s in range(data.shape[1]):
      currpredict = predicts[:, s-1].reshape((n, -1))
      currreference = data[:, s].reshape((n, -1))
      errors.append(np.mean(np.linalg.norm(currpredict - currreference, axis=1) / np.linalg.norm(currreference, axis=1)))
        
    print(f"Average Relative L2 Error over all times: {np.mean(errors):.4f}")

    if len(data.shape) == 3:
      fig, ax = plt.subplots(figsize=(4, 3))

    @widgets.interact(i=(0, n-1), s=(1, ltinet.T-1))
    def plot_interact(i=0, s=1):
      print(f"Avg Relative L2 Error for t0 to t{s}: {errors[s-1]:.4f}")

      if len(data.shape) == 3:
        ax.clear()
        ax.set_title(f"RelL2 {np.linalg.norm(predicts[i, s-1] - data[i, s]) / np.linalg.norm(data[i, s])}")
        ax.plot(xs, data[i, 0], label="Input", linewidth=1)
        ax.plot(xs, predicts[i, s-1], label="Predicted", linewidth=1)
        ax.plot(xs, data[i, s], label="Exact", linewidth=1)
        ax.legend()
        
  @staticmethod
  def plot_errorparams(ltinet, param=-1):
    if param == -1:
        # Auto-detect one varying parameter
        param = 0
        P = ltinet.dataset.params.shape[1]
        for p in range(P):
            if np.abs(ltinet.dataset.params[0, p] - ltinet.dataset.params[1, p]) > 0:
                param = p
                break

    l2error = np.asarray(LTINetHelper.get_operrs(ltinet, times=[ltinet.T - 1]))
    params = ltinet.dataset.params

    print(params.shape, l2error.shape)

    if isinstance(param, (list, tuple, np.ndarray)) and len(param) == 2:
        # 3D scatter plot for 2 varying parameters
        x = params[:, param[0]]
        y = params[:, param[1]]
        z = l2error

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(x, y, z, c=z, cmap='viridis', s=10)

        ax.set_xlabel(f"Param {param[0]}")
        ax.set_ylabel(f"Param {param[1]}")
        ax.set_zlabel("Operator Error")
        fig.colorbar(sc, ax=ax, label="Operator Error")

    else:
        # Fallback to 2D scatter if param is 1D
        fig, ax = plt.subplots()
        ax.scatter(params[:, param], l2error, s=2)
        ax.set_xlabel(f"Parameter {param}")
        ax.set_ylabel("Operator Error")

    fig.tight_layout()

class LTINet():
  def __init__(self, dataset, recclass, recparams, td, seed, device):
    self.dataset = dataset
    self.device = device
    self.td = td
    self.f = self.dataset.params.shape[1]
  
    if self.td is None:
      self.prefix = f"{self.dataset.name}{str(recclass.__name__)}LTINet"
    else:
      self.prefix = self.td

    torch.manual_seed(seed)
    np.random.seed(seed)
    self.seed = seed

    datacopy = self.dataset.data.copy()
    self.numtrain = int(datacopy.shape[0] * 0.8)
    
    self.T = self.dataset.data.shape[1]
    self.trainarr = datacopy[:self.numtrain]
    self.testarr = datacopy[self.numtrain:]
    self.trainparams = self.dataset.params[:self.numtrain]
    self.testparams = self.dataset.params[self.numtrain:]
    self.optparams = None

    self.datadim = len(self.dataset.data.shape) - 2

    self.recclass = recclass
    self.recparams = copy.deepcopy(recparams)

    recparams["seq"][0] = self.f + 1
    recparams["seq"][-1] = self.dataset.data.shape[-1]

    self.recnet = recclass(**recparams).float().to(device)

    self.metadata = {
      "recclass": recclass.__name__,
      "recparams": recparams,
      "dataset_name": dataset.name,
      "data_shape": list(dataset.data.shape),
      "data_checksum": float(np.sum(dataset.data)),
      "seed": seed,
      "epochs": []
    }

  def propagate(self, code, start=0, end=-1):
    fullts = torch.linspace(0, 1, self.T).float().to(self.device)
    
    if end > 0:
      ts = fullts[start:end+1]
    else:
      ts = fullts[start:]

    out = self.forward(code, ts)
    return out

  def get_errors(self, testarr, testparams, ords=(2,), times=None, aggregate=True):
    assert(aggregate or len(ords) == 1)
    
    if isinstance(testarr, np.ndarray):
      testarr = torch.tensor(testarr, dtype=torch.float32)

    if isinstance(testparams, np.ndarray):
      testparams = torch.tensor(testparams, dtype=torch.float32)

    if times is None:
      times = range(self.T-1)
  
    out = self.propagate(testparams)

    n = testarr.shape[0]
    orig = testarr.cpu().detach().numpy()
    out = out.cpu().detach().numpy()

    if aggregate:
      orig = orig.reshape([n, -1])
      out = out.reshape([n, -1])
      testerrs = []
      for o in ords:
        testerrs.append(np.mean(np.linalg.norm(orig - out, axis=1, ord=o) / np.linalg.norm(orig, axis=1, ord=o)))

      return tuple(testerrs)
    
    else:
      o = ords[0]
      testerrs = []

      if len(times) == 1:
        t = times[0]
        origslice = orig[:, t].reshape([n, -1])
        outslice = out[:, t].reshape([n, -1])
        return np.linalg.norm(origslice - outslice, axis=1, ord=o) / np.linalg.norm(origslice, axis=1, ord=o)
      else:
        for t in range(orig.shape[1]):
          origslice = orig[:, t].reshape([n, -1])
          outslice = out[:, t].reshape([n, -1])
          testerrs.append(np.mean(np.linalg.norm(origslice - outslice, axis=1, ord=o) / np.linalg.norm(origslice, axis=1, ord=o)))

        return testerrs

  def forward(self, z, ts):
    z_shape = z.shape
    *leading_dims, N = z_shape
    T = ts.shape[0]

    z_expanded = z.unsqueeze(-2).expand(*leading_dims, T, N)

    t_shape = [1] * len(leading_dims) + [T, 1]
    t_expanded = ts.view(*t_shape).expand(*leading_dims, T, 1)

    result = torch.cat([z_expanded, t_expanded], dim=-1)

    decoded = self.recnet(result)
    return decoded

  def load_model(self, filename_prefix, verbose=False, min_epochs=0):
    search_path = f"savedmodels/ltinet/{filename_prefix}*.pickle"
    matching_files = glob.glob(search_path)

    print("Searching for model files matching prefix:", filename_prefix)
    if not hasattr(self, "metadata"):
        raise ValueError("Missing self.metadata. Cannot match models without metadata. Ensure model has been initialized with same config.")

    for addr in matching_files:
      try:
          with open(addr, "rb") as handle:
              dic = pickle.load(handle)
      except Exception as e:
          if verbose:
              print(f"Skipping {addr} due to read error: {e}")
          continue

      meta = dic.get("metadata", {})
      is_match = all(
          meta.get(k) == self.metadata.get(k)
          for k in self.metadata.keys()
      )

      # Check if model meets the minimum epoch requirement
      model_epochs = meta.get("epochs")
      if model_epochs is None:
          if verbose:
              print(f"Skipping {addr} due to missing epoch metadata.")
          continue
      elif isinstance(model_epochs, list):  # handle legacy or list format
          if sum(model_epochs) < min_epochs:
              if verbose:
                  print(f"Skipping {addr} due to insufficient epochs ({sum(model_epochs)} < {min_epochs})")
              continue
      elif model_epochs < min_epochs:
          if verbose:
              print(f"Skipping {addr} due to insufficient epochs ({model_epochs} < {min_epochs})")
          continue

      if is_match:
          print("Model match found. Loading from:", addr)
          self.recnet.load_state_dict(dic["recnet"])
          self.metadata["epochs"] = meta.get("epochs")
          if "opt" in dic:     
            self.optparams = dic["opt"]

          return True
      elif verbose:
          print("Metadata mismatch in file:", addr)
          for k in self.metadata:
              print(f"{k}: saved={meta.get(k)} vs current={self.metadata.get(k)}")

    print("Load failed. No matching models found.")
    print("Searched:", matching_files)
    return False

  def train_model(self, epochs, save=True, optim=torch.optim.AdamW, lr=1e-4, printinterval=10, batch=32, ridge=0, loss=None, best=True, verbose=False):
    def train_epoch(dataloader, writer=None, optimizer=None, scheduler=None, ep=0, printinterval=10, loss=None, testarr=None, testparams=None):
      losses = []
      testerrors1 = []
      testerrors2 = []
      testerrorsinf = []

      def closure(values, params):
        optimizer.zero_grad()

        out = self.propagate(params)
        target = values
        
        res = loss(out, target)
        res.backward()
        
        if writer is not None and self.trainstep % 5 == 0:
          writer.add_scalar("main/loss", res, global_step=self.trainstep)

        return res

      for values, params in dataloader:
        self.trainstep += 1
        error = optimizer.step(lambda: closure(values, params))
        losses.append(float(error.cpu().detach()))

      if scheduler is not None:
        scheduler.step(np.mean(losses))

      # print test
      if printinterval > 0 and (ep % printinterval == 0):
        testerr1, testerr2, testerrinf = self.get_errors(testarr, testparams, ords=(1, 2, np.inf))
        if scheduler is not None:
          print(f"{ep+1}: Train Loss {error:.3e}, LR {scheduler.get_last_lr()[-1]:.3e}, Relative LTINet Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")
        else:
          print(f"{ep+1}: Train Loss {error:.3e}, Relative LTINet Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")

        if writer is not None:
            writer.add_scalar("misc/relativeL1error", testerr1, global_step=ep)
            writer.add_scalar("main/relativeL2error", testerr2, global_step=ep)
            writer.add_scalar("misc/relativeLInferror", testerrinf, global_step=ep)

      return losses, testerrors1, testerrors2, testerrorsinf
  
    loss = nn.MSELoss() if loss is None else loss()

    losses, testerrors1, testerrors2, testerrorsinf = [], [], [], []
    self.trainstep = 0

    train = torch.tensor(self.trainarr, dtype=torch.float32).to(self.device)
    params = torch.tensor(self.trainparams, dtype=torch.float32).to(self.device)
    test = self.testarr

    opt = optim(self.recnet.parameters(), lr=lr, weight_decay=ridge)
    scheduler = lr_scheduler.ReduceLROnPlateau(opt, patience=30)
    dataloader = DataLoader(torch.utils.data.TensorDataset(train, params), shuffle=False, batch_size=batch)

    if self.optparams is not None:
      opt.load_state_dict(self.optparams)

    writer = None
    if self.td is not None:
      name = f"./tensorboard/{datetime.datetime.now().strftime('%d-%B-%Y')}/{self.td}/{datetime.datetime.now().strftime('%H.%M.%S')}/"
      writer = torch.utils.tensorboard.SummaryWriter(name)
      print("Tensorboard writer location is " + name)

    print("Number of NN trainable parameters", utils.num_params(self.recnet))
    print(f"Starting training LTINet model at {time.asctime()}...")
    print("train", train.shape, "test", test.shape)
      
    bestdict = { "loss": float(np.inf), "ep": 0 }
    for ep in range(epochs):
      lossesN, testerrors1N, testerrors2N, testerrorsinfN = train_epoch(dataloader, optimizer=opt, scheduler=scheduler, writer=writer, ep=ep, printinterval=printinterval, loss=loss, testarr=test, testparams=self.testparams)
      losses += lossesN; testerrors1 += testerrors1N; testerrors2 += testerrors2N; testerrorsinf += testerrorsinfN

      if best and ep > epochs // 2:
        avgloss = np.mean(lossesN)
        if avgloss < bestdict["loss"]:
          bestdict["recnet"] = self.recnet.state_dict()
          bestdict["opt"] = opt.state_dict()
          bestdict["loss"] = avgloss
          bestdict["ep"] = ep
        elif verbose:
          print(f"Loss not improved at epoch {ep} (Ratio: {avgloss/bestdict['loss']:.2f}) from {bestdict['ep']} (Loss: {bestdict['loss']:.2e})")
      
    print(f"Finished training LTINet model at {time.asctime()}...")

    if best:
      self.recnet.load_state_dict(bestdict["recnet"])
      opt.load_state_dict(bestdict["opt"])

    self.optparams = opt.state_dict()
    self.metadata["epochs"].append(epochs)

    if save:
      now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

      # Compute total training epochs
      total_epochs = sum(self.metadata["epochs"]) if isinstance(self.metadata["epochs"], list) else self.metadata["epochs"]

      filename = (
          f"{self.dataset.name}_"
          f"{self.recclass.__name__}_"
          f"{self.recparams['seq']}_"
          f"{self.seed}_"
          f"{total_epochs}ep_"
          f"{now}.pickle"
      )

      dire = "savedmodels/ltinet"
      addr = os.path.join(dire, filename)

      if not os.path.exists(dire):
          os.makedirs(dire)

      with open(addr, "wb") as handle:
          pickle.dump({
              "recnet": self.recnet.state_dict(),
              "metadata": self.metadata,
              "opt": self.optparams
          }, handle, protocol=pickle.HIGHEST_PROTOCOL)

      print("Model saved at", addr)

    return { "losses": losses, "testerrors1": testerrors1, "testerrors2": testerrors2, "testerrorsinf": testerrorsinf }



In [3]:
dconfig = load_config("../autoencoder/configs/data/burgersshift.yaml")
dconfig.datasize.spacedim = 1
dset = create_object(dconfig)

#dset.downsample_time(10)
dset.downsample(4)

In [7]:
import models
lconfig = load_config("../autoencoder/configs/experiments/donnormal.yaml")
experiment = models.TimeInputHelper(lconfig)

FFNet = models.FFNet
test = experiment.create_timeinput(dset)

test.train_model(400, lr=1e-3)

Number of NN trainable parameters 1581700
Starting training TI model DeepONet at Fri Jul  4 22:40:59 2025...
train torch.Size([400, 301, 128]) test (100, 301, 128)
1: Train Loss 1.228e-01, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.871023, 0.885350, 0.974652
11: Train Loss 2.543e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.348901, 0.429166, 0.802944
21: Train Loss 1.710e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.257463, 0.338546, 0.743950
31: Train Loss 1.596e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.252372, 0.333476, 0.744745
41: Train Loss 1.260e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.215464, 0.305711, 0.755190
51: Train Loss 1.176e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.199164, 0.276510, 0.707517
61: Train Loss 9.608e-03, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.195823, 0.268300, 0.687340
71: Train Loss 9.289e-03, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.199168, 0.268056, 0.718723
81: Train Loss 9.818e-03, LR 1.000e-0

{'losses': [0.15397650003433228,
  0.16766208410263062,
  0.15065503120422363,
  0.14493978023529053,
  0.13401150703430176,
  0.13720950484275818,
  0.15102946758270264,
  0.11988787353038788,
  0.13123498857021332,
  0.13423579931259155,
  0.13364563882350922,
  0.12886355817317963,
  0.12280315905809402,
  0.1250154972076416,
  0.12196734547615051,
  0.11190628260374069,
  0.11065759509801865,
  0.10506986081600189,
  0.09584058076143265,
  0.0888606384396553,
  0.08075150102376938,
  0.08250419795513153,
  0.08053058385848999,
  0.0747310221195221,
  0.07164130359888077,
  0.0783054530620575,
  0.07394121587276459,
  0.07109012454748154,
  0.06953132152557373,
  0.07250446826219559,
  0.07464279979467392,
  0.06476087123155594,
  0.06525207310914993,
  0.061666227877140045,
  0.06442311406135559,
  0.06318215280771255,
  0.06095825880765915,
  0.05885211378335953,
  0.06190148741006851,
  0.058058105409145355,
  0.05872269347310066,
  0.056085728108882904,
  0.05460492521524429,
  

In [11]:
import models
lconfig = load_config("../autoencoder/configs/experiments/weldnormal.yaml")
experimentg = models.WeldHelper(lconfig)

FFNet = models.FFNet
test = experimentg.create_weld(dset, windows=1)

test.train_aes_plus_props(400, lr=1e-3)

Training 1 WeldNet AEs and props together
Number of NN trainable parameters 1135132
Starting training WeldNet AE + Prop 1/1 (0->300) at Fri Jul  4 22:50:38 2025...
train torch.Size([400, 301, 128]) test (100, 301, 128)
1: Train Loss 5.020e-02 + 1.825e-04, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.486830, 0.554858, 0.906445
11: Train Loss 9.709e-03 + 9.007e-06, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.181283, 0.255140, 0.626496
21: Train Loss 4.984e-03 + 9.080e-06, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.141325, 0.193488, 0.611253
31: Train Loss 2.565e-03 + 7.542e-06, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.088458, 0.132817, 0.466957
41: Train Loss 1.489e-03 + 5.852e-06, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.068213, 0.103190, 0.398821
51: Train Loss 1.795e-03 + 1.547e-05, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.095790, 0.132489, 0.498541
61: Train Loss 1.360e-03 + 9.637e-06, LR 1.000e-03, Relative AE Error (1, 2, inf): 0.074751, 0.102068, 0.39836

KeyboardInterrupt: 

In [27]:
import models
lconfig = load_config("../autoencoder/configs/experiments/ltinet.yaml")
experiment = LTINetHelper(lconfig)

FFNet = models.FFNet
test = experiment.create_ltinet(dset)

test.train_model(400, lr=1e-3)

Number of NN trainable parameters 1386528
Starting training LTINet model at Fri Jul  4 22:29:04 2025...
train torch.Size([400, 301, 128]) test (100, 301, 128)
1: Train Loss 9.602e-02, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.708935, 0.792212, 0.964838
11: Train Loss 2.010e-02, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.267529, 0.354214, 0.729388
21: Train Loss 1.268e-02, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.205877, 0.287795, 0.694042
31: Train Loss 8.363e-03, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.158464, 0.236354, 0.650605
41: Train Loss 6.621e-03, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.143211, 0.212019, 0.626961
51: Train Loss 5.305e-03, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.129349, 0.194317, 0.606483
61: Train Loss 4.532e-03, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.119608, 0.179407, 0.584500
71: Train Loss 3.918e-03, LR 1.000e-03, Relative LTINet Error (1, 2, inf): 0.111702, 0.167281, 0.561656
81: Train 

AttributeError: 'LTINet' object has no attribute 'dynclass'