This tutorial demostrates the creation of a model in bionemo. 
Models need to be megatron modules. At the most basic level this just means:
  1. They need a config argument of type megatron.core.ModelParallelConfig. An easy way of implementing this is to inherit from bionemo.llm.model.config.MegatronBioNeMoTrainableModelConfig. This is a class for bionemo that supports usage with Megatron models, as NeMo2 requires. This class also inherits ModelParallelConfig. 
  2. They need a self.model_type:megatron.core.transformer.enums.ModelType enum defined (ModelType.encoder_or_decoder is probably usually fine)
  3. def set_input_tensor(self, input_tensor) needs to be present. This is used in model parallelism. This function can be a stub/ placeholder function.


In [None]:
from nemo.lightning.megatron_parallel import MegatronLossReduction
from torchvision.datasets import MNIST


Losses: here we define a simple loss function. These should inherit from losses in nemo.lightning.megatron_parallel. 
The output of forward and backwared passes happen in parallel.
The reduce function is required. It is only used for collecting forward output for inference, as well as for logging.

In [None]:
class MSELossReduction(MegatronLossReduction):
    pass

Datasets used for model training must be compatible with megatron datasets.
The dataset modules must have a data_sampler in it which is a nemo2 peculiarity. Also the sampler will not shuffle your data! So you need to wrap your dataset in a dataset shuffler that maps sequential ids to random ids in your dataset. For further information, see: docs/user-guide/background/megatron_datasets.md. 


In [None]:
class MnistItem(TypedDict):
    data: Tensor
    label: Tensor
    idx: int
    
class MNISTCustom(MNIST):
    def __getitem__(self, index: int) -> MnistItem:
        """Wraps the getitem method of the MNIST dataset such that we return a Dict
        instead of a Tuple or tensor.

        Args:
            index: The index we want to grab, an int.

        Returns:
            A dict containing the data ("x"), label ("y"), and index ("idx").
        """  # noqa: D205
        x, y = super().__getitem__(index)

        return {
            "data": x,
            "label": y,
            "idx": index,
        }
