### Prerequisite for this script
Run the first two code section and move on to batch training part.

In [1]:
import sys
sys.path.insert(0, "..")
basedir = "../.."

In [10]:
def trainer(dconfig, wconfig, nows=False, **kwargs):
    is_training = kwargs.pop("is_training", True)
    from common.config import create_object, load_config
    
    dconfig = load_config(dconfig, [f"datasize.spacedim=1"])
    dset = create_object(dconfig)
    wconfig = load_config(wconfig)
    experiment = create_object(wconfig)
    reg_name = kwargs.pop("reg_name", "base")
    
    if isinstance(experiment, models.WeldHelper):
        name = f"{dset.name}{wconfig.aeclass}{'noWS' if nows else ''}"
        k = 4# wconfig.aeparams.encodeSeq[-1] if "encodeSeq" in wconfig.aeparams else wconfig.aeparams.sample.latents_dims
        kwargs["td"] = f"{name}{reg_name}-{k}-{'auton' if  kwargs['autonomous'] else 'nonauton'}"
        weld = experiment.create_weld(dset, **kwargs)
        baseepochs = 150
        is_all_loaded = weld.load_all(name)
        if is_training:
            if not is_all_loaded and kwargs.get("windows",1) > 1:
                is_ae_prop_loaded = weld.load_aes_and_props(name,verbose=True)
            else:
                is_ae_prop_loaded = is_all_loaded

            if not is_ae_prop_loaded:
                if not weld.load_aes(name,verbose=True):
                    weld.train_aes(baseepochs, warmstart_epochs=0 if nows else baseepochs, printinterval=25, batch=16, save=True, plottb=False, lr=1e-4)
                weld.train_propagators(baseepochs * 10, batch=32, printinterval=baseepochs, save=True, lr=1e-5)
                
            if not is_all_loaded and kwargs.get("windows",1) > 1:
                weld.train_transcoders(baseepochs * 10, batch=32, printinterval=baseepochs, save=True, lr=1e-5)
        return weld
    else:
        assert(isinstance(experiment, models.TimeInputHelper))
        kwargs["td"] = f"{dset.name}{wconfig.ticlass}"
        ti = experiment.create_timeinput(dset, **kwargs)
        baseepochs = 1000
        is_all_loaded = ti.load_model(dset.name)
        if is_training:
            if not is_all_loaded:
                ti.train_model(baseepochs, printinterval=baseepochs//20, lr=1e-3)

        return ti

def batch_trainer(dconfig:str, wconfig:str, nows=False, **kwargs):
    is_training = kwargs.pop("is_training", True)
    base_config = {"reg_name":"base", "straightness":0, "kinetic":0}
    straight_config = {"reg_name":"straight", "straightness":0.1, "kinetic":0}
    kinetic_config = {"reg_name":"kinetic", "straightness":0, "kinetic":1}
    batched_welds = []
    
    regconfigs = [base_config]
    for autonomous in [True]: # autonomous = True,
        for reg_config in regconfigs:
            welds = []
            for windows in [1]:
                kwargs = {"windows": windows,
                          "autonomous": autonomous,
                          "k": 20,
                          **reg_config}
                welds.append(trainer(dconfig, wconfig, is_training=is_training, nows=nows, **kwargs))
            batched_welds.append(welds)
    return batched_welds

import models

def plots(welds:list, dir="plots"):
    import matplotlib.pyplot as plt
    import os

    if not os.path.exists(dir):
        os.makedirs(dir)

    if isinstance(welds[0], models.WeldNet):
        #print(f"We're now comparing the projection and operator errors from {welds[0].td}")
        fig = welds.WeldHelper.compare_projops(welds, windowlines=False)
        fig.tight_layout()
        fig.savefig(f"{dir}/{welds[0].td}_projopserrs.png", facecolor="white", dpi=200)
        plt.close(fig)
        
        #print(f"We're now comparing the propogation erros from {welds[0].td}")
        fig = welds.WeldHelper.compare_properrs(welds, windowlines=False)
        fig.tight_layout()
        fig.savefig(f"{dir}/{welds[0].td}_properrs.png", facecolor="white", dpi=200)
        plt.close(fig)

        fig = welds.WeldHelper.compare_errorparams(welds)
        fig.tight_layout()
        fig.savefig(f"{dir}/{welds[0].td}_errorparams.png", facecolor="white", dpi=200)
        plt.close(fig)
    else:
        #print(f"We're now comparing the projection and operator errors from {welds[0].td}")
        fig = models.TimeInputHelper.compare_operrs(welds)
        fig.savefig(f"{dir}/{welds[0].td}_operrs.png", facecolor="white", dpi=200)
        plt.close(fig)

        fig = models.TimeInputHelper.compare_errorparams(welds)
        fig.tight_layout()
        fig.savefig(f"{dir}/{welds[0].td}_errorparams.png", facecolor="white", dpi=200)
        plt.close(fig)
    
def tester(*args, **kwargs):
    weld = trainer(*args, is_training=False, **kwargs)
    plots([weld])
    
def batch_tester(*args, **kwargs):
    batched_welds = batch_trainer(*args, is_training=False, **kwargs)
    for welds in batched_welds:
        plots(welds)

### WELDNET Training (& optional: Test) :

*follow the following example to train our weldnet with preferable hyperparameters.*

### Batched Training  (& optional: Test) ::

*use "batch_trainer" with data and model configuration. Add "aeclass" keyword arguments for CNN and LSTM Autoencoder if necessary.*

In [5]:
DCONFIGS = {
    "bshift": "../autoencoder/configs/data/burgersshift.yaml",
    "bscale": "../autoencoder/configs/data/burgersscale.yaml",
    "tscale": "../autoencoder/configs/data/hatsscale.yaml",
    "tshift": "../autoencoder/configs/data/hatsshift.yaml",
    "kscale": "../autoencoder/configs/data/KdVscale.yaml",
    "kshift": "../autoencoder/configs/data/KdVshift.yaml",
    # "b2width": "../autoencoder/configs/data/burgers2dwidth.yaml",
    # "t2scale": "../autoencoder/configs/data/hats2dscale.yaml"
}
# DCONFIGS = {
#     "tboth": "../autoencoder/configs/data/hatsboth.yaml",
#     "tbothrandom": "../autoencoder/configs/data/hatsboth_randomt.yaml",
#     "tbothrandomstable": "../autoencoder/configs/data/hatsboth_randomtstable.yaml",
# }

# DCONFIGS = {
#     "lshift": "../autoencoder/configs/data/l96.yaml",
# }
WCONFIGS = {
    "weldff": "../autoencoder/configs/experiments/weldnormal.yaml",
    "weldpca": "../autoencoder/configs/experiments/weldpca.yaml",
    "ffti": "../autoencoder/configs/experiments/ffnetnormal.yaml",
    "donti": "../autoencoder/configs/experiments/donnormal.yaml",
}
def configure_model_and_data(dname, wname,):
    # apply your prefered modification here
    return DCONFIGS[dname], WCONFIGS[wname]
def get_all_data():
    return DCONFIGS.keys()

def get_periodic_bc_1d_data():
    return ["bshift", "bscale"]
def get_zero_bc_1d_data():
    return ["tscale", "tshift", "kscale", "kshift"]
def get_2d_data():
    return ["b2width", "t2scale"]

In [12]:
for dname in list(get_all_data()):
    dconfig, wconfig = configure_model_and_data(dname, "ffti")
    batch_trainer(dconfig, wconfig)
    batch_tester(dconfig, wconfig)

Searching for [<class 'models.FFNet'>, ([513, 400, 400, 400, 512],)] [-1.0, (500, 51, 512)]
Load failed. Could not match with any files
['savedmodels/timeinput\\bshiftFFNet-01-October-2024-20.24.pickle']
Tensorboard writer location is ./tensorboard/01-October-2024/bshiftFFNet/21.17.27/
Number of NN trainable parameters 731712
Starting training TI model FFNet at Tue Oct  1 21:17:27 2024...
train torch.Size([400, 51, 512]) test (100, 51, 512)
[0.001]
1: Train Loss 3.965e-02, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.355283, 0.389606, 0.724861
[0.001]
51: Train Loss 6.412e-04, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.047837, 0.053971, 0.158051
[0.001]
101: Train Loss 3.716e-04, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.040437, 0.046569, 0.136504
[0.001]
151: Train Loss 4.259e-04, LR 1.000e-03, Relative TI Error (1, 2, inf): 0.036888, 0.041808, 0.113890
[0.0001]
201: Train Loss 5.718e-05, LR 1.000e-04, Relative TI Error (1, 2, inf): 0.013319, 0.016507, 0.076259
[0.0001]
251

: 