In [1]:
# Main package imports
import os
import time
import torch
import functools
import torch.nn as nn
import torch.nn.functional as F
import math
from math import *
import numpy as np
from copy import deepcopy
import forge
import forge.experiment_tools as fet
from forge import flags
from tqdm import tqdm

# Specific utility imports
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from oil.utils.utils import FixedNumpySeed, islice
from oil.datasetup.datasets import IndexedDataset
from sklearn.model_selection import train_test_split
from types import SimpleNamespace
from einops.einops import rearrange, reduce
from attrdict import AttrDict
from collections import OrderedDict
from torchdiffeq import odeint

# LieConv dependencies
from lie_conv.datasets import SpringDynamics
from lie_conv.lieGroups import SE3
from lie_conv.lieConv import Swish
from lie_conv.utils import Pass, Expression
from lie_conv.masked_batchnorm import MaskBatchNormNd
from lie_conv.dynamicsTrainer import Partial
from lie_conv.hamiltonian import SpringV, SpringH, HamiltonianDynamics, KeplerV, KeplerH
from lie_conv.dynamicsTrainer import HNet
from lie_conv.hamiltonian import HamiltonianDynamics
from lie_conv.lieGroups import T, SE2, SE2_canonical, SO2

# Import auxiliary functions 
from lie_transformer.eqv_transformer.train_tools import (
    log_tensorboard,
    parse_reports,
    parse_reports_cpu,
    print_reports,
    load_checkpoint,
    save_checkpoint,
    nested_to,
    param_count,
)
from lie_transformer.eqv_transformer.multihead_neural import (
    MultiheadWeightNet,
    MultiheadMLP,
    LinearBNact,
    MLP,
)
from lie_transformer.eqv_transformer.kernels import (
    AttentionKernel,
    SumKernel,
    DotProductKernel,
    RelativePositionKernel,
)

# Model, data, training configuration

LieTransformer uses the "forge" library to simplify training, evaluation, and inference. The following flags allow the user to specify various settings either in the script (like we are doing here) or via the command line. 

In [2]:
flags_evaluated = False

In [3]:
# Training configuration & flags

if not flags_evaluated:
    flags.DEFINE_string(
        "results_dir", "lie_transformer/checkpoints/", "Top directory for all experimental results."
    )
    
    # Configuration files to load
    flags.DEFINE_string(
        "data_config",
        "lie_transformer/configs/dynamics/spring_dynamics_data.py",
        "Path to a data config file.",
    )
    flags.DEFINE_string(
        "model_config",
        "lie_transformer/configs/dynamics/eqv_transformer_model.py",
        "Path to a model config file.",
    )
    # Job management
    flags.DEFINE_string(
        "run_name",
        "demo_run",
        "Name of this job and name of results folder.",
    )
    flags.DEFINE_boolean("resume", False, "Tries to resume a job if True.")
    
    # Logging
    flags.DEFINE_integer(
        "report_loss_every", 10, "Number of iterations between reporting minibatch loss."
    )
    flags.DEFINE_integer(
        "save_check_points",
        10,
        "frequency with which to save checkpoints, in number of epoches.",
    )
    flags.DEFINE_boolean("log_train_values", True, "Logs train values if True.")
    flags.DEFINE_integer(
        "total_evaluations",
        100,
        "Maximum number of evaluations on test and validation data during training.",
    )
    
    # Optimization
    flags.DEFINE_integer("train_epochs", 3, "Maximum number of training epochs.")
    flags.DEFINE_integer("batch_size", 100, "Mini-batch size.")
    flags.DEFINE_float("learning_rate", 1e-3, "Adam learning rate.")
    flags.DEFINE_float("beta1", 0.9, "Adam Beta 1 parameter")
    flags.DEFINE_float("beta2", 0.999, "Adam Beta 2 parameter")
    flags.DEFINE_string("lr_schedule", "cosine_annealing", "Learning rate schedule.")
    
    # GPU device
    flags.DEFINE_integer("device", 0, "GPU to use.")
    
    # Debug mode tracks more stuff
    flags.DEFINE_boolean("debug", False, "Track and show on tensorboard more metrics.")
    flags.DEFINE_boolean(
        "save_test_predictions",
        True,
        "Makes and saves test predictions on one or more test sets (e.g. 5-step and 100-step predictions) at the end of training.",
    )
    flags.DEFINE_boolean(
        "log_val_test", True, "Turns off computation of validation and test errors."
    )
    
    flags.DEFINE_string("group", "T(2)", "Group to be invariant to.")
    
    
    flags.DEFINE_integer("dim_hidden", 160, "Dimension of features to use in each layer")
    flags.DEFINE_string(
        "activation_function", "swish", "Activation function to use in the network"
    )
    flags.DEFINE_boolean(
        "mean_pooling",
        True,
        "Use mean pooling insteave of sum pooling in the invariant layer",
    )
    flags.DEFINE_integer("num_heads", 8, "Number of attention heads in each layer")
    flags.DEFINE_integer("kernel_dim", 16, "Hidden layer size to use in kernel MLPs")
    flags.DEFINE_integer("num_layers", 5, "Number of ResNet layers to use")
    flags.DEFINE_integer(
        "lift_samples",
        1,
        "Number of coset lift samples to use for non-trivial stabilisers.",
    )
    flags.DEFINE_integer("model_seed", 0, "Model rng seed")
    flags.DEFINE_string(
        "attention_fn", "dot_product", "How to form the attention weights from the 'logits'."
    )
    
    flags.DEFINE_string(
        "block_norm", "layer_pre", "Normalization to use around the attention blocks."
    )
    flags.DEFINE_string("output_norm", "none", "Normalization to use in final output MLP.")
    flags.DEFINE_string("kernel_norm", "none", "Normalization to use in kernel MLP.")
    flags.DEFINE_string("kernel_type", "mlp", "Attention kernel type.")
    flags.DEFINE_string("architecture", "model_1", "Overall model architecture.")
    flags.DEFINE_boolean(
        "model_with_dict",
        True,
        "Makes model output predictions in dictionary instead of directly."
    )
    
    
    flags_evaluated = True

In [4]:
# auxiliary functions

def evaluate(model, loader, device):
    reports = None
    for data in loader:
        data = nested_to(data, device, torch.float32)
        outputs = model(data)

        if reports is None:
            reports = {k: v.detach().clone().cpu() for k, v in outputs.reports.items()}
        else:
            for k, v in outputs.reports.items():
                reports[k] += v.detach().clone().cpu()

    for k, v in reports.items():
        reports[k] = v / len(
            loader
        )  

    return reports

# Construct model

## Equivariant multihead attention & equivariant transformer

In [5]:
class EquivairantMultiheadAttention(nn.Module):
    def __init__(
        self,
        c_in,
        c_out,
        n_heads,
        group,
        kernel_type="mlp",
        kernel_dim=16,
        act="swish",
        bn=False,
        mc_samples=0,
        fill=1.0,
        attention_fn="softmax",
        feature_embed_dim=None,
    ):

        super().__init__()

        self.c_in = c_in
        self.c_out = c_out
        self.n_heads = n_heads
        self.group = group

        self.mc_samples = mc_samples
        self.fill = fill
        self.kernel_type = kernel_type

        if not (attention_fn in ["softmax", "dot_product", "norm_exp"]):
            raise NotImplementedError(f"{attention_fn} not implemented.")
        self.attention_fn = attention_fn

        if len(kernel_type) == 4:
            normalisation = ["none", "softmax", "dot_product"]
            self.attention_fn = normalisation[int(kernel_type[0])]

            location_feature_combination = ["none", "sum", "mlp", "multiply"]
            location_feature_combination = location_feature_combination[
                int(kernel_type[1])
            ]

            feature_featurisation = [
                "none",
                "dot_product",
                "linear_concat",
                "linear_concat_linear",
            ]
            feature_featurisation = feature_featurisation[int(kernel_type[2])]

            location_featurisation = ["none", "mlp", "none"]
            location_featurisation = location_featurisation[int(kernel_type[3])]

            self.kernel = AttentionKernel(
                c_in,
                group.lie_dim + 2 * group.q_dim,
                n_heads,
                feature_featurisation=feature_featurisation,
                location_featurisation=location_featurisation,
                location_feature_combination=location_feature_combination,
                hidden_dim=kernel_dim,
                feature_embed_dim=feature_embed_dim,
                activation=act,
            )

        elif kernel_type == "mlp":
            self.kernel = SumKernel(
                MultiheadWeightNet(
                    group.lie_dim + 2 * group.q_dim,
                    1,
                    n_heads,
                    hid_dim=kernel_dim,
                    act=act,
                    bn=bn,
                ),
                DotProductKernel(c_in, c_in, c_in, n_heads=n_heads),
                n_heads,
            )
        elif kernel_type == "relative_position":
            self.kernel = RelativePositionKernel(
                c_in,
                c_in,
                group.lie_dim + 2 * group.q_dim,
                n_heads=n_heads,
                bias=True,
                lamda=1.0,
            )
        elif kernel_type == "dot_product_only":
            self.kernel = SumKernel(
                lambda x: [
                    None,
                    torch.zeros(x[1].shape[:-1], device=x[2].device).unsqueeze(-1),
                    None,
                ],  # unsure what's going on here. Dims don't match so trying to fix it.
                DotProductKernel(c_in, c_in, c_in, n_heads=n_heads),
                n_heads,
            )
        elif kernel_type == "location_only":
            self.kernel = SumKernel(
                MultiheadWeightNet(
                    group.lie_dim + 2 * group.q_dim,
                    1,
                    n_heads,
                    hid_dim=kernel_dim,
                    act=act,
                    bn=bn,
                ),
                lambda x1, x2, x3: 0,
                n_heads,
            )
        else:
            raise ValueError(f"{kernel_type} is not a valid kernel type")

        self.input_linear = nn.Linear(c_in, c_out)
        self.output_linear = nn.Linear(c_out, c_out)

    def extract_neighbourhoods(self, input, query_indices=None):
        """Extracts which points each other point is to attend to based on distance, or graph structure


        Parameters
        ----------
        input : (pairwise_g, coset_functions, mask)
        """
        # TODO: Currently no down sampling in this step.

        pairwise_g, coset_functions, mask = input

        if query_indices is not None:
            raise NotImplementedError()
        else:
            coset_functions_at_query = coset_functions
            mask_at_query = mask
            pairwise_g_at_query = pairwise_g

        if self.mc_samples > 0:
            dists = self.group.distance(pairwise_g_at_query)
            dists = torch.where(
                mask[:, None, :].expand(*dists.shape),
                dists,
                1e8 * torch.ones_like(dists),
            )
            k = (
                coset_functions.shape[1]
                if not self.mc_samples
                else min(self.mc_samples, coset_functions.shape[1])
            )
            k_ball = (
                coset_functions.shape[1]
                if not self.mc_samples
                else min(int(self.mc_samples / self.fill), coset_functions.shape[1])
            )
            _, points_in_ball_indices = dists.topk(
                k=k_ball, dim=-1, largest=False, sorted=False
            )
            ball_indices = torch.randperm(k_ball)[:k]

            nbhd_idx = points_in_ball_indices[:, :, ball_indices]

        else:
            nbhd_idx = (
                torch.arange(coset_functions.shape[1], device=coset_functions.device)
                .long()[None, None, :]
                .expand(pairwise_g.shape[:-1])
            )

        # Get batch index array
        BS = (
            torch.arange(coset_functions.shape[0], device=coset_functions.device)
            .long()[:, None, None]
            .expand(*nbhd_idx.shape)
        )
        # Get NNS indexes
        NNS = (
            torch.arange(coset_functions.shape[1], device=coset_functions.device)
            .long()[None, :, None]
            .expand(*nbhd_idx.shape)
        )

        nbhd_pairwise_g = pairwise_g[
            BS, NNS, nbhd_idx
        ]  # (bs, n * ns, n * ns, g_dim) -> (bs, n * ns, nbhd_size, g_dim)
        # nbhd_coset_functions = coset_functions[
        #     BS, nbhd_idx
        # ]  # (bs, n * ns, c_in) -> (bs, n * ns, nbhd_size, c_in)
        nbhd_mask = mask[BS, nbhd_idx]  # (bs, n * ns) -> (bs, n * ns, nbhd_size)

        # (bs, n * ns, nbhd_size, g_dim), (bs, n * ns, nbhd_size, c_in), (bs, n * ns, nbhd_size), (bs, n * ns, nbhd_size)
        return (
            nbhd_pairwise_g,
            None,
            nbhd_mask,
            nbhd_idx,
            BS,
            NNS,
        )  # TODO: last two are conveniences - is there an easier way to do this?

    def forward(self, input):

        # (bs, n * ns, n * ns, g_dim), (bs, n * ns, c_in), (bs, n * ns)
        pairwise_g, coset_functions, mask = input
        bs, n, d = coset_functions.shape

        # (bs, n * ns, nbhd_size, g_dim), (bs, n * ns, nbhd_size, c_in), (bs, n * ns, nbhd_size), (bs, n * ns, nbhd_size)
        (
            nbhd_pairwise_g,
            nbhd_coset_functions,
            nbhd_mask,
            nbhd_idx,
            BS,
            NNS,
        ) = self.extract_neighbourhoods(input)

        # (bs, n * ns, n * ns, g_dim), (bs, n * ns, c_in), (bs, n * ns, nbhd_size, c_in) -> (bs, n * ns, nbhd_size, h)
        presoftmax_weights = self.kernel(
            nbhd_pairwise_g, nbhd_mask, coset_functions, coset_functions, nbhd_idx
        )

        if self.attention_fn == "softmax":
            # Make masked areas very small attention weights
            presoftmax_weights = torch.where(
                # (bs, n * ns, nbhd_size) -> (bs, n * ns, nbhd_size, 1). Constant along head dim
                nbhd_mask.unsqueeze(-1),
                presoftmax_weights,
                torch.tensor(
                    -1e38,
                    dtype=presoftmax_weights.dtype,
                    device=presoftmax_weights.device,
                )
                * torch.ones_like(presoftmax_weights),
            )

            # Compute the normalised attention weights
            # (bs, n * ns, nbhd_size, h) -> (bs, n * ns, nbhd_size, h)
            attention_weights = F.softmax(presoftmax_weights, dim=2)

        elif self.attention_fn == "norm_exp":
            # Make masked areas very small attention weights
            presoftmax_weights = torch.where(
                # (bs, n * ns, nbhd_size) -> (bs, n * ns, nbhd_size, 1). Constant along head dim
                nbhd_mask.unsqueeze(-1),
                presoftmax_weights,
                torch.tensor(
                    -1e38,
                    dtype=presoftmax_weights.dtype,
                    device=presoftmax_weights.device,
                )
                * torch.ones_like(presoftmax_weights),
            )

            # Compute the normalised attention weights
            # (bs, n * ns, nbhd_size, h) -> (bs, n * ns, nbhd_size, h)
            attention_weights = presoftmax_weights.exp()
            normalization = nbhd_mask.unsqueeze(-1).sum(-2, keepdim=True)
            normalization = torch.clamp(normalization, min=1)
            attention_weights = attention_weights / normalization

        # From the non-local attention paper
        elif self.attention_fn == "dot_product":
            attention_weights = torch.where(
                # (bs, n * ns, nbhd_size) -> (bs, n * ns, nbhd_size, 1). Constant along head dim
                nbhd_mask.unsqueeze(-1),
                presoftmax_weights,
                torch.tensor(
                    0.0,
                    dtype=presoftmax_weights.dtype,
                    device=presoftmax_weights.device,
                )
                * torch.ones_like(presoftmax_weights),
            )

            normalization = nbhd_mask.unsqueeze(-1).sum(-2, keepdim=True)
            normalization = torch.clamp(normalization, min=1)

            # Compute the normalised attention weights
            # (bs, n * ns, nbhd_size, h) -> (bs, n * ns, nbhd_size, h)
            attention_weights = attention_weights / normalization

        # Pass the inputs through the value linear layer
        # (bs, n * ns, nbhd_size, c_in) -> (bs, n * ns, nbhd_size, c_out)
        coset_functions = self.input_linear(coset_functions)

        attention_weights_expanded = torch.zeros(
            (bs, n, n, self.n_heads),
            dtype=attention_weights.dtype,
            device=attention_weights.device,
        )

        # (bs, n, n, h) hopefully?
        attention_weights_expanded[BS, NNS, nbhd_idx] = attention_weights
        attention_weights_expanded = rearrange(
            attention_weights_expanded, "b n m h -> b h n m"
        )

        coset_functions = rearrange(
            coset_functions, "b m (h d) -> b h m d", h=self.n_heads
        )

        coset_functions = attention_weights_expanded.matmul(coset_functions)
        coset_functions = rearrange(coset_functions, "b h n d -> b n (h d)")

        coset_functions = self.output_linear(coset_functions)

        # ( (bs, n * ns, n * ns, g_dim), (bs, n * ns, c_out), (bs, n * ns) )
        return (pairwise_g, coset_functions, mask)


class EquivariantTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        n_heads,
        group,
        block_norm="layer_pre",
        kernel_norm="none",
        kernel_type="mlp",
        kernel_dim=16,
        kernel_act="swish",
        hidden_dim_factor=1,
        mc_samples=0,
        fill=1.0,
        attention_fn="softmax",
        feature_embed_dim=None,
    ):
        super().__init__()
        self.ema = EquivairantMultiheadAttention(
            dim,
            dim,
            n_heads,
            group,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            act=kernel_act,
            bn=kernel_norm == "batch",
            mc_samples=mc_samples,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        self.mlp = MLP(dim, dim, dim, 2, kernel_act, kernel_norm == "batch")

        if block_norm == "none":
            self.attention_function = lambda inpt: inpt[1] + self.ema(inpt)[1]
            self.mlp_function = lambda inpt: inpt[1] + self.mlp(inpt)[1]
        elif block_norm == "layer_pre":
            self.ln_ema = nn.LayerNorm(dim)
            self.ln_mlp = nn.LayerNorm(dim)

            self.attention_function = (
                lambda inpt: inpt[1]
                + self.ema((inpt[0], self.ln_ema(inpt[1]), inpt[2]))[1]
            )
            self.mlp_function = (
                lambda inpt: inpt[1]
                + self.mlp((inpt[0], self.ln_mlp(inpt[1]), inpt[2]))[1]
            )
        elif block_norm == "layer_post":
            self.ln_ema = nn.LayerNorm(dim)
            self.ln_mlp = nn.LayerNorm(dim)

            self.attention_function = lambda inpt: inpt[1] + self.ln_ema(
                self.ema(inpt)[1]
            )
            self.mlp_function = lambda inpt: inpt[1] + self.ln_mlp(self.mlp(inpt)[1])
        elif block_norm == "batch_pre":
            self.bn_ema = MaskBatchNormNd(dim)
            self.bn_mlp = MaskBatchNormNd(dim)

            self.attention_function = (
                lambda inpt: inpt[1] + self.ema(self.bn_ema(inpt))[1]
            )
            self.mlp_function = lambda inpt: inpt[1] + self.mlp(self.bn_mlp(inpt))[1]
        elif block_norm == "batch_post":
            self.bn_ema = MaskBatchNormNd(dim)
            self.bn_mlp = MaskBatchNormNd(dim)

            self.attention_function = (
                lambda inpt: inpt[1] + self.bn_ema(self.ema(inpt))[1]
            )
            self.mlp_function = lambda inpt: inpt[1] + self.bn_mlp(self.mlp(inpt))[1]
        else:
            raise ValueError(f"{block_norm} is invalid block norm type.")

    def forward(self, inpt):
        inpt[1] = self.attention_function(inpt)
        inpt[1] = self.mlp_function(inpt)

        return inpt


class GlobalPool(nn.Module):
    """computes values reduced over all spatial locations (& group elements) in the mask"""

    def __init__(self, mean=False):
        super().__init__()
        self.mean = mean

    def forward(self, x):
        """x [xyz (bs,n,d), vals (bs,n,c), mask (bs,n)]"""
        if len(x) == 2:
            return x[1].mean(1)
        coords, vals, mask = x

        if self.mean:
            # mean pooling
            summed = torch.where(mask.unsqueeze(-1), vals, torch.zeros_like(vals)).sum(
                1
            )
            summed_mask = mask.sum(-1).unsqueeze(-1)
            summed_mask = torch.where(
                summed_mask == 0, torch.ones_like(summed_mask), summed_mask
            )
            summed /= summed_mask

            return summed
        else:
            # max pooling
            masked = torch.where(
                mask.unsqueeze(-1),
                vals,
                torch.tensor(
                    -1e38,
                    dtype=vals.dtype,
                    device=vals.device,
                )
                * torch.ones_like(vals),
            )

            return masked.max(dim=1)[0]


class EquivariantTransformer(nn.Module):
    def __init__(
        self,
        dim_input,
        dim_output,
        dim_hidden,
        num_layers,
        num_heads,
        global_pool=True,
        global_pool_mean=True,
        group=SE3(0.2),
        liftsamples=1,
        block_norm="layer_pre",
        output_norm="none",
        kernel_norm="none",
        kernel_type="mlp",
        kernel_dim=16,
        kernel_act="swish",
        mc_samples=0,
        fill=1.0,
        architecture="model_1",
        attention_fn="softmax",  # softmax or dot product? SZ: TODO: "dot product" is used to describe both the attention weights being non-softmax (non-local attention paper) and the feature kernel. should fix terminology
        feature_embed_dim=None,
        max_sample_norm=None,
        lie_algebra_nonlinearity=None,
    ):
        super().__init__()

        if isinstance(dim_hidden, int):
            dim_hidden = [dim_hidden] * (num_layers + 1)

        if isinstance(num_heads, int):
            num_heads = [num_heads] * num_layers

        attention_block = lambda dim, n_head: EquivariantTransformerBlock(
            dim,
            n_head,
            group,
            block_norm=block_norm,
            kernel_norm=kernel_norm,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            kernel_act=kernel_act,
            mc_samples=mc_samples,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        activation_fn = {
            "swish": Swish,
            "relu": nn.ReLU,
            "softplus": nn.Softplus,
        }

        if architecture == "model_1":
            if output_norm == "batch":
                norm1 = nn.BatchNorm1d(dim_hidden[-1])
                norm2 = nn.BatchNorm1d(dim_hidden[-1])
                norm3 = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "layer":
                norm1 = nn.LayerNorm(dim_hidden[-1])
                norm2 = nn.LayerNorm(dim_hidden[-1])
                norm3 = nn.LayerNorm(dim_hidden[-1])
            elif output_norm == "none":
                norm1 = nn.Sequential()
                norm2 = nn.Sequential()
                norm3 = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                GlobalPool(mean=global_pool_mean)
                if global_pool
                else Expression(lambda x: x[1]),
                nn.Sequential(
                    norm1,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm2,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm3,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_output),
                ),
            )
        elif architecture == "lieconv":
            if output_norm == "batch":
                norm = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "none":
                norm = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                nn.Sequential(
                    OrderedDict(
                        [
                            # ("norm", Pass(norm, dim=1)),
                            (
                                "activation",
                                Pass(
                                    activation_fn[kernel_act](),
                                    dim=1,
                                ),
                            ),
                            (
                                "linear",
                                Pass(nn.Linear(dim_hidden[-1], dim_output), dim=1),
                            ),
                        ]
                    )
                ),
                GlobalPool(mean=global_pool_mean)
                if global_pool
                else Expression(lambda x: x[1]),
            )
        else:
            raise ValueError(f"{architecture} is not a valid architecture.")

        self.group = group
        self.liftsamples = liftsamples
        self.max_sample_norm = max_sample_norm

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )

    def forward(self, input):
        if self.max_sample_norm is None:
            lifted_data = self.group.lift(input, self.liftsamples)
        else:
            lifted_data = [
                torch.tensor(self.max_sample_norm * 2, device=input[0].device),
                0,
                0,
            ]
            while lifted_data[0].norm(dim=-1).max() > self.max_sample_norm:
                lifted_data = self.group.lift(input, self.liftsamples)

        if self.lie_algebra_nonlinearity is not None:
            lifted_data = list(lifted_data)
            pairs_norm = lifted_data[0].norm(dim=-1) + 1e-6
            lifted_data[0] = lifted_data[0] * (
                self.lie_algebra_nonlinearity(pairs_norm / 7) / pairs_norm
            ).unsqueeze(-1)

        return self.net(lifted_data)

## Construct dynamics predictor (specific application of EquivariantTransformer)

In [6]:
class DynamicsPredictor(nn.Module):
    """This class implements forward pass through our model, including loss computation."""

    def __init__(self, predictor, debug=False, task="spring", model_with_dict=True):
        super().__init__()
        self.predictor = predictor
        self.debug = debug
        self.task = task
        self.model_with_dict = model_with_dict

        if self.debug:
            print("DynamicsPredictor is in DEBUG MODE.")

    def _rollout_model(self, z0, ts, sys_params, tol=1e-4):
        """inputs [z0: (bs, z_dim), ts: (bs, T), sys_params: (bs, n, c)]
        outputs pred_zs: (bs, T, z_dim)"""
        dynamics = Partial(self.predictor, sysP=sys_params)
        zs = odeint(dynamics, z0, ts[0], rtol=tol, method="rk4")
        return zs.permute(1, 0, 2)

    def forward(self, data):
        o = AttrDict()

        (z0, sys_params, ts), true_zs = data

        pred_zs = self._rollout_model(z0, ts, sys_params)
        mse = (pred_zs - true_zs).pow(2).mean()

        if self.debug:
            if self.task == "spring":
                # currently a bit inefficient to do the below?
                with torch.no_grad():
                    (z0, sys_params, ts), true_zs = data

                    z = z0
                    m = sys_params[..., 0]  # assume the first component encodes masses
                    D = z.shape[-1]  # of ODE dims, 2*num_particles*space_dim
                    q = z[:, : D // 2].reshape(*m.shape, -1)
                    p = z[:, D // 2 :].reshape(*m.shape, -1)
                    V_pred = self.predictor.compute_V((q, sys_params))

                    k = sys_params[..., 1]
                    V_true = SpringV(q, k)

                    mse_V = (V_pred - V_true).pow(2).mean()

                    # dynamics
                    dyn_tz_pred = self.predictor(ts, z0, sys_params)

                    H = lambda t, z: SpringH(
                        z, sys_params[..., 0].squeeze(-1), sys_params[..., 1].squeeze(-1)
                    )
                    dynamics = HamiltonianDynamics(H, wgrad=False)
                    dyn_tz_true = dynamics(ts, z0)

                    mse_dyn = (dyn_tz_true - dyn_tz_pred).pow(2).mean()

            o.mse_dyn = mse_dyn
            o.mse_V = mse_V

        o.prediction = pred_zs
        o.mse = mse
        o.loss = mse  # loss wrt which we train the model

        if self.debug:
            o.reports = AttrDict({"mse": o.mse, "mse_V": o.mse_V, "mse_dyn": o.mse_dyn})
        else:
            o.reports = AttrDict({"mse": o.mse})

        if not self.model_with_dict:
            return pred_zs

        return o

class DynamicsEquivariantTransformer(EquivariantTransformer, HNet):
    def __init__(self, center=True, **kwargs):
        super().__init__(**kwargs)
        self.center = center
        self.nfe = 0

    def forward(self, t, z, sysP, wgrad=True):
        dynamics = HamiltonianDynamics(
            lambda t, z: self.compute_H(z, sysP), wgrad=wgrad
        )
        return dynamics(t, z)

    def compute_V(self, x):
        """Input is a canonical position variable and the system parameters,
        shapes (bs, n,d) and (bs,n,c)"""
        q, sys_params = x
        mask = ~torch.isnan(q[..., 0])
        if self.center:
            q = q - q.mean(1, keepdims=True)
        return super().forward((q, sys_params, mask)).squeeze(-1)

## Construction of specific model

In [7]:
config = forge.config() # load configuration flags

network = DynamicsEquivariantTransformer(
        group=T(2),
        dim_input=2,
        dim_output=1,  # Potential term in Hamiltonian is scalar
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        architecture=config.architecture,
        attention_fn=config.attention_fn,
    )

model = DynamicsPredictor(network, debug=False, task="spring", model_with_dict=config.model_with_dict)
model

DynamicsPredictor(
  (predictor): DynamicsEquivariantTransformer(
    (net): Sequential(
      (0): Pass(
        (module): Linear(in_features=2, out_features=160, bias=True)
      )
      (1): EquivariantTransformerBlock(
        (ema): EquivairantMultiheadAttention(
          (kernel): SumKernel(
            (location_kernel): Sequential(
              (LinNormAct_1): Sequential(
                (linear): Pass(
                  (module): MultiheadLinear()
                )
                (norm): Sequential()
                (activation): Pass(
                  (module): Expression()
                )
              )
              (LinNormAct_2): Sequential(
                (linear): Pass(
                  (module): MultiheadLinear()
                )
                (norm): Sequential()
                (activation): Pass(
                  (module): Expression()
                )
              )
              (LinNormAct_3): Sequential(
                (linear): Pass(
           

# Training and execution

### Load data

In [8]:
# configure data
data_config = SimpleNamespace(**{
    'n_train': 3000,
    'n_test': 2000,
    'n_val': 2000,
    'n_systems': 10000,
    'data_path': './datasets/ODEDynamics/SpringDynamics/',
    'sys_dim': 2,
    'space_dim': 2,
    'data_seed': 0,
    'batch_size': 100,
    'device': 0,
    "num_particles": 6,
    "chunk_len": 5,
    "load_preprocessed":False,
    "nested_and_unshuffled":False
})

In [9]:
# Load data
dataloaders, data_name = fet.load(config.data_config, config = data_config)

train_loader = dataloaders["train"]
test_loader = dataloaders["test"]
val_loader = dataloaders["val"]

Loading 'spring_dynamics_data' from lie_transformer/configs/dynamics/spring_dynamics_data.py


### Set up training and directories

In [10]:
# Set device, optimizer
if torch.cuda.is_available():
    device = f"cuda:{config.device}"
    torch.cuda.set_device(device)
else:
    device = "cpu"

model = model.to(device)
model_params = model.predictor.parameters()

opt_learning_rate = config.learning_rate
model_opt = torch.optim.Adam(
    model_params, lr=opt_learning_rate, betas=(config.beta1, config.beta2)
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            model_opt, config.train_epochs
        )

In [11]:
# directory management
results_folder_name = "demonstration"
checkpoint_dir = os.path.join(config.results_dir, results_folder_name.replace(".", "_"))

experiment_folders = [f for f in os.listdir(checkpoint_dir)
                          if not f.startswith('_') and not f.startswith('.')]

if experiment_folders:
    experiment_folder = int(sorted(experiment_folders, key=lambda x: int(x))[-1])
    if not config.resume:
        experiment_folder += 1
else:
    if config.resume:
        raise ValueError("Can't resume since no experiments were run before in checkpoint"
                         " dir '{}'.".format(checkpoint_dir))
    else:
        experiment_folder = 1

experiment_folder = os.path.join(checkpoint_dir, str(experiment_folder))
if not config.resume:
    os.mkdir(experiment_folder)

logdir = experiment_folder

checkpoint_name = os.path.join(logdir, "model.ckpt")

### Train model

In [12]:
start_epoch = 1
train_iter = (start_epoch - 1) * (
        len(train_loader.dataset) // config.batch_size
    ) + 1
print("Starting training at epoch = {}, iter = {}".format(start_epoch, train_iter))
# Setup tensorboard writing
summary_writer = SummaryWriter(logdir)

train_reports = []
report_all = {}
report_all_val = {}

# Saving model at epoch 0 before training
print("saving model at epoch 0 before training ... ")
save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0)
print("finished saving model at epoch 0 before training")

num_params = param_count(model)
print(f"Number of model parameters: {num_params}")

# Training
start_t = time.time()

total_train_iters = len(train_loader) * config.train_epochs
iters_per_eval = max(1, int(total_train_iters / config.total_evaluations))

assert (
    config.n_train % min(config.batch_size, config.n_train) == 0
), "Batch size doesn't divide dataset size. Can be inaccurate for loss computation (see below)."

training_failed = False
best_val_loss_so_far = 1e7

for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
    model.train()

    for batch_idx, data in enumerate(train_loader):
        data = nested_to(
            data, device, torch.float32
        )  # the format is ((z0, sys_params, ts), true_zs) for data
        true_zs = data[-1]
        if config.model_with_dict:
            outputs = model(data)
        else:
            pred_zs = model(data)
            loss = (pred_zs - true_zs).pow(2).mean()
            outputs = AttrDict({"loss": loss, "prediction": pred_zs})
            outputs.reports = AttrDict({"mse": loss})

        if torch.isnan(outputs.loss):
            if not training_failed:
                epoch_of_nan = epoch
            if (epoch > epoch_of_nan + 1) and training_failed:
                raise ValueError("Loss Nan-ed.")
            training_failed = True

        model_opt.zero_grad()
        outputs.loss.backward(retain_graph=False)

        model_opt.step()

        train_reports.append(parse_reports_cpu(outputs.reports))

        if config.log_train_values:
            reports = parse_reports(outputs.reports)
            if batch_idx % config.report_loss_every == 0:
                log_tensorboard(summary_writer, train_iter, reports, "train/")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(train_loader.dataset) // config.batch_size,
                    prefix="train",
                )
                log_tensorboard(
                    summary_writer,
                    train_iter,
                    {"lr": model_opt.param_groups[0]["lr"]},
                    "hyperparams/",
                )

        # Do learning rate schedule steps per STEP for cosine_annealing_warmup
        if config.lr_schedule == "cosine_annealing_warmup":
            scheduler.step()

        # Logging and evaluation
        if (
            train_iter % iters_per_eval == 0 or (train_iter == total_train_iters)
        ) and config.log_val_test:  # batch_idx % config.evaluate_every == 0:
            model.eval()
            with torch.no_grad():
                reports = evaluate(model, test_loader, device)
                reports = parse_reports(reports)
                reports["time"] = time.time() - start_t
                if report_all == {}:
                    report_all = deepcopy(reports)

                    for d in reports.keys():
                        report_all[d] = [report_all[d]]
                else:
                    for d in reports.keys():
                        report_all[d].append(reports[d])

                log_tensorboard(summary_writer, train_iter, reports, "test/")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(train_loader.dataset) // config.batch_size,
                    prefix="test",
                )

                # repeat for validation data
                reports = evaluate(model, val_loader, device)
                reports = parse_reports(reports)
                reports["time"] = time.time() - start_t
                if report_all_val == {}:
                    report_all_val = deepcopy(reports)

                    for d in reports.keys():
                        report_all_val[d] = [report_all_val[d]]
                else:
                    for d in reports.keys():
                        report_all_val[d].append(reports[d])

                log_tensorboard(summary_writer, train_iter, reports, "val/")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(train_loader.dataset) // config.batch_size,
                    prefix="val",
                )

                if report_all_val["mse"][-1] < best_val_loss_so_far:
                    save_checkpoint(
                        checkpoint_name,
                        f"early_stop",
                        model,
                        model_opt,
                        loss=outputs.loss,
                    )
                    best_val_loss_so_far = report_all_val["mse"][-1]

            model.train()

        train_iter += 1

    # Do learning rate schedule steps per *epoch* for cosine_annealing
    if config.lr_schedule == "cosine_annealing":
        scheduler.step()

    if epoch % config.save_check_points == 0:
        save_checkpoint(
            checkpoint_name, train_iter, model, model_opt, loss=outputs.loss
        )

    dd.io.save(logdir + "/results_dict_train.h5", train_reports)
    dd.io.save(logdir + "/results_dict.h5", report_all)
    dd.io.save(logdir + "/results_dict_val.h5", report_all_val)

# always save final model
save_checkpoint(checkpoint_name, train_iter, model, model_opt, loss=outputs.loss)

if config.save_test_predictions:
    print("Starting to make model predictions on test sets for *final model*.")
    for chunk_len in [5, 100]:
        start_t_preds = time.time()
        data_config = SimpleNamespace(
            **{
                **config.__dict__["__flags"],
                **{"chunk_len": chunk_len, "batch_size": 500},
            }
        )
        dataloaders, data_name = fet.load(config.data_config, config=data_config)
        test_loader_preds = dataloaders["test"]

        torch.cuda.empty_cache()
        with torch.no_grad():
            preds = []
            true = []
            num_datapoints = 0
            for idx, d in enumerate(test_loader_preds):
                true.append(d[-1])
                d = nested_to(d, device, torch.float32)
                outputs = model(d)

                pred_zs = outputs.prediction
                preds.append(pred_zs)

                num_datapoints += len(pred_zs)

                if num_datapoints >= 2000:
                    break

            preds = torch.cat(preds, dim=0).cpu()
            true = torch.cat(true, dim=0).cpu()

            save_dir = osp.join(logdir, f"traj_preds_{chunk_len}_steps_2k_test.pt")
            torch.save(preds, save_dir)

            save_dir = osp.join(logdir, f"traj_true_{chunk_len}_steps_2k_test.pt")
            torch.save(true, save_dir)

            print(
                f"Completed making test predictions for chunk_len = {chunk_len} in {time.time() - start_t_preds:.2f} seconds."
            )

Starting training at epoch = 1, iter = 1
saving model at epoch 0 before training ... 
Saving model training checkpoint to lie_transformer/checkpoints/demonstration/4/model.ckpt-0
finished saving model at epoch 0 before training
Number of model parameters: 841641


  0%|                                                     | 0/3 [00:00<?, ?it/s]

train: time 4.423,  epoch: 1 [0 / 30]: mse:0.121947


  0%|                                                     | 0/3 [00:23<?, ?it/s]


KeyboardInterrupt: 