In [None]:
import os
import time

import pandas as pd
import torch
import torch.multiprocessing as mp

from burgers.approximator import BurgersApproximator
from GNRK.experiment import run
from GNRK.hyperparameter import get_hp
from GNRK.path import DATA_DIR
from heat.approximator import HeatApproximator
from kuramoto.approximator import KuramotoApproximator
from rossler.approximator import RosslerApproximator

In [None]:
hp = get_hp(
    [
        # ----------------- Data -----------------
        "--equation=burgers",
        "--dataset=burgers_dataset1",
        # ------------------ NN ------------------
        "--rk=RK4",
        "--approximator_state_embedding", "8",
        # "--approximator_node_embedding", "16",
        "--approximator_edge_embedding", "8",
        "--approximator_glob_embedding", "8",
        "--approximator_edge_hidden=32",
        "--approximator_node_hidden=32",
        "--approximator_activation=gelu",
        "--approximator_dropout=0.0",
        "--approximator_bn_momentum=-1.0",
        # --------------- Schedular --------------
        "--scheduler_name=step",
        "--scheduler_lr=0.0001",
        "--scheduler_lr_max=0.004",
        "--scheduler_lr_max_mult=0.5",
        "--scheduler_period=20",
        "--scheduler_period_mult=1.5",
        "--scheduler_warmup=0",
        # -------------- Early Stop --------------
        # "--earlystop_patience=60",
        # "--earlystop_delta=0.0",
        # ------------ Train config --------------
        "--weight_decay=0.0",
        "--device", "0", "1", "2", "3",
        # "--seed=0",
        "--port=3184",
        "--epochs=2",
        "--batch_size=64",
        # "--rollout_batch_size=256",
        "--tqdm",
        # "--wandb",
        "--amp",
    ]
)


# Read data

In [None]:
start = time.perf_counter()
train_df = pd.read_pickle(DATA_DIR / f"{hp.dataset}_train.pkl")
val_df = pd.read_pickle(DATA_DIR / f"{hp.dataset}_val.pkl")
print(f"Reading data took {time.perf_counter()-start} seconds")


# Create governing equation approximator

In [None]:
match hp.equation:
    case "burgers":
        approximator = BurgersApproximator.from_hp(hp.approximator)
    case "heat":
        approximator = HeatApproximator.from_hp(hp.approximator)
    case "kuramoto":
        approximator = KuramotoApproximator.from_hp(hp.approximator)
    case "rossler":
        approximator = RosslerApproximator.from_hp(hp.approximator)
    case _:
        raise NotImplementedError(f"No such equation {hp.equation}")


# Training

In [None]:
start = time.perf_counter()
save = False

if len(hp.device) == 1:
    run(0, hp, approximator, train_df, val_df, save)
else:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = f"{hp.port}"

    mp.spawn(  # type:ignore
        run,
        args=(hp, approximator, train_df, val_df, save),
        nprocs=len(hp.device),
        join=True,
    )

print(f"Training took {time.perf_counter()-start} seconds")