## Training DEnKF
DEnKF contains four sub-modules: a state transition model, an observation model, an observation noise model, and a sensor model. The entire framework is trained in an end-to-end manner via a mean squared error (MSE) loss between the ground truth state $\hat{{\bf x}}_{t|t}$ and the estimated state ${\bf \bar{x}}_{t|t}$ at every timestep. We also supervise the intermediate modules via loss gradients $\mathcal{L}_{f_{\pmb {\theta}}}$ and $\mathcal{L}_{s_{\pmb {\xi}}}$. Given ground truth at time $t$, we apply the MSE loss gradient calculated between $\hat{{\bf x}}_{t|t}$ and the output of the state transition model to $f_{\pmb {\theta}}$. We apply the intermediate loss gradients computed based on the ground truth observation $\hat{{\bf y}_t}$ and the output of the stochastic sensor model $\tilde{{\bf y}}_t$: 
    \begin{align}
    \mathcal{L}_{f_{\pmb {\theta}}} =  \| {\bf \bar{x}}_{t|t-1} - \hat{{\bf x}}_{t|t}\|_2^2,\ \ 
        \mathcal{L}_{s_{\pmb {\xi}}} =\| \tilde{{\bf y}_t} -  \hat{{\bf y}_t}\|_2^2.
    \end{align}
    
All models in the experiments were trained for 50 epochs with batch size 64, and a learning rate of $\eta = 10^{-5}$. We chose the model with the best performance on a validation set for testing. The ensemble size of the DEnKF was set to 32 ensemble members.

In this tutorial, we present and elucidate the fundamental training process of DEnkF using the `car tracking` example.

### 1. Set training parameters
We initiate the training process by setting up the training parameters, which involve defining the dimensionality of the state and observation, determining the batch size, and selecting the model mode. This is followed by implementing the training process as a training `engine` class, as illustrated below.

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from dataset import CarDataset
from model import DEnKF
from optimizer import build_optimizer
from optimizer import build_lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import time
import random
import warnings
warnings.filterwarnings('ignore')

class Engine:
    def __init__(self):
        self.batch_size = 64
        self.dim_x = 4
        self.dim_z = 4
        self.num_ensemble = 32
        self.global_step = 0
        self.mode = 'train'
        self.dataset = CarDataset(self.args, self.mode)
        self.model = DEnKF(self.num_ensemble, self.dim_x, self.dim_z)
        # Check model type
        if not isinstance(self.model, nn.Module):
            raise TypeError("model must be an instance of nn.Module")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            self.model.cuda()

Then, we proceed to define the actual training script, which includes declaring the optimizer and learning scheduler.

In [3]:
# define loss function, dataloader, optimizer for training the model
def train(self):
    mse_criterion = nn.MSELoss()
    dataloader = torch.utils.data.DataLoader(
        self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=1
    )
    pytorch_total_params = sum(
        p.numel() for p in self.model.parameters() if p.requires_grad
    )
    print("Total number of parameters: ", pytorch_total_params)

    # Create optimizer
    optimizer_ = build_optimizer(
        [self.model],
        self.args.network.name,
        self.args.optim.optim,
        self.args.train.learning_rate,
        self.args.train.weight_decay,
        self.args.train.adam_eps,
    )

    # Create LR scheduler
    if self.args.mode.mode == "train":
        num_total_steps = self.args.train.num_epochs * len(dataloader)
        scheduler = build_lr_scheduler(
            optimizer_,
            self.args.optim.lr_scheduler,
            self.args.train.learning_rate,
            num_total_steps,
            self.args.train.end_learning_rate,
        )
    # Epoch calculations
    steps_per_epoch = len(dataloader)
    num_total_steps = self.args.train.num_epochs * steps_per_epoch
    epoch = self.global_step // steps_per_epoch
    duration = 0

### 2. Training with curricula
Within the main training loop, it is useful to define different training curricula. Our proposed framework is designed to be modular, providing the flexibility to use individual components independently. However, it is important to acknowledge that various learning tasks may require specific curricula. For instance, complex visual tasks might demand a longer training period for the sensor model before it can be seamlessly integrated into end-to-end learning. Consequently, there is currently no universal curriculum that ensures optimal performance of all sub-modules in any given scenario. In the `car tracking` example, we pretrining the sensor model for `10 epoch` first then conduct end-to-end training.

In [4]:
def train(self):    
    ####################################################################################################
    # MAIN TRAINING LOOP
    ####################################################################################################
    while epoch < self.args.train.num_epochs:
        step = 0
        for data in dataloader:
            # collect data from data loader
            data = [item.to(self.device) for item in data]
            state_ensemble = data[1]
            state_pre = data[0]
            obs = data[3]
            state_gt = data[2]

            # init optimizer
            optimizer_.zero_grad()

            # forward pass
            input_state = (state_ensemble, state_pre)
            obs_action = obs
            output = self.model(obs_action, input_state)

            final_est = output[1]  # -> final estimation
            inter_est = output[2]  # -> state transition output
            obs_est = output[3]  # -> learned observation
            hx = output[5]  # -> observation model output

            # calculate loss
            loss_1 = mse_criterion(final_est, state_gt)
            loss_2 = mse_criterion(inter_est, state_gt)
            loss_3 = mse_criterion(obs_est, state_gt)
            loss_4 = mse_criterion(hx, state_gt)
            
            # define training curricula
            if epoch <= 10:
                final_loss = loss_3
            else:
                final_loss = loss_1 + loss_2 + loss_3 + loss_4


            # back prop
            final_loss.backward()
            optimizer_.step()
            current_lr = optimizer_.param_groups[0]["lr"]

            # verbose
            if self.global_step % self.args.train.log_freq == 0:
                string = "[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}"
                self.logger.info(
                    string.format(
                        epoch,
                        step,
                        steps_per_epoch,
                        self.global_step,
                        current_lr,
                        final_loss,
                    )
                )
                if np.isnan(final_loss.cpu().item()):
                    self.logger.warning("NaN in loss occurred. Aborting training.")
                    return -1

            step += 1
            self.global_step += 1
            if scheduler is not None:
                scheduler.step(self.global_step)

        # Save a model based of a chosen save frequency
        if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0:
            checkpoint = {
                "global_step": self.global_step,
                "model": self.model.state_dict(),
                "optimizer": optimizer_.state_dict(),
            }
            torch.save(
                checkpoint,
                os.path.join(
                    self.args.train.log_directory,
                    self.args.train.model_name,
                    "model-{}".format(self.global_step),
                ),
            )

        # online evaluation
        if (
            self.args.mode.do_online_eval
            and self.global_step != 0
            and epoch + 1 >= 10
            and (epoch + 1) % self.args.train.eval_freq == 0
        ):
            time.sleep(0.1)
            self.model.eval()
            self.test()
            self.model.train()
        # Update epoch
        epoch += 1

### 3. Test
Similar to the training scripts, the `test()` function utilizes the testing dataloader to sequentially feed observation data into the trained filter. The model begins with an initial state and continuously tracks the state recursively while considering only the observations provided. The test output is saved as a pickle file, utilizing a dictionary format for storage.

In [5]:
def test(self):
    test_dataset = CarDataset(self.args, "test")
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, num_workers=1
    )
    step = 0
    data_out = {}
    data_save = []
    ensemble_save = []
    gt_save = []
    obs_save = []
    for data in test_dataloader:
        data = [item.to(self.device) for item in data]
        state_ensemble = data[1]
        state_pre = data[0]
        obs = data[3]
        state_gt = data[2]

        with torch.no_grad():
            if step == 0:
                ensemble = state_ensemble
                state = state_pre
            else:
                ensemble = ensemble
                state = state
            input_state = (ensemble, state)
            obs_action = obs
            output = self.model(obs_action, input_state)

            ensemble = output[0]  # -> ensemble estimation
            state = output[1]  # -> final estimation
            obs_p = output[3]  # -> learned observation

            final_ensemble = ensemble  # -> make sure these variables are tensor
            final_est = state
            obs_est = obs_p

            final_ensemble = final_ensemble.cpu().detach().numpy()
            final_est = final_est.cpu().detach().numpy()
            obs_est = obs_est.cpu().detach().numpy()
            state_gt = state_gt.cpu().detach().numpy()

            data_save.append(final_est)
            ensemble_save.append(final_ensemble)
            gt_save.append(state_gt)
            obs_save.append(obs_est)
            step = step + 1

    data_out["state"] = data_save
    data_out["ensemble"] = ensemble_save
    data_out["gt"] = gt_save
    data_out["observation"] = obs_save

    save_path = os.path.join(
        self.args.train.eval_summary_directory,
        self.args.train.model_name,
        "eval-result-{}.pkl".format(self.global_step),
    )

    with open(save_path, "wb") as f:
        pickle.dump(data_out, f)

### 4. Putting everything together
Within our repository, we employ a `.yaml` file to initialize the parameters and manage all relevant training or testing setups. The implementation of the engine class is located in `/pyTorch/engine.py`, while the `.yaml` file can be found in `/pyTorch/config/car_tracking.yaml`. The following demonstrates how to execute the training script via the command line along with the corresponding logs.

In [None]:
os.system('python train.py --config ./config/car_tracking.yaml')

2023-07-10 14:48:02,786 MainThread INFO DEnKF - mode:
  dist_backend: nccl
  dist_url: tcp://127.0.0.1:2345
  do_online_eval: True
  gpu: None
  mode: train
  multiprocessing_distributed: False
  num_threads: 1
  parameter_path: 
  rank: 0
  world_size: 1
network:
  activation_function: ELU
  encoder: resnet50
  name: DEnKF
optim:
  lr_scheduler: polynomial_decay
  optim: adamw
test:
  checkpoint_path: 
  data_path: ./dataset/car_dataset_test.pkl
  dataset: car_dataset
  dim_a: 
  dim_x: 4
  dim_z: 4
  eigen_crop: False
  garg_crop: False
  input_height: None
  input_size: 
  input_width: None
  model_name: DEnKF
  num_ensemble: 32
train:
  adam_eps: 0.001
  batch_size: 128
  checkpoint_path: 
  data_path: ./dataset/car_dataset_train.pkl
  dataset: car_dataset
  dim_a: 
  dim_x: 4
  dim_z: 4
  end_learning_rate: -1.0
  eval_freq: 5
  eval_summary_directory: ./experiments/
  input_size: 
  learning_rate: 0.0001
  log_directory: ./experiments
  log_freq: 100
  loss: mse
  loss_weights: [

Total number of parameters:  2072664


2023-07-10 14:48:06,324 MainThread INFO DEnKF - [epoch][s/s_per_e/gs]: [0][0/464/0], lr: 0.000100000000, loss: 0.486620664597
2023-07-10 14:49:20,899 MainThread INFO DEnKF - [epoch][s/s_per_e/gs]: [0][100/464/100], lr: 0.000099417894, loss: 0.292591094971
