This tutorial demostrates the creation of a model in bionemo. 
`Megatron` / `NeMo` modules and datasets are special derivatives of PyTorch modules and datasets that extend and accelerate the distributed training and inference capabilities of PyTorch.

Some distinctions of Megatron / NeMo are:

- `torch.nn.Module`/`LightningModule` changes into `MegatronModule`.
- Loss functions should extend the `MegatronLossReduction` module and implement a `reduce` method for aggregating loss across multiple micro-batches.
- Megatron configuration classes (e.g. `megatron.core.transformer.TransformerConfig`) are extended with a `configure_model` method that defines how model weights are initialized and loaded in a way that is compliant with training via NeMo2.
- Various modifications and extensions to common PyTorch classes, such as adding a `MegatronDataSampler` (and re-sampler such as `PRNGResampleDataset` or `MultiEpochDatasetResampler`) to your `LightningDataModule

In [None]:
from nemo.lightning.megatron_parallel import MegatronLossReduction
from torchvision.datasets import MNIST
from nemo.lightning.pytorch.plugins import MegatronDataSampler


Losses: here we define a simple loss function. These should inherit from losses in nemo.lightning.megatron_parallel. 
The output of forward and backwared passes happen in parallel.
The reduce function is required. It is only used for collecting forward output for inference, as well as for logging.

In [None]:
class MSELossReduction(MegatronLossReduction):
    pass

In [None]:
class MnistItem(TypedDict):
    data: Tensor
    label: Tensor
    idx: int
    
class MNISTCustom(MNIST):
    def __getitem__(self, index: int) -> MnistItem:
        """Wraps the getitem method of the MNIST dataset such that we return a Dict
        instead of a Tuple or tensor.

        Args:
            index: The index we want to grab, an int.

        Returns:
            A dict containing the data ("x"), label ("y"), and index ("idx").
        """  # noqa: D205
        x, y = super().__getitem__(index)

        return {
            "data": x,
            "label": y,
            "idx": index,
        }
mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
mnist_train_data, mnist_val_data = torch.utils.data.random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )
mnist_test_set_data = 
            MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False), seed=43)


Datasets used for model training must be compatible with megatron datasets.
The dataset modules must have a data_sampler in it which is a nemo2 peculiarity. Also the sampler will not shuffle your data! So you need to wrap your dataset in a dataset shuffler that maps sequential ids to random ids in your dataset. This is what PRNGResampleDataset does. For further information, see: docs/user-guide/background/megatron_datasets.md. Moreover, the compatability of datasets with megatron can be checked by running bionemo.testing.data_utils.assert_dataset_compatible_with_megatron.


In [None]:
mnist_test = PRNGResampleDataset(mnist_test_set_data)
mnist_train = PRNGResampleDataset(mnist_train_set_data)
mnist_val = PRNGResampleDataset(mnist_val_set_data)



The data module can now take these in as inputs. 

In [None]:
In the data module class, it's necessary to have data_sampler set. This data sampler that can be used to shuffle the data. A MegatronDataSampler is the best choice. 

In [None]:
class MNISTDataModule(pl.LightningDataModule): 
    def __init__(self, data_sampler:MegatronDataSampler, mnist_train:PRNGResampleDataset, mnist_val:PRNGResampleDataset, mnist_test:PRNGResampleDataset, batch_size: int = 32) -> None:  # noqa: D107
        super().__init__()
        self.batch_size = batch_size
        self.mnist_train = mnist_train
        self.mnist_test = mnist_test
        self.mnist_val = mnist_val
        #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
        # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
        # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
        # the place where the global batch size is constructed.
        self.data_sampler = MegatronDataSampler(
            seq_len=self.max_len,
            micro_batch_size=self.batch_size,
            global_batch_size=self.batch_size,
            rampup_batch_size=None,
        )

    def train_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=0)

    def val_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=0)

    def test_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=0)


Models need to be megatron modules. At the most basic level this just means:
  1. They need a config argument of type megatron.core.ModelParallelConfig. An easy way of implementing this is to inherit from bionemo.llm.model.config.MegatronBioNeMoTrainableModelConfig. This is a class for bionemo that supports usage with Megatron models, as NeMo2 requires. This class also inherits ModelParallelConfig. 
  2. They need a self.model_type:megatron.core.transformer.enums.ModelType enum defined (ModelType.encoder_or_decoder is probably usually fine)
  3. def set_input_tensor(self, input_tensor) needs to be present. This is used in model parallelism. This function can be a stub/ placeholder function.
