### Prerequisite for this script
Just run them and move on to next section.

In [1]:
import sys
sys.path.insert(0, "..")
basedir = "../.."

In [6]:
def trainer(dconfig, wconfig,  **kwargs):
    is_training = kwargs.pop("is_training", True)
    from common.config import create_object, load_config
    
    dconfig = load_config(dconfig, [f"datasize.spacedim=1"])
    dset = create_object(dconfig)
    wconfig = load_config(wconfig)
    experiment = create_object(wconfig)
    reg_name = kwargs.pop("reg_name", "base")
    name = f"{dset.name}{wconfig.aeclass}"
    k = wconfig.aeparams.encodeSeq[-1] if "encodeSeq" in wconfig.aeparams else wconfig.aeparams.sample.latents_dims
    kwargs["td"] = f"{name}{reg_name}-{k}-{'auton' if  kwargs['autonomous'] else 'nonauton'}"
    weld = experiment.create_weld(dset, **kwargs)

    baseepochs = kwargs.get("epochs", 150)
        
    is_all_loaded = weld.load_all(name)
    if is_training:
        if not is_all_loaded and kwargs.get("windows",1) > 1:
            is_ae_prop_loaded = weld.load_aes_and_props(name,verbose=True)
        else:
            is_ae_prop_loaded = is_all_loaded

        if not is_ae_prop_loaded:
            if not weld.load_aes(name,verbose=True):
                weld.train_aes(baseepochs, warmstart_epochs=baseepochs, printinterval=25, batch=16, save=True, plottb=False, lr=1e-4)
            weld.train_propagators(baseepochs * 10, batch=32, printinterval=baseepochs, save=True, lr=1e-5)
            
        if not is_all_loaded and kwargs.get("windows",1) > 1:
            weld.train_transcoders(baseepochs * 10, batch=32, printinterval=baseepochs, save=True, lr=1e-5)
    return weld

def batch_trainer(dconfig:str, wconfig:str, **kwargs):
    is_training = kwargs.pop("is_training", True)
    base_config = {"reg_name":"base", "straightness":0, "kinetic":0}
    straight_config = {"reg_name":"straight", "straightness":0.1, "kinetic":0}
    kinetic_config = {"reg_name":"kinetic", "straightness":0, "kinetic":10}
    batched_welds = []
    for autonomous in [True]: # autonomous = True,
        for reg_config in [base_config]:#, straight_config, kinetic_config]:
            welds = []
            for windows in [1, 3, 10]:
                kwargs = {"windows": windows,
                          "autonomous": autonomous,
                          **reg_config}
                welds.append(trainer(dconfig, wconfig, is_training=is_training, **kwargs))
            batched_welds.append(welds)
    return batched_welds

def plots(welds:list):
    from models import WeldAnalyzer
    import matplotlib.pyplot as plt
    print(f"We're now comparing the projection errors from {welds[0].td}")
    fig = WeldAnalyzer.compare_projerrs(welds)
    fig.savefig(f"{welds[0].td}_projerrs.png")
    plt.close(fig)
    
    print(f"We're now comparing the propogation erros from {welds[0].td}")
    fig = WeldAnalyzer.compare_properrs(welds)
    fig.savefig(f"{welds[0].td}_properrs.png")
    plt.close(fig)
    
def tester(*args, **kwargs):
    weld = trainer(*args, is_training=False, **kwargs)
    plots([weld])
    
def batch_tester(*args, **kwargs):
    batched_welds = batch_trainer(*args, is_training=False, **kwargs)
    for welds in batched_welds:
        plots(welds)



### WELDNET Training (& optional: Test) :

*follow the following example to train our weldnet with preferable hyperparameters.*

In [8]:
from models import *

#ConvAutoEncoder

def configure_model_and_data():
    # apply your prefered modification here
    dconfig = "../autoencoder/configs/data/burgers2dwidth.yaml"
    wconfig = "../autoencoder/configs/experiments/weldconv.yaml"
    reg_config = {"reg_name":"none", "straightness":0, "kinetic":0} # change to None if no regularization
    windows=1
    autonomous=True
    kwargs = {"windows": windows,
              "autonomous": autonomous,
              **reg_config}
    
    return dconfig, wconfig, kwargs

### Batched Training  (& optional: Test) ::

*use "batch_trainer" with data and model configuration. Add "aeclass" keyword arguments for CNN and LSTM Autoencoder if necessary.*

In [10]:

DCONFIGS = {
    "bshift": "../autoencoder/configs/data/burgersshift.yaml",
    "bscale": "../autoencoder/configs/data/burgersscale.yaml",
    "tscale": "../autoencoder/configs/data/hatsscale.yaml",
    "tshift": "../autoencoder/configs/data/hatsshift.yaml",
    "kscale": "../autoencoder/configs/data/KdVscale.yaml",
    "kshift": "../autoencoder/configs/data/KdVshift.yaml",
    "b2width": "../autoencoder/configs/data/burgers2dwidth.yaml",
    "t2scale": "../autoencoder/configs/data/hats2dscale.yaml"
}
WCONFIGS = {
    "weldff": "../autoencoder/configs/experiments/weldnormal.yaml",
    "weldconv": "../autoencoder/configs/experiments/weldconv.yaml",
    "weldconv_periodic": "../autoencoder/configs/experiments/weldconv_periodic.yaml",
    "weldconv2d": "../autoencoder/configs/experiments/weldconv2d.yaml"
}
def configure_model_and_data(dname, wname):
    # apply your prefered modification here
    return DCONFIGS[dname], WCONFIGS[wname]
def get_all_data():
    return DCONFIGS.keys()
def get_periodic_bc_1d_data():
    return ["bshift", "bscale"]
def get_zero_bc_1d_data():
    return ["tscale", "tshift", "kscale", "kshift"]
def get_2d_data():
    return ["b2width", "t2scale"]

In [11]:
for dname in get_all_data():
    dconfig, wconfig = configure_model_and_data(dname, "weldff")
    batch_trainer(dconfig, wconfig)
    batch_tester(dconfig, wconfig)

KeyboardInterrupt: 

In [13]:
for dname in get_periodic_bc_1d_data():
    dconfig, wconfig = configure_model_and_data(dname, "weldconv_periodic")
    batch_trainer(dconfig, wconfig)
    batch_tester(dconfig, wconfig)

Propagator failed. Could not match with any files
['savedmodels/weld/props\\bshiftConvAutoEncoder-4-auton1w-props-21-September-2024-22.11.pickle']
Searching for [<class 'autoencoder.networks.ConvAutoEncoder'>, {'sample': {'spatial_resolution': 512, 'latents_dims': 4}, 'downblocks': {'channels': [8, 16, 32, 32, 32], 'kernel_stride_paddings': [[8, 2, 1], [8, 2, 1], [8, 2, 1], [4, 2, 1], [4, 2, 1]], 'actvn': 'relu', 'padding': 'circular'}, 'datadim': 1}, 1, 0, 0] [-1.0, (500, 51, 512)]
NO MATCH [<class 'autoencoder.networks.ConvAutoEncoder'>, {'sample': {'spatial_resolution': 512, 'latents_dims': 4}, 'downblocks': {'channels': [16, 32, 64, 128], 'kernel_stride_paddings': [[8, 2, 1], [8, 2, 1], [4, 2, 1], [4, 2, 1]], 'actvn': 'relu', 'padding': 'circular'}, 'datadim': 1}, 1, 0, 0, (200, 200)] [-1.0, (500, 51, 512)]
Load failed. Could not match with any files
['savedmodels/weld\\bshiftConvAutoEncoder-4-auton1w-21-September-2024-22.11.pickle']
Training 1 WeldNet AEs
Tensorboard writer locati

KeyboardInterrupt: 

In [None]:
for dname in get_zero_bc_1d_data():
    dconfig, wconfig = configure_model_and_data(dname, "weldconv")
    batch_trainer(dconfig, wconfig)
    batch_tester(dconfig, wconfig)

In [None]:
dconfig, wconfig = configure_model_and_data("b2width", "weldconv2d")
batch_trainer(dconfig, wconfig)
batch_tester(dconfig, wconfig)

In [None]:
dconfig, wconfig = configure_model_and_data("t2scale", "weldconv2d_deeper")
batch_trainer(dconfig, wconfig)
batch_tester(dconfig, wconfig)