In [7]:
from typing import Any, Callable, Dict, List, Optional, Type, Union

import numpy as np
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode

from mace.tools.scatter import scatter_sum

from mace.modules.blocks import (
    AtomicEnergiesBlock,
    EquivariantProductBasisBlock,
    InteractionBlock,
    LinearNodeEmbeddingBlock,
    LinearReadoutBlock,
    NonLinearReadoutBlock,
    RadialEmbeddingBlock,
    ScaleShiftBlock,
)
from mace.modules.utils import (
    get_edge_vectors_and_lengths,
    get_outputs,
)

# pylint: disable=C0302

@compile_mode("script")
class MACE(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        max_ell: int,
        interaction_cls: Type[InteractionBlock],
        interaction_cls_first: Type[InteractionBlock],
        num_interactions: int,
        num_elements: int,
        hidden_irreps: o3.Irreps,
        MLP_irreps: o3.Irreps,
        atomic_energies: np.ndarray,
        avg_num_neighbors: float,
        atomic_numbers: List[int],
        correlation: Union[int, List[int]],
        gate: Optional[Callable],
        pair_repulsion: bool = False,
        distance_transform: str = "None",
        radial_MLP: Optional[List[int]] = None,
        radial_type: Optional[str] = "bessel",
        heads: Optional[List[str]] = None,
    ):
        super().__init__()
        # buffers
        self.register_buffer(
            "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)
        )
        self.register_buffer(
            "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
        )
        self.register_buffer(
            "num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
        )

        # handle heads
        if heads is None:
            heads = ["default"]
        self.heads = heads

        # handle correlations
        if isinstance(correlation, int):
            correlation = [correlation] * num_interactions

        # Embedding irreps
        node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
        node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])

        # node embedding
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )

        # radial embedding/bessel
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
            radial_type=radial_type,
            distance_transform=distance_transform,
        )

        edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
        sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
        num_features = hidden_irreps.count(o3.Irrep(0, 1))
        interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()

        # spherical/angular embedding
        self.spherical_harmonics = o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="component"
        )
        if radial_MLP is None:
            radial_MLP = [64, 64, 64]


        # Interactions and readout
        self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies)

        inter = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=hidden_irreps,
            avg_num_neighbors=avg_num_neighbors,
            radial_MLP=radial_MLP,
        )
        self.interactions = torch.nn.ModuleList([inter])

        # Use the appropriate self connection at the first layer for proper E0
        use_sc_first = False
        if "Residual" in str(interaction_cls_first):
            use_sc_first = True

        # First product block
        node_feats_irreps_out = inter.target_irreps

        prod = EquivariantProductBasisBlock(
            node_feats_irreps=node_feats_irreps_out,
            target_irreps=hidden_irreps,
            correlation=correlation[0],
            num_elements=num_elements,
            use_sc=use_sc_first,
        )
        self.products = torch.nn.ModuleList([prod])

        # first readout
        readout = LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
        self.readouts = torch.nn.ModuleList([readout])


        for i in range(num_interactions - 1):
            # for last layer, select only scalers
            if i == num_interactions - 2:
                hidden_irreps_out = str(
                    hidden_irreps[0]
                )  # Select only scalars for last layer
            else:
                hidden_irreps_out = hidden_irreps
            
            # get interaction block
            inter = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=hidden_irreps,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=interaction_irreps,
                hidden_irreps=hidden_irreps_out,
                avg_num_neighbors=avg_num_neighbors,
                radial_MLP=radial_MLP,
            )
            self.interactions.append(inter)

            # get product block
            prod = EquivariantProductBasisBlock(
                node_feats_irreps=interaction_irreps,
                target_irreps=hidden_irreps_out,
                correlation=correlation[i + 1],
                num_elements=num_elements,
                use_sc=True,
            )
            self.products.append(prod)

            # get readout
            # for last layer, non linear readout
            if i == num_interactions - 2:
                self.readouts.append(
                    NonLinearReadoutBlock(
                        hidden_irreps_out,
                        (len(heads) * MLP_irreps).simplify(),
                        gate,
                        o3.Irreps(f"{len(heads)}x0e"),
                        len(heads),
                    )
                )
            # for other layers linear layout
            else:
                self.readouts.append(
                    LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
                )

    def forward(
        self,
        data: Dict[str, torch.Tensor],
        training: bool = False,
        compute_force: bool = True,
        compute_virials: bool = False,
        compute_stress: bool = False,
        compute_displacement: bool = False,
        compute_hessian: bool = False,
    ) -> Dict[str, Optional[torch.Tensor]]:

        # Setup
        data["node_attrs"].requires_grad_(True) # (num_nodes, num_elements)
        data["positions"].requires_grad_(True)  # (num_nodes, 3)

        num_atoms_arange = torch.arange(data["positions"].shape[0])     # (num_nodes,)
        # ptr = (batch_size+1), [0,12,..], each ptr[i] shows starting index of ith molecule/graph
        num_graphs = data["ptr"].numel() - 1        # int: batch_size=num_graphs

        node_heads = (
            data["head"][data["batch"]]
            if "head" in data
            else torch.zeros_like(data["batch"])
        )   # (num_nodes) int32, each idx shows which head that node belongs to

        # (batch_size, 3, 3)
        displacement = torch.zeros(
            (num_graphs, 3, 3),
            dtype=data["positions"].dtype,
            device=data["positions"].device,
        )

        # Atomic energies
        # self.atomic_energies_fn(data["node_attrs"]) => (num_nodes, num_heads)
        node_e0 = self.atomic_energies_fn(data["node_attrs"])[
            num_atoms_arange, node_heads
        ]   # nodes_e0 = (num_nodes), selects only the energy of head under which node comes

        # e0 (batch_size) => sums up energy of nodes in each molecule
        e0 = scatter_sum(
            src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
        )  
        

        # Embeddings 
        # node embedding/features (num_nodes, num_channels_of_0e in hidden rep)
        node_feats = self.node_embedding(data["node_attrs"])    

        # vectors (num_edges,3), lengths (num_edges,1)
        vectors, lengths = get_edge_vectors_and_lengths(
            positions=data["positions"],
            edge_index=data["edge_index"],
            shifts=data["shifts"],
        )
        # edge_attrs=angular_emb = (num_edges, (max_ell+1)^2)
        edge_attrs = self.spherical_harmonics(vectors)
        # edge_feats= radial_emb/bessel (num_edges, num_bessels), with polynomial cutoff applied
        edge_feats = self.radial_embedding(
            lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
        )


        # Interactions
        energies = [e0]
        node_energies_list = [node_e0]
        node_feats_list = []

        for interaction, product, readout in zip(
            self.interactions, self.products, self.readouts
        ):
            # node_feats (num_nodes, num_channels, (max_ell+1)^2)
            # sc (num_nodes, hidden_irrep.dim())
            # RadialMLP is applied inside
            node_feats, sc = interaction(
                node_attrs=data["node_attrs"],
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=data["edge_index"],
            )
            
            # node_feats (num_nodes, hidden_irrep.dim())
            node_feats = product(
                node_feats=node_feats,
                sc=sc,
                node_attrs=data["node_attrs"],
            )
            node_feats_list.append(node_feats)

            # readout (num_nodes, num_heads)
            # node_energies (num_nodes), only the energy corresponding to head of config is selected
            node_energies = readout(node_feats, node_heads)[
                num_atoms_arange, node_heads
            ] 
            # energy (batch_size)
            energy = scatter_sum(
                src=node_energies,
                index=data["batch"],
                dim=0,
                dim_size=num_graphs,
            )

            energies.append(energy)
            node_energies_list.append(node_energies)

        # Concatenate node features
        # (num_nodes, (num_interactions-1)*hidden_irreps.dim() + scalar_channels_hidden_reps)
        node_feats_out = torch.cat(node_feats_list, dim=-1)

        # Sum over energy contributions
        # in contribution, 1 from atomic_energies and other from each readout of interactions
        contributions = torch.stack(energies, dim=-1)   #(batch_size, 1+num_interactions)
        total_energy = torch.sum(contributions, dim=-1)  # (batch_size)
        node_energy_contributions = torch.stack(node_energies_list, dim=-1) #(batch_size, num_nodes)
        node_energy = torch.sum(node_energy_contributions, dim=-1)  # (num_nodes)


        # Outputs
        forces, virials, stress, hessian = get_outputs(
            energy=total_energy,
            positions=data["positions"],
            displacement=displacement,
            cell=data["cell"],
            training=training,
            compute_force=compute_force,
            compute_virials=compute_virials,
            compute_stress=compute_stress,
            compute_hessian=compute_hessian,
        )

        return {
            "energy": total_energy, # (batch_size)
            "node_energy": node_energy, #(num_nodes)
            "contributions": contributions, #(batch_size, 1+num_interactions)
            "forces": forces,   #(num_nodes, 3)
            "virials": virials,
            "stress": stress,
            "displacement": displacement,
            "hessian": hessian,
            "node_feats": node_feats_out,
        }


In [8]:
@compile_mode("script")
class ScaleShiftMACE(MACE):
    def __init__(
        self,
        atomic_inter_scale: float,
        atomic_inter_shift: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # scale shift module
        self.scale_shift = ScaleShiftBlock(
            scale=atomic_inter_scale, shift=atomic_inter_shift
        )

    def forward(
        self,
        data: Dict[str, torch.Tensor],
        training: bool = False,
        compute_force: bool = True,
        compute_virials: bool = False,
        compute_stress: bool = False,
        compute_displacement: bool = False,
        compute_hessian: bool = False,
    ) -> Dict[str, Optional[torch.Tensor]]:
        
        # Setup
        data["positions"].requires_grad_(True) # (num_nodes, num_elements)
        data["node_attrs"].requires_grad_(True) # (num_nodes, 3)

        num_atoms_arange = torch.arange(data["positions"].shape[0])     # (num_nodes,)
        # ptr = (batch_size+1), [0,12,..], each ptr[i] shows starting index of ith molecule/graph
        num_graphs = data["ptr"].numel() - 1        # int: batch_size=num_graphs

        node_heads = (
            data["head"][data["batch"]]
            if "head" in data
            else torch.zeros_like(data["batch"])
        )   # (num_nodes) int32, each idx shows which head that node belongs to

        # (batch_size, 3, 3)
        displacement = torch.zeros(
            (num_graphs, 3, 3),
            dtype=data["positions"].dtype,
            device=data["positions"].device,
        )


        # Atomic energies
        # self.atomic_energies_fn(data["node_attrs"]) => (num_nodes, num_heads)
        node_e0 = self.atomic_energies_fn(data["node_attrs"])[
            num_atoms_arange, node_heads
        ] # nodes_e0 = (num_nodes), selects only the energy of head under which node comes

        # e0 (batch_size) => sums up energy of nodes in each molecule
        e0 = scatter_sum(
            src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs
        )  

        # Embeddings 
        # node embedding/features (num_nodes, num_channels_of_0e in hidden rep)
        node_feats = self.node_embedding(data["node_attrs"])    

        # vectors (num_edges,3), lengths (num_edges,1)
        vectors, lengths = get_edge_vectors_and_lengths(
            positions=data["positions"],
            edge_index=data["edge_index"],
            shifts=data["shifts"],
        )
        # edge_attrs=angular_emb = (num_edges, (max_ell+1)^2)
        edge_attrs = self.spherical_harmonics(vectors)
        # edge_feats= radial_emb/bessel (num_edges, num_bessels), with polynomial cutoff applied
        edge_feats = self.radial_embedding(
            lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers
        )
        
        # Interactions
        node_es_list = []
        node_feats_list = []
        for interaction, product, readout in zip(
            self.interactions, self.products, self.readouts
        ):
            # node_feats (num_nodes, num_channels, (max_ell+1)^2)
            # sc (num_nodes, hidden_irrep.dim())
            # RadialMLP is applied inside
            node_feats, sc = interaction(
                node_attrs=data["node_attrs"],
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=data["edge_index"],
            )

            # node_feats (num_nodes, hidden_irrep.dim())
            node_feats = product(
                node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"]
            )
            node_feats_list.append(node_feats)

            # readout (num_nodes, num_heads)
            # node_energies (num_nodes), only the energy corresponding to head of config is selected
            node_es_list.append(
                readout(node_feats, node_heads)[num_atoms_arange, node_heads]
            )  # [(n_nodes, ), ] -> node_es_list = (num_interactions, n_nodes)


        # Concatenate node features
        # (num_nodes, (num_interactions-1)*hidden_irreps.dim() + scalar_channels_hidden_reps)
        node_feats_out = torch.cat(node_feats_list, dim=-1)

        # Sum over interactions
        # torch.stack(node_es_list, dim=0) (num_interactions, n_nodes)
        # interaction energies
        node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0)  # (n_nodes, )
        # scale it
        node_inter_es = self.scale_shift(node_inter_es, node_heads)  # (n_nodes,)

        # Sum over nodes in graph
        # (batch_size,)
        inter_e = scatter_sum(
            src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs
        ) 

        # Add E_0 and (scaled) interaction energy
        total_energy = e0 + inter_e #(batch_size)
        node_energy = node_e0 + node_inter_es   #(num_nodes)

        # outputs
        forces, virials, stress, hessian = get_outputs(
            energy=inter_e,
            positions=data["positions"],
            displacement=displacement,
            cell=data["cell"],
            training=training,
            compute_force=compute_force,
            compute_virials=compute_virials,
            compute_stress=compute_stress,
            compute_hessian=compute_hessian,
        )
        output = {
            "energy": total_energy, # (batch_size)
            "node_energy": node_energy, #(num_nodes)
            "interaction_energy": inter_e, #(batch_size, 1+num_interactions)
            "forces": forces,  #(num_nodes, 3)
            "virials": virials,
            "stress": stress,
            "hessian": hessian,
            "displacement": displacement,
            "node_feats": node_feats_out,
        }
        return output

In [9]:
###########################################################################################
# Training script
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import dataclasses
import logging
import time
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel
from torch.optim.swa_utils import SWALR, AveragedModel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch_ema import ExponentialMovingAverage
from torchmetrics import Metric

from mace.tools import torch_geometric
from mace.tools.checkpoint import CheckpointHandler, CheckpointState
from mace.tools.torch_tools import to_numpy
from mace.tools.utils import (
    MetricsLogger,
    compute_mae,
    compute_q95,
    compute_rel_mae,
    compute_rel_rmse,
    compute_rmse,
)


@dataclasses.dataclass
class SWAContainer:
    model: AveragedModel
    scheduler: SWALR
    start: int
    loss_fn: torch.nn.Module


def valid_err_log(
    valid_loss,
    eval_metrics,
    logger,
    log_errors,
    epoch=None,
    valid_loader_name="Default",
):
    eval_metrics["mode"] = "eval"
    eval_metrics["epoch"] = epoch
    logger.log(eval_metrics)

    if epoch is None:
        inintial_phrase = "Initial"
    else:
        inintial_phrase = f"Epoch {epoch}"
        
    if log_errors == "PerAtomRMSE":
        error_e = eval_metrics["rmse_e_per_atom"] * 1e3
        error_f = eval_metrics["rmse_f"] * 1e3
        logging.info(
            f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A"
        )
    
    elif log_errors == "TotalRMSE":
        error_e = eval_metrics["rmse_e"] * 1e3
        error_f = eval_metrics["rmse_f"] * 1e3
        logging.info(
            f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A",
        )
    elif log_errors == "PerAtomMAE":
        error_e = eval_metrics["mae_e_per_atom"] * 1e3
        error_f = eval_metrics["mae_f"] * 1e3
        logging.info(
            f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
        )
    elif log_errors == "TotalMAE":
        error_e = eval_metrics["mae_e"] * 1e3
        error_f = eval_metrics["mae_f"] * 1e3
        logging.info(
            f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
        )
    


def train(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    train_loader: DataLoader,
    valid_loaders: Dict[str, DataLoader],
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler.ExponentialLR,
    start_epoch: int,
    max_num_epochs: int,
    patience: int,
    checkpoint_handler: CheckpointHandler,
    logger: MetricsLogger,
    eval_interval: int,
    output_args: Dict[str, bool],
    device: torch.device,
    log_errors: str,
    swa: Optional[SWAContainer] = None,
    ema: Optional[ExponentialMovingAverage] = None,
    max_grad_norm: Optional[float] = 10.0,
):
    # some values to tracks
    lowest_loss = np.inf
    valid_loss = np.inf
    patience_counter = 0
    swa_start = True
    keep_last = False

    if max_grad_norm is not None:
        logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}")

    logging.info("")
    logging.info("===========TRAINING===========")
    logging.info("Started training, reporting errors on validation set")
    logging.info("Loss metrics on validation set")
    epoch = start_epoch

    # log validation loss before _any_ training
    valid_loss = 0.0
    for valid_loader_name, valid_loader in valid_loaders.items():
        # evaluate val loss and metrics for each val_dataloader and log them
        valid_loss_head, eval_metrics = evaluate(
                                                    model=model,
                                                    loss_fn=loss_fn,
                                                    data_loader=valid_loader,
                                                    output_args=output_args,
                                                    device=device,
                                                )
        valid_err_log(
                        valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name
                    )
    valid_loss = valid_loss_head  # consider only the last head for the checkpoint

    # till epoch is less than max_num_epochs
    while epoch < max_num_epochs:
        # LR scheduler and SWA update
        if swa is None or epoch < swa.start:
            if epoch > start_epoch:
                lr_scheduler.step(
                    metrics=valid_loss
                )  # Can break if exponential LR, TODO fix that!
        else:
            # swa_starts, so change loss and load model
            if swa_start:
                logging.info("Changing loss based on Stage Two Weights")
                lowest_loss = np.inf
                swa_start = False
                keep_last = True
            loss_fn = swa.loss_fn
            swa.model.update_parameters(model)
            if epoch > start_epoch:
                swa.scheduler.step()

        # Train
        if "ScheduleFree" in type(optimizer).__name__:
            optimizer.train()
        
        # trains one epoch
        train_one_epoch(
            model=model,
            loss_fn=loss_fn,
            data_loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
            output_args=output_args,
            max_grad_norm=max_grad_norm,
            ema=ema,
            logger=logger,
            device=device,
        )

        # Validate
        if epoch % eval_interval == 0:
            model_to_evaluate = model
            param_context = (
                ema.average_parameters() if ema is not None else nullcontext()
            )
            if "ScheduleFree" in type(optimizer).__name__:
                optimizer.eval()
            # evaluate model on val loader, for each val loader
            with param_context:
                valid_loss = 0.0
                for valid_loader_name, valid_loader in valid_loaders.items():
                    valid_loss_head, eval_metrics = evaluate(
                                                                model=model_to_evaluate,
                                                                loss_fn=loss_fn,
                                                                data_loader=valid_loader,
                                                                output_args=output_args,
                                                                device=device,
                                                            )
                    # log the val metrics
                    valid_err_log(
                            valid_loss_head,
                            eval_metrics,
                            logger,
                            log_errors,
                            epoch,
                            valid_loader_name,
                        )
                valid_loss = (
                    valid_loss_head  # consider only the last head for the checkpoint
                )

            # if val loss increase
            if valid_loss >= lowest_loss:
                # increase patience counter
                patience_counter += 1
                # terminate if patience counter exceeds patience
                if patience_counter >= patience:
                    if swa is not None and epoch < swa.start:
                        logging.info(
                            f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two"
                        )
                        epoch = swa.start
                    else:
                        logging.info(
                            f"Stopping optimization after {patience_counter} epochs without improvement"
                        )
                        break


            # val loss decreased, reset patience counter and save model
            else:
                lowest_loss = valid_loss
                patience_counter = 0
                param_context = (
                    ema.average_parameters() if ema is not None else nullcontext()
                )
                with param_context:
                    checkpoint_handler.save(
                        state=CheckpointState(model, optimizer, lr_scheduler),
                        epochs=epoch,
                        keep_last=keep_last,
                    )
                    keep_last = False

        epoch += 1

    logging.info("Training complete")


def train_one_epoch(
                    model: torch.nn.Module,
                    loss_fn: torch.nn.Module,
                    data_loader: DataLoader,
                    optimizer: torch.optim.Optimizer,
                    epoch: int,
                    output_args: Dict[str, bool],
                    max_grad_norm: Optional[float],
                    ema: Optional[ExponentialMovingAverage],
                    logger: MetricsLogger,
                    device: torch.device,
                ) -> None:
    
    # iterate thorugh batches and take step
    for batch in data_loader:
        _, opt_metrics = take_step(
                                    model=model,
                                    loss_fn=loss_fn,
                                    batch=batch,
                                    optimizer=optimizer,
                                    ema=ema,
                                    output_args=output_args,
                                    max_grad_norm=max_grad_norm,
                                    device=device,
                                )
        opt_metrics["mode"] = "opt"
        opt_metrics["epoch"] = epoch
        logger.log(opt_metrics)


def take_step(  model: torch.nn.Module,
                loss_fn: torch.nn.Module,
                batch: torch_geometric.batch.Batch,
                optimizer: torch.optim.Optimizer,
                ema: Optional[ExponentialMovingAverage],
                output_args: Dict[str, bool],
                max_grad_norm: Optional[float],
                device: torch.device,
            ) -> Tuple[float, Dict[str, Any]]:
    # measure time
    start_time = time.time()

    # send batch to device
    batch = batch.to(device)
    optimizer.zero_grad(set_to_none=True)
    batch_dict = batch.to_dict()

    output = model(
                    batch_dict,
                    training=True,
                    compute_force=output_args["forces"],
                    compute_virials=output_args["virials"],
                    compute_stress=output_args["stress"],
                )
    loss = loss_fn(pred=output, ref=batch)
    loss.backward()

    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    optimizer.step()

    if ema is not None:
        ema.update()

    loss_dict = {
        "loss": to_numpy(loss),
        "time": time.time() - start_time,
    }

    return loss, loss_dict


def evaluate(
            model: torch.nn.Module,
            loss_fn: torch.nn.Module,
            data_loader: DataLoader,
            output_args: Dict[str, bool],
            device: torch.device,
        ) -> Tuple[float, Dict[str, Any]]:
    
    # freeze the model
    for param in model.parameters():
        param.requires_grad = False
    
    # create metrics
    metrics = MACELoss(loss_fn=loss_fn).to(device)

    # start the timer
    start_time = time.time()

    for batch in data_loader:
        batch = batch.to(device)
        batch_dict = batch.to_dict()
        output = model(
                    batch_dict,
                    training=False,
                    compute_force=output_args["forces"],
                    compute_virials=output_args["virials"],
                    compute_stress=output_args["stress"],
                )
        avg_loss, aux = metrics(batch, output)

    avg_loss, aux = metrics.compute()
    aux["time"] = time.time() - start_time
    metrics.reset()

    # unfreeze the model
    for param in model.parameters():
        param.requires_grad = True

    return avg_loss, aux


class MACELoss(Metric):
    def __init__(self, loss_fn: torch.nn.Module):
        super().__init__()
        self.loss_fn = loss_fn
        self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("delta_es", default=[], dist_reduce_fx="cat")
        self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat")
        self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("fs", default=[], dist_reduce_fx="cat")
        self.add_state("delta_fs", default=[], dist_reduce_fx="cat")

    def update(self, batch, output):  # pylint: disable=arguments-differ
        loss = self.loss_fn(pred=output, ref=batch)
        self.total_loss += loss
        self.num_data += batch.num_graphs

        if output.get("energy") is not None and batch.energy is not None:
            self.E_computed += 1.0
            self.delta_es.append(batch.energy - output["energy"])
            self.delta_es_per_atom.append(
                (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1])
            )
        if output.get("forces") is not None and batch.forces is not None:
            self.Fs_computed += 1.0
            self.fs.append(batch.forces)
            self.delta_fs.append(batch.forces - output["forces"])
        

    def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray:
        if isinstance(delta, list):
            delta = torch.cat(delta)
        return to_numpy(delta)

    def compute(self):
        aux = {}
        aux["loss"] = to_numpy(self.total_loss / self.num_data).item()
        if self.E_computed:
            delta_es = self.convert(self.delta_es)
            delta_es_per_atom = self.convert(self.delta_es_per_atom)
            aux["mae_e"] = compute_mae(delta_es)
            aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom)
            aux["rmse_e"] = compute_rmse(delta_es)
            aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom)
            aux["q95_e"] = compute_q95(delta_es)
        if self.Fs_computed:
            fs = self.convert(self.fs)
            delta_fs = self.convert(self.delta_fs)
            aux["mae_f"] = compute_mae(delta_fs)
            aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs)
            aux["rmse_f"] = compute_rmse(delta_fs)
            aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs)
            aux["q95_f"] = compute_q95(delta_fs)
        
        return aux["loss"], aux


In [None]:
###########################################################################################
# Training script for MACE
# Authors: Ilyes Batatia, Gregor Simm, David Kovacs
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import argparse
import ast
import glob
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import List, Optional

import torch.distributed
import torch.nn.functional
from torch.utils.data import ConcatDataset
from torch_ema import ExponentialMovingAverage

import mace
from mace import data, tools
from mace.tools import torch_geometric
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
    HeadConfig,
    dict_head_to_dataclass,
    prepare_default_head,
)
from mace.tools.scripts_utils import (
    LRScheduler,
    check_path_ase_read,
    dict_to_array,
    get_atomic_energies,
    get_avg_num_neighbors,
    get_config_type_weights,
    get_dataset_from_xyz,
    get_loss_fn,
    get_optimizer,
    get_params_options,
    get_swa,
    remove_pt_head,
)
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable


def main() -> None:
    """
    This script runs the training/fine tuning for mace
    """
    args = tools.build_default_arg_parser().parse_args()
    run(args)


def run(args: argparse.Namespace) -> None:
    """
    This script runs the training/fine tuning for mace
    """
    # tag is like model name
    tag = tools.get_tag(name=args.name, seed=args.seed)

    # check all args
    args, input_log_messages = tools.check_args(args)

    # Setup
    tools.set_seeds(args.seed)
    tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir)
    logging.info("===========VERIFYING SETTINGS===========")
    for message, loglevel in input_log_messages:
        logging.log(level=loglevel, msg=message)

    try:
        logging.info(f"MACE version: {mace.__version__}")
    except AttributeError:
        logging.info("Cannot find MACE version, please install MACE via pip")
    logging.debug(f"Configuration: {args}")

    # setup dtype, init device
    tools.set_default_dtype(args.default_dtype)
    device = tools.init_device(args.device)
    
    # see if pretrained model is given and can we use multihead finetuning
    model_foundation: Optional[torch.nn.Module] = None
    if args.foundation_model is not None:
        assert os.path.exists(args.foundation_model), f"Couldn't find the model at path {args.foundation_model}"

        # load the model
        model_foundation = torch.load(args.foundation_model, map_location=args.device)
        logging.info(f"Using foundation model {args.foundation_model} as initial checkpoint.")
        args.r_max = model_foundation.r_max.item()
        # if pretraining file in not provided, can't do multihead finetuning
        if args.pt_train_file is None:
            logging.warning("Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file.")
            args.multiheads_finetuning = False
        
        # if multihead finetuning is selected
        if args.multiheads_finetuning:
            assert args.E0s != "average", "average atomic energies cannot be used for multiheads finetuning"
            # check that the foundation model has a single head, if not, use the first head
            if hasattr(model_foundation, "heads"):
                if len(model_foundation.heads) > 1:
                    logging.warning("Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head.")
                    model_foundation = remove_pt_head(model_foundation, args.foundation_head)
    else:
        args.multiheads_finetuning = False

    # if head is provided, use, else prepare
    # head is dict(str:things), eg train_file, valid_file, E0 etc
    if args.heads is not None:
        args.heads = ast.literal_eval(args.heads)
    else:
        args.heads = prepare_default_head(args)

    # load input data for each head one by one
    logging.info("===========LOADING INPUT DATA===========")
    heads = list(args.heads.keys())
    logging.info(f"Using heads: {heads}")
    head_configs: List[HeadConfig] = []

    for head, head_args in args.heads.items():
        logging.info(f"=============    Processing head {head}     ===========")
        head_config = dict_head_to_dataclass(head_args, head, args)
        # if statistics file is given, use
        if head_config.statistics_file is not None:
            with open(head_config.statistics_file, "r") as f:  # pylint: disable=W1514
                statistics = json.load(f)
            logging.info("Using statistics json file")
            head_config.r_max = (
                statistics["r_max"] if args.foundation_model is None else args.r_max
            )
            head_config.atomic_numbers = statistics["atomic_numbers"]
            head_config.mean = statistics["mean"]
            head_config.std = statistics["std"]
            head_config.avg_num_neighbors = statistics["avg_num_neighbors"]
            head_config.compute_avg_num_neighbors = False
            if isinstance(statistics["atomic_energies"], str) and statistics["atomic_energies"].endswith(".json"):
                with open(statistics["atomic_energies"], "r", encoding="utf-8") as f:
                    atomic_energies = json.load(f)
                head_config.E0s = atomic_energies
                head_config.atomic_energies_dict = ast.literal_eval(atomic_energies)
            else:
                head_config.E0s = statistics["atomic_energies"]
                head_config.atomic_energies_dict = ast.literal_eval(
                    statistics["atomic_energies"]
                )

        # Data preparation, if train_file is .d5 or .hdf5 or dir that contains it, or empty dir => False
        if check_path_ase_read(head_config.train_file):
            if head_config.valid_file is not None:
                assert check_path_ase_read(head_config.valid_file), "valid_file if given must be same format as train_file"

            config_type_weights = get_config_type_weights(head_config.config_type_weights)
            collections, atomic_energies_dict = get_dataset_from_xyz(
                work_dir=args.work_dir,
                train_path=head_config.train_file,
                valid_path=head_config.valid_file,
                valid_fraction=head_config.valid_fraction,
                config_type_weights=config_type_weights,
                test_path=head_config.test_file,
                seed=args.seed,
                energy_key=head_config.energy_key,
                forces_key=head_config.forces_key,
                stress_key=head_config.stress_key,
                virials_key=head_config.virials_key,
                dipole_key=head_config.dipole_key,
                charges_key=head_config.charges_key,
                head_name=head_config.head_name,
                keep_isolated_atoms=head_config.keep_isolated_atoms,
            )
            head_config.collections = collections
            head_config.atomic_energies_dict = atomic_energies_dict
            logging.info(f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, "
                        f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}],")
        head_configs.append(head_config)

    # check if enough number of samples are there
    if all(check_path_ase_read(head_config.train_file) for head_config in head_configs):
        size_collections_train = sum(len(head_config.collections.train) for head_config in head_configs)
        size_collections_valid = sum(len(head_config.collections.valid) for head_config in head_configs)
        if size_collections_train < args.batch_size:
            logging.error(f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})")
        if size_collections_valid < args.valid_batch_size:
            logging.warning(f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})")

    # if we are usng multihead finetuning, load pretrain data as well
    if args.multiheads_finetuning:
        logging.info("==================Using multiheads finetuning mode==================")
        args.loss = "universal"
        
        logging.info(f"Using foundation model for multiheads finetuning with {args.pt_train_file}")
        # add pretraining head at start
        heads = list(dict.fromkeys(["pt_head"] + heads))
        # load data
        collections, atomic_energies_dict = get_dataset_from_xyz(
            work_dir=args.work_dir,
            train_path=args.pt_train_file,
            valid_path=args.pt_valid_file,
            valid_fraction=args.valid_fraction,
            config_type_weights=None,
            test_path=None,
            seed=args.seed,
            energy_key=args.energy_key,
            forces_key=args.forces_key,
            stress_key=args.stress_key,
            virials_key=args.virials_key,
            dipole_key=args.dipole_key,
            charges_key=args.charges_key,
            head_name="pt_head",
            keep_isolated_atoms=args.keep_isolated_atoms,
        )

        # create pretrain head
        head_config_pt = HeadConfig(
            head_name="pt_head",
            train_file=args.pt_train_file,
            valid_file=args.pt_valid_file,
            E0s="foundation",
            statistics_file=args.statistics_file,
            valid_fraction=args.valid_fraction,
            config_type_weights=None,
            energy_key=args.energy_key,
            forces_key=args.forces_key,
            stress_key=args.stress_key,
            virials_key=args.virials_key,
            dipole_key=args.dipole_key,
            charges_key=args.charges_key,
            keep_isolated_atoms=args.keep_isolated_atoms,
            collections=collections,
            avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
            compute_avg_num_neighbors=False,
        )
        head_config_pt.collections = collections
        head_configs.append(head_config_pt)
        logging.info(f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}")

    # Atomic number table
    # yapf: disable
    for head_config in head_configs:
        # if atomic numbers not given, extract from train and valid datasets
        if head_config.atomic_numbers is None:            
            z_table_head = tools.get_atomic_number_table_from_zs(
                z
                for configs in (head_config.collections.train, head_config.collections.valid)
                for config in configs
                for z in config.atomic_numbers
            )
            head_config.atomic_numbers = z_table_head.zs
            head_config.z_table = z_table_head
        else:
            # if given, but not in stat file, read from command line
            if head_config.statistics_file is None:
                logging.info("Using atomic numbers from command line argument")
            else:
                logging.info("Using atomic numbers from statistics file")
            zs_list = ast.literal_eval(head_config.atomic_numbers)
            assert isinstance(zs_list, list)
            z_table_head = tools.AtomicNumberTable(zs_list)
            head_config.atomic_numbers = zs_list
            head_config.z_table = z_table_head
        # yapf: enable

    #  pool all atomic numbers from all heads
    all_atomic_numbers = set()
    for head_config in head_configs:
        all_atomic_numbers.update(head_config.atomic_numbers)
    z_table = AtomicNumberTable(sorted(list(all_atomic_numbers)))
    logging.info(f"Atomic Numbers used: {z_table.zs}")

    # Atomic energies
    atomic_energies_dict = {}
    for head_config in head_configs:
        # if no atomic energy dict is given
        if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
            # E0 can't be none then
            assert head_config.E0s is not None, "Atomic energies must be provided"
            # if not foundation, calculate if train file given
            if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation":
                atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s,
                                                                                head_config.collections.train, 
                                                                                head_config.z_table)
            # if E0 to be used from foundation
            elif head_config.E0s.lower() == "foundation":
                assert args.foundation_model is not None
                z_table_foundation = AtomicNumberTable([int(z) for z in model_foundation.atomic_numbers])
                foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies

                # if foundation model is multihead
                if foundation_atomic_energies.ndim > 1:
                    foundation_atomic_energies = foundation_atomic_energies.squeeze()
                    if foundation_atomic_energies.ndim == 2:
                        foundation_atomic_energies = foundation_atomic_energies[0]
                        logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
                atomic_energies_dict[head_config.head_name] = {z: foundation_atomic_energies[z_table_foundation.z_to_index(z)].item() for z in z_table.zs}
            else:
                # if train file not given, may have to read from json
                atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
        # use the given atomic energy dict
        else:
            atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict

    # Atomic energies for multiheads finetuning, for pretrain head
    if args.multiheads_finetuning:
        assert model_foundation is not None, "Model foundation must be provided for multiheads finetuning"
        z_table_foundation = AtomicNumberTable([int(z) for z in model_foundation.atomic_numbers])
        foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies

        # if the foundation model is itself multihead take first head/which is usully it's pretrained head, as we add pt_head at start
        if foundation_atomic_energies.ndim > 1:
            foundation_atomic_energies = foundation_atomic_energies.squeeze()
            if foundation_atomic_energies.ndim == 2:
                foundation_atomic_energies = foundation_atomic_energies[0]
                logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
        atomic_energies_dict["pt_head"] = { z: foundation_atomic_energies[z_table_foundation.z_to_index(z)].item()for z in z_table.zs}

    # set the output args
    dipole_only = False
    args.compute_energy = True
    args.compute_dipole = False
    atomic_energies = dict_to_array(atomic_energies_dict, heads)

    # log the atomic energies
    for head_config in head_configs:
        try:
            logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
        except KeyError as e:
            raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e


    # get val and train sets
    valid_sets = {head: [] for head in heads}
    train_sets = {head: [] for head in heads}

    for head_config in head_configs:
        if check_path_ase_read(head_config.train_file):
            train_sets[head_config.head_name] = [data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads)
                                                    for config in head_config.collections.train]
            valid_sets[head_config.head_name] = [data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads)
                                                    for config in head_config.collections.valid]

        else:
            raise ValueError(f"Provide file that ends with .xyz instead of {head_config.train_file}")
        
        train_loader_head = torch_geometric.dataloader.DataLoader(
            dataset=train_sets[head_config.head_name],
            batch_size=args.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=args.pin_memory,
            num_workers=args.num_workers,
            generator=torch.Generator().manual_seed(args.seed),
        )
        head_config.train_loader = train_loader_head
        
    # concatenate all the trainsets
    train_set = ConcatDataset([train_sets[head] for head in heads])
    
    train_loader = torch_geometric.dataloader.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=args.pin_memory,
        num_workers=args.num_workers,
        generator=torch.Generator().manual_seed(args.seed),
    )
    # valid loaders will be different for each head
    valid_loaders = {heads[i]: None for i in range(len(heads))}
    if not isinstance(valid_sets, dict):
        valid_sets = {"Default": valid_sets}

    for head, valid_set in valid_sets.items():
        valid_loaders[head] = torch_geometric.dataloader.DataLoader(
            dataset=valid_set,
            batch_size=args.valid_batch_size,
            shuffle=False,
            drop_last=False,
            pin_memory=args.pin_memory,
            num_workers=args.num_workers,
            generator=torch.Generator().manual_seed(args.seed),
        )

    # get loss function and avg num of neighbors
    loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole)
    args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device)

    # Model
    model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table)
    model.to(device)

    logging.debug(model)
    logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
    logging.info("")
    logging.info("===========OPTIMIZER INFORMATION===========")
    logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
    logging.info(f"Batch size: {args.batch_size}")
    if args.ema:
        logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
    logging.info(f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}")
    logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
    logging.info(loss_fn)

    # Optimizer
    param_options = get_params_options(args, model)
    optimizer: torch.optim.Optimizer
    optimizer = get_optimizer(args, param_options)
    logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train")  # pylint: disable=E1123

    # scheduler and swa
    lr_scheduler = LRScheduler(optimizer, args)

    swa: Optional[tools.SWAContainer] = None
    swas = [False]
    if args.swa:
        swa, swas = get_swa(args, model, optimizer, swas, dipole_only)

    # checkpoint handler
    checkpoint_handler = tools.CheckpointHandler(
        directory=args.checkpoints_dir,
        tag=tag,
        keep=args.keep_checkpoints,
        swa_start=args.start_swa,
    )

    start_epoch = 0
    if args.restart_latest:
        try:
            opt_start_epoch = checkpoint_handler.load_latest(state=tools.CheckpointState(model, optimizer, lr_scheduler),
                                                             swa=True,
                                                             device=device,)
        except Exception:  # pylint: disable=W0703
            opt_start_epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler),
                                                            swa=False,
                                                            device=device,)
        if opt_start_epoch is not None:
            start_epoch = opt_start_epoch

    # initialize ema
    ema: Optional[ExponentialMovingAverage] = None
    if args.ema:
        ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
    else:
        for group in optimizer.param_groups:
            group["lr"] = args.lr

    train(
        model=model,
        loss_fn=loss_fn,
        train_loader=train_loader,
        valid_loaders=valid_loaders,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        checkpoint_handler=checkpoint_handler,
        eval_interval=args.eval_interval,
        start_epoch=start_epoch,
        max_num_epochs=args.max_num_epochs,
        logger=logger,
        patience=args.patience,
        output_args=output_args,
        device=device,
        swa=swa,
        ema=ema,
        max_grad_norm=args.clip_grad,
        log_errors=args.error_table,
    )

    logging.info("")
    logging.info("===========RESULTS===========")
    logging.info("Computing metrics for training, validation, and test sets")

    # get train and all valid loaders for all heads
    train_valid_data_loader = {}
    for head_config in head_configs:
        data_loader_name = "train_" + head_config.head_name
        train_valid_data_loader[data_loader_name] = head_config.train_loader
    for head, valid_loader in valid_loaders.items():
        data_load_name = "valid_" + head
        train_valid_data_loader[data_load_name] = valid_loader

    test_sets = {}
    stop_first_test = False # Initialize a flag to determine if testing should stop after the first test
    test_data_loader = {}
    # Check if all heads have the same test file and if it is not None
    if all(head_config.test_file == head_configs[0].test_file for head_config in head_configs) and head_configs[0].test_file is not None:
        stop_first_test = True
    # Check if all heads have the same test directory and if it is not None
    if all(head_config.test_dir == head_configs[0].test_dir for head_config in head_configs) and head_configs[0].test_dir is not None:
        stop_first_test = True

    # load test data head by head, stop if all heads have same test file
    for head_config in head_configs:
        if check_path_ase_read(head_config.train_file):
            for name, subset in head_config.collections.tests:
                test_sets[name] = [data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max, heads=heads)
                                    for config in subset]
    
        for test_name, test_set in test_sets.items():
            test_loader = torch_geometric.dataloader.DataLoader(
                test_set,
                batch_size=args.valid_batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=args.num_workers,
                pin_memory=args.pin_memory,
            )
            test_data_loader[test_name] = test_loader
        if stop_first_test:
            break

    for swa_eval in swas:
        epoch = checkpoint_handler.load_latest(state=tools.CheckpointState(model, optimizer, lr_scheduler),swa=swa_eval,device=device,)
        model.to(device)

        if swa_eval:
            logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
        else:
            logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation")

        # freeze model params
        for param in model.parameters():
            param.requires_grad = False

        table_train_valid = create_error_table(
                                                table_type=args.error_table,
                                                all_data_loaders=train_valid_data_loader,
                                                model=model,
                                                loss_fn=loss_fn,
                                                output_args=output_args,
                                                log_wandb=False,
                                                device=device,
                                            )
        logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid))

        if test_data_loader:
            table_test = create_error_table(
                                                table_type=args.error_table,
                                                all_data_loaders=test_data_loader,
                                                model=model,
                                                loss_fn=loss_fn,
                                                output_args=output_args,
                                                log_wandb=False,
                                                device=device,
                                            )
            logging.info("Error-table on TEST:\n" + str(table_test))

        # Save entire model
        if swa_eval:
            model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
            torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model"))
            logging.info(f"Saved stagetwo model at {Path(args.model_dir) / (args.name + "_stagetwo.model")}")
        else:
            model_path = Path(args.checkpoints_dir) / (tag + ".model")
            torch.save(model, Path(args.model_dir) / (args.name + ".model"))
            logging.info(f"Saved model at {Path(args.model_dir) / (args.name + ".model")}")

        if args.save_cpu:
            model = model.to("cpu")
        torch.save(model, model_path)

        logging.info(f"Saved model to {model_path}")

            


    logging.info("Done")


In [None]:
import ast
import logging

# import numpy as np
# from e3nn import o3

from mace import modules
from mace.tools.finetuning_utils import load_foundations_elements
from mace.tools.scripts_utils import extract_config_mace_model


def configure_model(args, train_loader, atomic_energies, model_foundation=None, heads=None, z_table=None):
    # Selecting outputs
    compute_virials = args.loss in ("stress", "virials", "huber", "universal")
    if compute_virials:
        args.compute_stress = True
        args.error_table = "PerAtomRMSEstressvirials"

    output_args = {
        "energy": args.compute_energy,
        "forces": args.compute_forces,
        "virials": compute_virials,
        "stress": args.compute_stress,
        "dipoles": args.compute_dipole,
    }
    logging.info(
        f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}"
    )
    logging.info("===========MODEL DETAILS===========")

    if args.scaling == "no_scaling":
        args.std = 1.0
        logging.info("No scaling selected")
    elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE":
        args.mean, args.std = modules.scaling_classes[args.scaling](
            train_loader, atomic_energies
        )

    # Build model
    if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]:
        logging.info("Loading FOUNDATION model")
        model_config_foundation = extract_config_mace_model(model_foundation)
        model_config_foundation["atomic_energies"] = atomic_energies
        model_config_foundation["atomic_numbers"] = z_table.zs
        model_config_foundation["num_elements"] = len(z_table)
        args.max_L = model_config_foundation["hidden_irreps"].lmax

        if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE":
            model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads)
        else:
            model_config_foundation["atomic_inter_shift"] = (
                _determine_atomic_inter_shift(args.mean, heads)
            )
        model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
        args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
        args.model = "FoundationMACE"
        model_config_foundation["heads"] = heads
        model_config = model_config_foundation

        logging.info("Model configuration extracted from foundation model")
        logging.info("Using universal loss function for fine-tuning")
        logging.info(
            f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})"
        )
        logging.info(
            f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}"
        )
        logging.info(
            f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)"
        )
        logging.info(
            f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}"
        )
    else:
        logging.info("Building model")
        logging.info(
            f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})"
        )
        logging.info(
            f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}"
        )
        logging.info(
            f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions"
        )
        logging.info(
            f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)"
        )
        logging.info(
            f"Distance transform for radial basis functions: {args.distance_transform}"
        )

        assert (
            len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
        ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"

        logging.info(f"Hidden irreps: {args.hidden_irreps}")

        model_config = dict(
            r_max=args.r_max,
            num_bessel=args.num_radial_basis,
            num_polynomial_cutoff=args.num_cutoff_basis,
            max_ell=args.max_ell,
            interaction_cls=modules.interaction_classes[args.interaction],
            num_interactions=args.num_interactions,
            num_elements=len(z_table),
            hidden_irreps=o3.Irreps(args.hidden_irreps),
            atomic_energies=atomic_energies,
            avg_num_neighbors=args.avg_num_neighbors,
            atomic_numbers=z_table.zs,
        )
        model_config_foundation = None

    model = _build_model(args, model_config, model_config_foundation, heads)

    if model_foundation is not None:
        model = load_foundations_elements(
            model,
            model_foundation,
            z_table,
            load_readout=args.foundation_filter_elements,
            max_L=args.max_L,
        )

    return model, output_args


def _determine_atomic_inter_shift(mean, heads):
    if isinstance(mean, np.ndarray):
        if mean.size == 1:
            return mean.item()
        if mean.size == len(heads):
            return mean.tolist()
        logging.info("Mean not in correct format, using default value of 0.0")
        return [0.0] * len(heads)
    if isinstance(mean, list) and len(mean) == len(heads):
        return mean
    if isinstance(mean, float):
        return [mean] * len(heads)
    logging.info("Mean not in correct format, using default value of 0.0")
    return [0.0] * len(heads)


def _build_model(args, model_config, model_config_foundation, heads):  # pylint: disable=too-many-return-statements
    if args.model == "MACE":
        return ScaleShiftMACE(
            **model_config,
            pair_repulsion=args.pair_repulsion,
            distance_transform=args.distance_transform,
            correlation=args.correlation,
            gate=modules.gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[
                "RealAgnosticInteractionBlock"
            ],
            MLP_irreps=o3.Irreps(args.MLP_irreps),
            atomic_inter_scale=args.std,
            atomic_inter_shift=[0.0] * len(heads),
            radial_MLP=ast.literal_eval(args.radial_MLP),
            radial_type=args.radial_type,
            heads=heads,
        )
    if args.model == "ScaleShiftMACE":
        return ScaleShiftMACE(
            **model_config,
            pair_repulsion=args.pair_repulsion,
            distance_transform=args.distance_transform,
            correlation=args.correlation,
            gate=modules.gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=o3.Irreps(args.MLP_irreps),
            atomic_inter_scale=args.std,
            atomic_inter_shift=args.mean,
            radial_MLP=ast.literal_eval(args.radial_MLP),
            radial_type=args.radial_type,
            heads=heads,
        )
    if args.model == "FoundationMACE":
        return ScaleShiftMACE(**model_config_foundation)
    
    raise RuntimeError(f"Unknown model: '{args.model}'")
