In [1]:
# Setup code

import sys
sys.path.insert(0, "..")
basedir = "../.."

import datetime
import skdim
import glob
import time
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
import pickle
import re
import os
import ipywidgets as widgets

from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
from torchdiffeq import odeint

import models
import utils

def determine_params(paramarr):
  encoding_param = []
  P = paramarr.shape[1]

  for p in range(P):
    if np.abs(paramarr[0, p] - paramarr[1, p]) > 0:
      encoding_param.append(p)

  return encoding_param

%matplotlib widget
plt.rcParams["figure.figsize"] = (7, 4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

dsets = [(f"{basedir}/datasets/burgers/grfarc2visc0p001-shift.mat", "grfarc2visc0p001"),
          (f"{basedir}/datasets/burgers/grfarc2visc0p001-scale.mat", "grfarc2visc0p001"),
          (f"{basedir}/datasets/transport/hats2_2500_shift.mat", "alldata"),
          (f"{basedir}/datasets/transport/hats2_2500_scale.mat", "alldata"),
          (f"{basedir}/datasets/kdv/kdv2-shift.mat", "kdv2wide"),
          (f"{basedir}/datasets/kdv/kdv2-scale.mat", "kdv2wide")]

names = ["bshift", "bscale", "tshift", "tscale", "kshift", "kscale"]

# dsets2d = [(f"{basedir}/datasets/transport/hats2d_shift.mat", "alldata"), 
#            (f"{basedir}/datasets/transport/hats2d_scale.mat", "alldata")]

# names2d = ["t2shift", "t2scale"]

# dsets = dsets2d
# names = names2d

Using backend: pytorch



In [2]:
class WeldNetNonUniform(models.WeldNet):
  def __init__(self, dataset, maxwindows, aeclass, aeparams, propclass, propparams, transclass=None, transparams=None, alpha=0.05, residualprop=True, straightness=0, warmstart=True, tensorboard_directory=None, seed=0, device=0, kinetic=0, autonomous=True):
    self.dataset = dataset
    self.device = device
    self.td = tensorboard_directory
    self.straightness = straightness
    self.kinetic = kinetic
    self.autonomous = autonomous
    self.residualprop = residualprop

    assert(autonomous)

    assert(self.straightness == 0 or self.kinetic == 0)

    torch.manual_seed(seed)
    np.random.seed(seed)

    datacopy = self.dataset.data.copy()
    self.numtrain = int(datacopy.shape[0] * 0.8)
    
    self.T = self.dataset.data.shape[1]
    self.Wmax = maxwindows
    self.alpha = alpha
    
    self.aes = []
    self.props = []

    self.alltrain = datacopy[:self.numtrain]
    self.alltest = datacopy[self.numtrain:]

    self.aedata = [aeclass, aeparams, maxwindows, alpha, warmstart, self.straightness, self.kinetic]
    self.propdata = [propclass, propparams, maxwindows, alpha, warmstart, self.autonomous, self.residualprop]
    self.datadata = [np.sum(self.dataset.data), self.dataset.data.shape]

    if transclass is not None:
      self.transcoderdata = self.aedata + self.propdata + [transclass, transparams]

    self.windowvals = False
    self.transcoders = []
    self.aeclass = aeclass
    self.aeparams = aeparams
    self.propclass = propclass
    self.propparams = propparams
    self.transclass = propclass
    self.transparams = transparams
    self.warmstart = warmstart

  def train_aes(self, epochs_first, warmstart_epochs=-1, save=True, optim=torch.optim.AdamW, lr=1e-4, plottb=True, backwardstime=False, printinterval=10, batch=32, ridge=0, loss=None, encoding_param=-1):
    if warmstart_epochs == -1 and self.warmstart:
      warmstart_epochs = epochs_first
    
    def train_one_ae(ae, a, b, epochs):
      train = torch.tensor(self.alltrain[:, a:b+1]).float()
      test = torch.tensor(self.alltest[:, a:b+1]).float()

      writer = None
      if self.td is not None:
        name = f"./tensorboard/{datetime.datetime.now().strftime('%d-%B-%Y')}/{self.td}-weld{a}to{b}/{datetime.datetime.now().strftime('%H.%M.%S')}/"
        writer = torch.utils.tensorboard.SummaryWriter(name)
        print("Tensorboard writer location is " + name)
        
      print(f"Starting training WeldNet AE on [{a}, {b}] at {time.asctime()}...")
      print("Number of NN trainable parameters", utils.num_params(ae))
      print("train", train.shape, "test", test.shape)

      self.aestep = 0
      opt = optim(ae.parameters(), lr=lr, weight_decay=ridge)
      dataloader = DataLoader(train, batch_size=batch)

      for ep in range(epochs):
          ae_epoch(ae, dataloader, optimizer=opt, writer=writer, ep=ep, printinterval=printinterval, loss=loss, testarr=test)

          if ep % 5 == 0 and plottb:
            pass#self.plot_encoding(w, encoding_param, step=self.aestep, writer=writer, tensorboard=True)
      
      print(f"Finish training AE [{a}, {b}] at {time.asctime()}.")

      return ae

    def ae_epoch(model, dataloader, writer=None, optimizer=None, ep=0, printinterval=10, loss=None, testarr=None):
      def closure(batch):
        optimizer.zero_grad()
        total = 0
        penalties = 0

        for N in range(batch.shape[0]):
          traj = batch[N, :, :]
          enc = model.encode(traj)
          proj = model.decode(enc)

          # compute regularization here
          # add penalization to enc here
          # (enc[:-1] - enc[1:]) ** 2 is proportional to velocity

          res = loss(traj, proj)

          if self.straightness > 0:
            T = traj.shape[0]
            i_values = torch.arange(1, T)
            weights = (T - i_values) / T
            term1 = torch.outer(weights, enc[0, :])
            term2 = torch.outer((i_values / T), enc[-1, :])
            term3 = enc[i_values, :]
            penalty = loss(term1 + term2, term3)
            penalties += self.straightness * penalty
          elif self.kinetic > 0:
            #acceleration
            #starts = enc[:-2]
            #mids = enc[1:-1]
            #ends = enc[2:]

            # maybe scale?

            # order 2
            starts = enc[:-2]
            ends = enc[2:]
            penalty = loss(starts - ends, torch.zeros_like(starts))
            penalties += self.kinetic * penalty

            # order one
            # starts = enc[:-1, :]
            # ends = enc[1:, :]
            # penalty = loss(starts - ends, torch.zeros_like(starts))
            # penalties += self.kinetic * penalty

          total += res
        
        total += penalties
        total.backward()
        
        if writer is not None and self.aestep % 5:
          writer.add_scalar("main/loss", total, global_step=self.aestep)
          writer.add_scalar("main/penalty", penalties, global_step=self.aestep)

        return total

      for batch in dataloader:
        self.aestep += 1
        error = optimizer.step(lambda: closure(batch))

      # print test
      if printinterval > 0 and (ep % printinterval == 0):
        proj = model(testarr)
        testarr = testarr.cpu().detach().numpy()
        proj = proj.cpu().detach().numpy()

        testerr1 = np.mean(np.linalg.norm(testarr - proj, axis=1, ord=1) / np.linalg.norm(testarr, axis=1, ord=1))
        testerr2 = np.mean(np.linalg.norm(testarr - proj, axis=1, ord=2) / np.linalg.norm(testarr, axis=1, ord=2))
        testerrinf = np.mean(np.linalg.norm(testarr - proj, axis=1, ord=np.inf) / np.linalg.norm(testarr, axis=1, ord=np.inf))
        
        print(f"{ep+1}: Train Loss {error:.3e}, Relative Projection Error (1, 2, inf): {testerr1:3f}, {testerr2:3f}, {testerrinf:3f}")

        if writer is not None:
            writer.add_scalar("misc/relativeL1proj", testerr1, global_step=ep)
            writer.add_scalar("main/relativeL2proj", testerr2, global_step=ep)
            writer.add_scalar("misc/relativeLInfproj", testerrinf, global_step=ep)

    if loss is None:
       loss = nn.MSELoss()
    else:
       loss = loss()

    if encoding_param == -1:
      encoding_param = []
      P = self.dataset.params.shape[1]

      for p in range(P):
        if np.abs(self.dataset.params[0, p] - self.dataset.params[1, p]) > 0:
          encoding_param.append(p)

    print(f"Training until we have up to {self.Wmax} WeldNet AEs. Starting with initial [0, {self.T-1}].")

    aefirst = self.aeclass(**self.aeparams)
    aefirst = train_one_ae(aefirst, 0, self.T-1, epochs_first)

    windowsleft = [[(0, self.T-1), aefirst]]
    windowsdone = []

    while len(windowsleft) > 0 and len(windowsleft) + len(windowsdone) < self.Wmax:
      out = windowsleft[0]
      a, b = out[0]
      ae = out[1]
      windowsleft = windowsleft[1:]

      c = a + (b - a) // 2
      state = ae.state_dict()

      ae_a = self.aeclass(**self.aeparams)
      if warmstart_epochs > 0:
        ae_a.load_state_dict(state)
      ae_a = train_one_ae(ae_a, a, c, warmstart_epochs)

      ae_b = self.aeclass(**self.aeparams)
      if warmstart_epochs > 0:
        ae_b.load_state_dict(state)
      ae_b = train_one_ae(ae_b, c, b, warmstart_epochs)

      # lets evaluate error at c
      testslice = torch.tensor(self.alltest[:, c, :]).float().to(self.device)
      error = torch.norm(testslice - ae.forward(testslice))
      errora = torch.norm(testslice - ae_a.forward(testslice))
      errorb = torch.norm(testslice - ae_b.forward(testslice))

      print(float(errora), float(errorb), "versus", float(error))

      if 0.5 * (errora + errorb) < (1 - self.alpha) * error:
        if c - a < 2 or b - c < 2:
          if backwardstime:
            windowsleft.append([(c, b), ae_b])
            windowsleft.append([(a, c), ae_a])
          else:
            windowsleft.append([(a, c), ae_a])
            windowsleft.append([(c, b), ae_b])    
        
          print("Splitting and finishing", (a, b), "at", c)

        else:
          if backwardstime:
            windowsleft.append([(c, b), ae_b])
            windowsleft.append([(a, c), ae_a])
          else:
            windowsleft.append([(a, c), ae_a])
            windowsleft.append([(c, b), ae_b])    
       
          print("Splitting", (a, b), "at", c)
      else:
        windowsdone.append([(a, b), ae])
        print("No more improvement for", (a, b))

    windows = sorted(windowsleft + windowsdone, key=lambda x: x[0][0])
    self.aes = [x[1] for x in windows]
    self.windowvals = [range(x[0][0], x[0][1]+1) for x in windows]

    if save:
      dire = "savedmodels/weldnonunif"
      addr = f"{dire}/{self.td}{self.Wmax}-{self.alpha}-{datetime.datetime.now().strftime('%d-%B-%Y-%H.%M')}.pickle"

      if not os.path.exists(dire):
        os.makedirs(dire)

      with open(addr, "wb") as handle:
        pickle.dump({"aes": self.aes, "aedata": self.aedata, "datadata": self.datadata}, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print("AEs saved at", addr)

    print("Finished training all timewindows")

  def load_aes(self, filename, verbose=False):
    matching_files = glob.glob(f"savedmodels/weldnonunif/{filename}*{self.Wmax}-{self.alpha}-*")
    print("Searching for", str(self.aedata), str(self.datadata))

    for addr in matching_files:
      with open(addr, "rb") as handle:
        dic = pickle.load(handle)

        if str(self.aedata) == str(dic["aedata"]) and str(self.datadata) == str(dic["datadata"]):
          print("Loading AEs from", addr)
          self.aes = dic["aes"]
          return True
        elif verbose:
          print("NO MATCH", str(dic["aedata"]), str(dic["datadata"]))
            
    print(f"Load failed. Could not match with any files")
    print(matching_files)
    return False

In [3]:
baseepochs = 200

aeclass = models.FFAutoencoder
propclass = models.FFNet

welds = {}
start = 0
end = 1
for dset, name in zip(dsets[start:end], names[start:end]):
  dataset = utils.DynamicData(dset)
  print(dset, name)

  dataset.shuffle_inplace()
  dataset.subset_data(500)
  dataset.downsample(int(dataset.data.shape[2] / 256))
  dataset.data = dataset.data[:, ::2]
  dataset.scaledown()

  din = dataset.data.shape[-1]
  Lae = 3
  Lprop = 3
  pae = 400
  pprop = 200
  rp = True
  trans = True
  ws = True

  welds[name] = []

  for k in [4]:
    for style in ["base"]:
      for w in [10]:
        for auton in [True]:
          aeargs = { "encodeSeq": [din] + [pae] * Lae + [k], "decodeSeq": [k] + [pae] * Lae + [din], "activation": nn.ReLU() }
          propargs = { "seq": [k if auton else k+1] + [pprop] * Lprop + [k], "activation": nn.ReLU() }
          transclass = models.FFNet if trans else None
          transargs = { "seq": [k] + [pprop] * Lprop + [k], "activation": nn.ReLU() } if trans else None

          if style == "base":
            weld = WeldNetNonUniform(dataset, w, aeclass, aeargs, propclass, propargs, warmstart=ws, transclass=transclass, transparams=transargs, device=device, tensorboard_directory=f"{name}base-{k}-{'auton' if auton else 'nonauton'}", straightness=0, autonomous=auton, residualprop=rp)
          elif style == "straight":
            weld = WeldNetNonUniform(dataset, w, aeclass, aeargs, propclass, propargs, warmstart=ws, transclass=transclass, transparams=transargs, device=device, tensorboard_directory=f"{name}straight-{k}-{'auton' if auton else 'nonauton'}", straightness=0.1, autonomous=auton, residualprop=rp)
          else:
            weld = WeldNetNonUniform(dataset, w, aeclass, aeargs, propclass, propargs, warmstart=ws, transclass=transclass, transparams=transargs, device=device, tensorboard_directory=f"{name}kinetic-{k}-{'auton' if auton else 'nonauton'}", kinetic=1, autonomous=auton, residualprop=rp)

          loadae = weld.load_aes(name, verbose=True)
          if not loadae:
            weld.train_aes(baseepochs, printinterval=25, batch=16, save=True, plottb=False, lr=1e-4)
          
          assert(False)
          # loadboth = False#weld.load_all(name)
          
          # if not loadboth:
          #   loadae = weld.load_aes(name)
          #   if not loadae:
          #     weld.train_aes(baseepochs, warmstart_epochs=baseepochs, printinterval=25, batch=16, save=True, plottb=False, lr=1e-4)

          #   weld.train_propagators(baseepochs * 10, batch=32, printinterval=baseepochs, save=True, lr=1e-10)

          # loadweld = False#weld.load_transcoders(name)
          # if not loadweld:
          #   weld.train_transcoders(baseepochs * 5, printinterval=baseepochs//2, lr=1e-4, save=True)
          # #assert(False)

          # welds[name].append(weld)

('../../datasets/burgers/grfarc2visc0p001-shift.mat', 'grfarc2visc0p001') bshift
Searching for [<class 'models.FFAutoencoder'>, {'encodeSeq': [256, 400, 400, 400, 4], 'decodeSeq': [4, 400, 400, 400, 256], 'activation': ReLU()}, 10, 0.05, True, 0, 0] [-0.00015802056145730603, (500, 26, 256)]
Loading AEs from savedmodels/weldnonunif\bshiftbase-4-auton10-0.05-04-September-2024-18.50.pickle


AssertionError: 

In [5]:
weld.windowvals

False

In [4]:
from models import WeldAnalyzer

WeldAnalyzer.plot_encoding_window(weld)

TypeError: 'bool' object is not subscriptable