In [2]:
import sys
import torch

sys.path.insert(0, "..")
basedir = "../.."

from common.config import create_object, load_config

%matplotlib widget

torch._dynamo.config.suppress_errors = True
torch._dynamo.disable()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# dconfig = load_config("../autoencoder/configs/data/burgersshift.yaml")
# dconfig.datasize.spacedim = 1
# dset = create_object(dconfig)

In [None]:
import time
import glob
import datetime
import copy
import os
import pickle
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import seaborn as sns
import matplotlib.cm as cm

import itertools

from itertools import combinations
from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from copy import deepcopy

import utils

class ETINetHelper():
  def __init__(self, config):
    self.update_config(config)

  def update_config(self, config):
    self.config = deepcopy(config)

  def create_etinet(self, dataset, k, config=None, **args):
    if config is None:
      config = self.config

    assert(len(dataset.data.shape) < 4)
    if len(dataset.data.shape) == 3:
      din = dataset.params.shape[-1]
      dout = dataset.data.shape[-1]

    td = args.get("td", None)
    seed = args.get("seed", 0)
    device = args.get("device", 0)

    recclass = globals()[args.get("recclass", config.recclass)]
    recparams = copy.deepcopy(dict(args.get("recparams", config.recparams)))

    aeclass = globals()[args.get("aeclass", config.aeclass)]
    aeparams = copy.deepcopy(dict(args.get("aeparams", config.aeparams)))

    recparams["seq"][0] = k + 1
    recparams["seq"][-1] = dout

    return ETINet(dataset, k, aeclass, aeparams, recclass, recparams, td=td, seed=seed, device=device)

  @staticmethod
  def get_operrs(etinet, times=None, testonly=False):
    if testonly:
      data = etinet.dataset.data[etinet.numtrain:,]
    else:
      data = etinet.dataset.data

    errors = etinet.get_errors(data, times=times, aggregate=False)

    return errors
  
  @staticmethod
  def plot_op_predicts(etinet, testonly=False, xs=None, cmap="viridis"):
    if testonly:
      data = etinet.dataset.data[etinet.numtrain:,]
      params = etinet.dataset.params[etinet.numtrain:,]
    else:
      data = etinet.dataset.data
      params = etinet.dataset.params

    if xs == None:
      xs = np.linspace(0, 1, len(data[0, 0]))

    params = torch.tensor(np.float32(params)).to(etinet.device)

    predicts = etinet.propagate(params).cpu().detach()

    errors = []
    n = predicts.shape[0]
    for s in range(data.shape[1]):
      currpredict = predicts[:, s-1].reshape((n, -1))
      currreference = data[:, s].reshape((n, -1))
      errors.append(np.mean(np.linalg.norm(currpredict - currreference, axis=1) / np.linalg.norm(currreference, axis=1)))
        
    print(f"Average Relative L2 Error over all times: {np.mean(errors):.4f}")

    if len(data.shape) == 3:
      fig, ax = plt.subplots(figsize=(4, 3))

    @widgets.interact(i=(0, n-1), s=(1, etinet.T-1))
    def plot_interact(i=0, s=1):
      print(f"Avg Relative L2 Error for t0 to t{s}: {errors[s-1]:.4f}")

      if len(data.shape) == 3:
        ax.clear()
        ax.set_title(f"RelL2 {np.linalg.norm(predicts[i, s-1] - data[i, s]) / np.linalg.norm(data[i, s])}")
        ax.plot(xs, data[i, 0], label="Input", linewidth=1)
        ax.plot(xs, predicts[i, s-1], label="Predicted", linewidth=1)
        ax.plot(xs, data[i, s], label="Exact", linewidth=1)
        ax.legend()
        
  @staticmethod
  def plot_errorparams(etinet, param=-1):
    if param == -1:
        # Auto-detect one varying parameter
        param = 0
        P = etinet.dataset.params.shape[1]
        for p in range(P):
            if np.abs(etinet.dataset.params[0, p] - etinet.dataset.params[1, p]) > 0:
                param = p
                break

    l2error = np.asarray(ETINetHelper.get_operrs(etinet, times=[etinet.T - 1]))
    params = etinet.dataset.params

    print(params.shape, l2error.shape)

    if isinstance(param, (list, tuple, np.ndarray)) and len(param) == 2:
        # 3D scatter plot for 2 varying parameters
        x = params[:, param[0]]
        y = params[:, param[1]]
        z = l2error

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        sc = ax.scatter(x, y, z, c=z, cmap='viridis', s=10)

        ax.set_xlabel(f"Param {param[0]}")
        ax.set_ylabel(f"Param {param[1]}")
        ax.set_zlabel("Operator Error")
        fig.colorbar(sc, ax=ax, label="Operator Error")

    else:
        # Fallback to 2D scatter if param is 1D
        fig, ax = plt.subplots()
        ax.scatter(params[:, param], l2error, s=2)
        ax.set_xlabel(f"Parameter {param}")
        ax.set_ylabel("Operator Error")

    fig.tight_layout()

class ETINet():
  def __init__(self, dataset, k, aeclass, aeparams, recclass, recparams, td, seed, device):
    self.dataset = dataset
    self.device = device
    self.td = td
    self.k = k
  
    if self.td is None:
      self.prefix = f"{self.dataset.name}{str(recclass.__name__)}ETINet"
    else:
      self.prefix = self.td

    torch.manual_seed(seed)
    np.random.seed(seed)
    self.seed = seed

    datacopy = self.dataset.data.copy()
    self.numtrain = int(datacopy.shape[0] * 0.8)
    
    self.T = self.dataset.data.shape[1]
    self.trainarr = datacopy[:self.numtrain]
    self.testarr = datacopy[self.numtrain:]
    self.optparams = None

    self.datadim = len(self.dataset.data.shape) - 2
    self.aestep = 0
    self.recstep = 0

    aeparams["encodeSeq"][0] = self.dataset.data.shape[-1]
    aeparams["encodeSeq"][-1] = self.k
    aeparams["decodeSeq"][0] = self.k
    aeparams["decodeSeq"][-1] = self.dataset.data.shape[-1]

    recparams["seq"][0] = self.k + 1
    recparams["seq"][-1] = self.dataset.data.shape[-1]

    self.aeclass = aeclass
    self.aeparams = copy.deepcopy(aeparams)
    self.recclass = recclass
    self.recparams = copy.deepcopy(recparams)

    self.aenet = aeclass(**aeparams).float().to(device)
    self.recnet = recclass(**recparams).float().to(device)

    self.metadata = {
      "aeclass": aeclass.__name__,
      "aeparams": aeparams,
      "recclass": recclass.__name__,
      "recparams": recparams,
      "dataset_name": dataset.name,
      "data_shape": list(dataset.data.shape),
      "data_checksum": float(np.sum(dataset.data)),
      "seed": seed,
      "epochs": []
    }

  def reconstruct(self, z, ts):
    z_shape = z.shape
    *leading_dims, N = z_shape
    T = ts.shape[0]

    z_expanded = z.unsqueeze(-2).expand(*leading_dims, T, N)

    t_shape = [1] * len(leading_dims) + [T, 1]
    t_expanded = ts.view(*t_shape).expand(*leading_dims, T, 1)

    result = torch.cat([z_expanded, t_expanded], dim=-1)

    recon = self.recnet(result)
    return recon

  def propagate(self, code, start=1, end=-1):
    fullts = torch.linspace(0, 1, self.T).float().to(self.device)
    
    if end > 0:
      ts = fullts[start:end+1]
    else:
      ts = fullts[start:]

    out = self.reconstruct(code, ts)
    return out

  def get_errors(self, testarr, testrest, ords=(2,), times=None, aggregate=True):
    assert(aggregate or len(ords) == 1)
    
    if isinstance(testarr, np.ndarray):
      testarr = torch.tensor(testarr, dtype=torch.float32)

    if isinstance(testrest, np.ndarray):
      testrest = torch.tensor(testrest, dtype=torch.float32)

    if times is None:
      times = range(self.T-1)
  
    out = self.propagate(testarr)

    n = testarr.shape[0]
    orig = testrest.cpu().detach().numpy()
    out = out.cpu().detach().numpy()

    if aggregate:
      orig = orig.reshape([n, -1])
      out = out.reshape([n, -1])
      testerrs = []
      for o in ords:
        testerrs.append(np.mean(np.linalg.norm(orig - out, axis=1, ord=o) / np.linalg.norm(orig, axis=1, ord=o)))

      return tuple(testerrs)
    
    else:
      o = ords[0]
      testerrs = []

      if len(times) == 1:
        t = times[0]
        origslice = orig[:, t].reshape([n, -1])
        outslice = out[:, t].reshape([n, -1])
        return np.linalg.norm(origslice - outslice, axis=1, ord=o) / np.linalg.norm(origslice, axis=1, ord=o)
      else:
        for t in range(orig.shape[1]):
          origslice = orig[:, t].reshape([n, -1])
          outslice = out[:, t].reshape([n, -1])
          testerrs.append(np.mean(np.linalg.norm(origslice - outslice, axis=1, ord=o) / np.linalg.norm(origslice, axis=1, ord=o)))

        return testerrs
      
  def get_ae_errors(self, testarr, ords=(2,)):
    if isinstance(testarr, np.ndarray):
      testarr = torch.tensor(testarr, dtype=torch.float32)
  
    out = self.aenet(testarr).cpu().detach().numpy()
    orig = testarr.cpu().detach().numpy()

    testerrs = []
    for o in ords:
      testerrs.append(np.mean(np.linalg.norm(orig - out, axis=1, ord=o) / np.linalg.norm(orig, axis=1, ord=o)))

    return tuple(testerrs)

  def load_models(self, filename_prefix, verbose=False, min_epochs=0):
    search_path = f"savedmodels/etinet/{filename_prefix}*.pickle"
    matching_files = glob.glob(search_path)

    print("Searching for model files matching prefix:", filename_prefix)
    if not hasattr(self, "metadata"):
        raise ValueError("Missing self.metadata. Cannot match models without metadata. Ensure model has been initialized with same config.")

    for addr in matching_files:
      try:
          with open(addr, "rb") as handle:
              dic = pickle.load(handle)
      except Exception as e:
          if verbose:
              print(f"Skipping {addr} due to read error: {e}")
          continue

      meta = dic.get("metadata", {})
      is_match = all(
          meta.get(k) == self.metadata.get(k)
          for k in self.metadata.keys()
      )

      # Check if model meets the minimum epoch requirement
      model_epochs = meta.get("epochs")
      if model_epochs is None:
          if verbose:
              print(f"Skipping {addr} due to missing epoch metadata.")
          continue
      elif isinstance(model_epochs, list):  # handle legacy or list format
          if sum(model_epochs) < min_epochs:
              if verbose:
                  print(f"Skipping {addr} due to insufficient epochs ({sum(model_epochs)} < {min_epochs})")
              continue
      elif model_epochs < min_epochs:
          if verbose:
              print(f"Skipping {addr} due to insufficient epochs ({model_epochs} < {min_epochs})")
          continue

      if is_match:
          print("Model match found. Loading from:", addr)
          self.recnet.load_state_dict(dic["recnet"])
          self.aenet.load_state_dict(dic["aenet"])
          self.metadata["epochs"] = meta.get("epochs")
          if "opt" in dic:     
            self.optparams = dic["opt"]

          return True
      elif verbose:
          print("Metadata mismatch in file:", addr)
          for k in self.metadata:
              print(f"{k}: saved={meta.get(k)} vs current={self.metadata.get(k)}")

    print("Load failed. No matching models found.")
    print("Searched:", matching_files)
    return False

  def train_recnet(self, epochs, save=True, optim=torch.optim.AdamW, lr=1e-4, printinterval=10, batch=32, ridge=0, loss=None, best=True, verbose=False):
    def recnet_epoch(dataloader, writer=None, optimizer=None, scheduler=None, ep=0, printinterval=10, loss=None, testarr=None, testrest=None):
      losses = []
      testerrors1 = []
      testerrors2 = []
      testerrorsinf = []

      def closure(codes, rest):
        optimizer.zero_grad()

        out = self.propagate(codes)
        target = rest
        
        res = loss(out, target)
        res.backward()
        
        if writer is not None and self.recstep % 5 == 0:
          writer.add_scalar("main/loss", res, global_step=self.recstep)

        return res

      for codes, rest in dataloader:
        self.recstep += 1
        error = optimizer.step(lambda: closure(codes, rest))
        losses.append(float(error.cpu().detach()))

      if scheduler is not None:
        scheduler.step(np.mean(losses))

      # print test
      if printinterval > 0 and (ep % printinterval == 0):
        testerr1, testerr2, testerrinf = self.get_errors(testarr, testrest, ords=(1, 2, np.inf))
        if scheduler is not None:
          print(f"{ep+1}: Train Loss {error:.3e}, LR {scheduler.get_last_lr()[-1]:.3e}, Relative ETINet Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")
        else:
          print(f"{ep+1}: Train Loss {error:.3e}, Relative ETINet Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")

        if writer is not None:
            writer.add_scalar("misc/relativeL1error", testerr1, global_step=ep)
            writer.add_scalar("main/relativeL2error", testerr2, global_step=ep)
            writer.add_scalar("misc/relativeLInferror", testerrinf, global_step=ep)

      return losses, testerrors1, testerrors2, testerrorsinf
  
    assert(self.aestep > 0)

    loss = nn.MSELoss() if loss is None else loss()

    losses, testerrors1, testerrors2, testerrorsinf = [], [], [], []

    initial = torch.tensor(self.trainarr[:, 0], dtype=torch.float32).to(self.device)
    rest = torch.tensor(self.trainarr[:, 1:], dtype=torch.float32).to(self.device)
    train = self.aenet.encode(initial).detach()

    testinitial = torch.tensor(self.testarr[:, 0], dtype=torch.float32).to(self.device)
    testrest = torch.tensor(self.testarr[:, 1:], dtype=torch.float32).to(self.device)
    test = self.aenet.encode(testinitial).detach()

    opt = optim(self.recnet.parameters(), lr=lr, weight_decay=ridge)
    scheduler = lr_scheduler.ReduceLROnPlateau(opt, patience=30)
    dataloader = DataLoader(torch.utils.data.TensorDataset(train, rest), shuffle=False, batch_size=batch)

    writer = None
    if self.td is not None:
      name = f"./tensorboard/{datetime.datetime.now().strftime('%d-%B-%Y')}/{self.td}/{datetime.datetime.now().strftime('%H.%M.%S')}/"
      writer = torch.utils.tensorboard.SummaryWriter(name)
      print("Tensorboard writer location is " + name)

    print("Number of NN trainable parameters", utils.num_params(self.recnet))
    print(f"Starting ETINet rec model at {time.asctime()}...")
    print("train", train.shape, "test", test.shape)
      
    bestdict = { "loss": float(np.inf), "ep": 0 }
    for ep in range(epochs):
      lossesN, testerrors1N, testerrors2N, testerrorsinfN = recnet_epoch(dataloader, optimizer=opt, scheduler=scheduler, writer=writer, ep=ep, printinterval=printinterval, loss=loss, testarr=test, testrest=testrest)
      losses += lossesN; testerrors1 += testerrors1N; testerrors2 += testerrors2N; testerrorsinf += testerrorsinfN

      if best and ep > epochs // 2:
        avgloss = np.mean(lossesN)
        if avgloss < bestdict["loss"]:
          bestdict["recnet"] = self.recnet.state_dict()
          bestdict["opt"] = opt.state_dict()
          bestdict["loss"] = avgloss
          bestdict["ep"] = ep
        elif verbose:
          print(f"Loss not improved at epoch {ep} (Ratio: {avgloss/bestdict['loss']:.2f}) from {bestdict['ep']} (Loss: {bestdict['loss']:.2e})")
      
    print(f"Finished training ETINet rec model at {time.asctime()}...")

    if best:
      self.recnet.load_state_dict(bestdict["recnet"])
      opt.load_state_dict(bestdict["opt"])

    self.aeoptparams = opt.state_dict()
    self.metadata["aeepochs"].append(epochs)

    if save:
      now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

      # Compute total training epochs
      total_epochs = sum(self.metadata["epochs"]) if isinstance(self.metadata["epochs"], list) else self.metadata["epochs"]

      filename = (
          f"{self.dataset.name}_"
          f"{self.aeclass.__name__}_"
          f"{self.aeparams['seq']}_"
          f"{self.recclass.__name__}_"
          f"{self.recparams['seq']}_"
          f"{self.seed}_"
          f"{total_epochs}ep_"
          f"{now}.pickle"
      )

      dire = "savedmodels/etinet"
      addr = os.path.join(dire, filename)

      if not os.path.exists(dire):
          os.makedirs(dire)

      with open(addr, "wb") as handle:
          pickle.dump({
              "aenet": self.aenet.state_dict(),
              "recnet": self.recnet.state_dict(),
              "metadata": self.metadata,
              "opt": self.optparams
          }, handle, protocol=pickle.HIGHEST_PROTOCOL)

      print("Model saved at", addr)

    return { "losses": losses, "testerrors1": testerrors1, "testerrors2": testerrors2, "testerrorsinf": testerrorsinf }

  def train_aenet(self, epochs, save=True, optim=torch.optim.AdamW, lr=1e-4, printinterval=10, batch=32, ridge=0, loss=None, best=True, verbose=False):
    def aenet_epoch(dataloader, writer=None, optimizer=None, scheduler=None, ep=0, printinterval=10, loss=None, testarr=None):
      losses = []
      testerrors1 = []
      testerrors2 = []
      testerrorsinf = []

      def closure(codes):
        optimizer.zero_grad()

        out = self.aenet(codes)
        target = codes
        
        res = loss(out, target)
        res.backward()
        
        if writer is not None and self.recstep % 5 == 0:
          writer.add_scalar("main/aeloss", res, global_step=self.recstep)

        return res

      for codes in dataloader:
        self.aestep += 1
        error = optimizer.step(lambda: closure(codes))
        losses.append(float(error.cpu().detach()))

      if scheduler is not None:
        scheduler.step(np.mean(losses))

      # print test
      if printinterval > 0 and (ep % printinterval == 0):
        testerr1, testerr2, testerrinf = self.get_ae_errors(testarr, ords=(1, 2, np.inf))
        if scheduler is not None:
          print(f"{ep+1}: Train Loss {error:.3e}, LR {scheduler.get_last_lr()[-1]:.3e}, Relative ETINet AE Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")
        else:
          print(f"{ep+1}: Train Loss {error:.3e}, Relative ETINet AE Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")

        if writer is not None:
            writer.add_scalar("misc/relativeL1error", testerr1, global_step=ep)
            writer.add_scalar("main/relativeL2error", testerr2, global_step=ep)
            writer.add_scalar("misc/relativeLInferror", testerrinf, global_step=ep)

      return losses, testerrors1, testerrors2, testerrorsinf
  
    loss = nn.MSELoss() if loss is None else loss()

    losses, testerrors1, testerrors2, testerrorsinf = [], [], [], []

    initial = torch.tensor(self.trainarr[:, 0], dtype=torch.float32).to(self.device)
    test = torch.tensor(self.testarr[:, 0], dtype=torch.float32).to(self.device)

    opt = optim(self.aenet.parameters(), lr=lr, weight_decay=ridge)
    scheduler = lr_scheduler.ReduceLROnPlateau(opt, patience=30)
    dataloader = DataLoader(initial, shuffle=False, batch_size=batch)

    if self.optparams is not None:
      opt.load_state_dict(self.optparams)

    writer = None
    if self.td is not None:
      name = f"./tensorboard/{datetime.datetime.now().strftime('%d-%B-%Y')}/{self.td}/{datetime.datetime.now().strftime('%H.%M.%S')}/"
      writer = torch.utils.tensorboard.SummaryWriter(name)
      print("Tensorboard writer location is " + name)

    print("Number of NN trainable parameters", utils.num_params(self.recnet))
    print(f"Starting training ETINet AE model at {time.asctime()}...")
    print("train", initial.shape, "test", test.shape)
      
    bestdict = { "loss": float(np.inf), "ep": 0 }
    for ep in range(epochs):
      lossesN, testerrors1N, testerrors2N, testerrorsinfN = aenet_epoch(dataloader, optimizer=opt, scheduler=scheduler, writer=writer, ep=ep, printinterval=printinterval, loss=loss, testarr=test)
      losses += lossesN; testerrors1 += testerrors1N; testerrors2 += testerrors2N; testerrorsinf += testerrorsinfN

      if best and ep > epochs // 2:
        avgloss = np.mean(lossesN)
        if avgloss < bestdict["loss"]:
          bestdict["aenet"] = self.aenet.state_dict()
          bestdict["opt"] = opt.state_dict()
          bestdict["loss"] = avgloss
          bestdict["ep"] = ep
        elif verbose:
          print(f"Loss not improved at epoch {ep} (Ratio: {avgloss/bestdict['loss']:.2f}) from {bestdict['ep']} (Loss: {bestdict['loss']:.2e})")
      
    print(f"Finished training ETINet AE model at {time.asctime()}...")

    if best:
      self.aenet.load_state_dict(bestdict["aenet"])
      opt.load_state_dict(bestdict["opt"])

    return { "losses": losses, "testerrors1": testerrors1, "testerrors2": testerrors2, "testerrorsinf": testerrorsinf }


In [4]:
dconfig = load_config("../autoencoder/configs/data/burgersshift.yaml")
dconfig.datasize.spacedim = 1
dset = create_object(dconfig)

#dset.downsample_time(10)
dset.downsample(4)

In [None]:
import models
lconfig = load_config("../autoencoder/configs/experiments/etinet.yaml")
experiment = ETINetHelper(lconfig)

FFNet = models.FFNet
FFAutoencoder = models.FFAutoencoder
test = experiment.create_etinet(dset, 2)

test.train_aenet(400, lr=1e-3)
test.train_recnet(400, lr=1e-3)

Number of NN trainable parameters 800528
Starting training ETINet AE model at Sat Jul  5 10:34:57 2025...
train torch.Size([400, 128]) test torch.Size([100, 128])
1: Train Loss 4.922e-02, LR 1.000e-03, Relative ETINet AE Error (1, 2, inf): 0.424543, 0.429133, 0.601456
11: Train Loss 2.471e-04, LR 1.000e-03, Relative ETINet AE Error (1, 2, inf): 0.034005, 0.034449, 0.046659
21: Train Loss 3.528e-03, LR 1.000e-03, Relative ETINet AE Error (1, 2, inf): 0.111951, 0.108367, 0.126833
31: Train Loss 6.569e-04, LR 1.000e-03, Relative ETINet AE Error (1, 2, inf): 0.054653, 0.054799, 0.069752
41: Train Loss 2.728e-04, LR 1.000e-03, Relative ETINet AE Error (1, 2, inf): 0.043904, 0.044011, 0.056270
51: Train Loss 6.194e-05, LR 1.000e-04, Relative ETINet AE Error (1, 2, inf): 0.016327, 0.016507, 0.022629
61: Train Loss 5.602e-05, LR 1.000e-04, Relative ETINet AE Error (1, 2, inf): 0.015526, 0.015650, 0.021682
71: Train Loss 5.224e-05, LR 1.000e-04, Relative ETINet AE Error (1, 2, inf): 0.014990, 0