This tutorial demostrates the creation of a model in bionemo. 
`Megatron` / `NeMo` modules and datasets are special derivatives of PyTorch modules and datasets that extend and accelerate the distributed training and inference capabilities of PyTorch.

Some distinctions of Megatron / NeMo are:

- `torch.nn.Module`/`LightningModule` changes into `MegatronModule`.
- Loss functions should extend the `MegatronLossReduction` module and implement a `reduce` method for aggregating loss across multiple micro-batches.
- Megatron configuration classes (e.g. `megatron.core.transformer.TransformerConfig`) are extended with a `configure_model` method that defines how model weights are initialized and loaded in a way that is compliant with training via NeMo2.
- Various modifications and extensions to common PyTorch classes, such as adding a `MegatronDataSampler` (and re-sampler such as `PRNGResampleDataset` or `MultiEpochDatasetResampler`) to your `LightningDataModule

In [None]:
from nemo.lightning.megatron_parallel import MegatronLossReduction
from torchvision.datasets import MNIST
from nemo.lightning.pytorch.plugins import MegatronDataSampler


Losses: here we define a simple loss function. These should inherit from losses in nemo.lightning.megatron_parallel. 
The output of forward and backwared passes happen in parallel. The reduce function is required. It is only used for collecting forward output for inference, as well as for logging.

In [None]:
class MSELossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: "MnistItem", forward_out: Dict[str, Tensor]) -> Tuple[Tensor, ReductionT]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        x = batch["data"]
        x_hat = forward_out["x_hat"]
        xview = x.view(x.size(0), -1).to(x_hat.dtype)
        loss = nn.functional.mse_loss(x_hat, xview)

        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()



In [None]:
class MnistItem(TypedDict):
    data: Tensor
    label: Tensor
    idx: int
    
class MNISTCustom(MNIST):
    def __getitem__(self, index: int) -> MnistItem:
        """Wraps the getitem method of the MNIST dataset such that we return a Dict
        instead of a Tuple or tensor.

        Args:
            index: The index we want to grab, an int.

        Returns:
            A dict containing the data ("x"), label ("y"), and index ("idx").
        """  # noqa: D205
        x, y = super().__getitem__(index)

        return {
            "data": x,
            "label": y,
            "idx": index,
        }
mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
mnist_train_data, mnist_val_data = torch.utils.data.random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )
mnist_test_set_data = 
            MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False), seed=43)


Datasets used for model training must be compatible with megatron datasets.
The dataset modules must have a data_sampler in it which is a nemo2 peculiarity. Also the sampler will not shuffle your data! So you need to wrap your dataset in a dataset shuffler that maps sequential ids to random ids in your dataset. This is what PRNGResampleDataset does. For further information, see: docs/user-guide/background/megatron_datasets.md. Moreover, the compatability of datasets with megatron can be checked by running bionemo.testing.data_utils.assert_dataset_compatible_with_megatron.


In [None]:
mnist_test = PRNGResampleDataset(mnist_test_set_data)
mnist_train = PRNGResampleDataset(mnist_train_set_data)
mnist_val = PRNGResampleDataset(mnist_val_set_data)



The data module can now take these in as inputs. 

In the data module class, it's necessary to have data_sampler method to shuffle the data and that allows the sampler to be used with megatron. A nemo.lightning.pytorch.plugins.MegatronDataSampler is the best choice. It sets up the capability to utilize micro-batching and gradient accumulation. It is also the place where the global batch size is constructed.

In [None]:
class MNISTDataModule(pl.LightningDataModule): 
    def __init__(self, data_sampler:MegatronDataSampler, mnist_train:PRNGResampleDataset, mnist_val:PRNGResampleDataset, mnist_test:PRNGResampleDataset, batch_size: int = 32) -> None:  # noqa: D107
        super().__init__()
        self.batch_size = batch_size
        self.mnist_train = mnist_train
        self.mnist_test = mnist_test
        self.mnist_val = mnist_val
        # Wraps the datasampler with the MegatronDataSampler. 
        self.data_sampler = MegatronDataSampler(
            seq_len=self.max_len,
            micro_batch_size=self.batch_size,
            global_batch_size=self.batch_size,
            rampup_batch_size=None,
        )

    def train_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=0)

    def val_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=0)

    def test_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=0)


The model config class is used to instatiate the model. These configs must have:
1. A configure_model function which allows the megatron strategy to lazily initialize the model after the parallel computing environment has been setup. These also handle loading starting weights for fine-tuning cases. Additionally these configs tell the trainer which loss you want to use with a matched model.
2. A get_loss_reduction_class function that defines the loss fucntion.

In [None]:
# typevar for capturing subclasses of ExampleModelTrunk. Useful for Generic type hints as below.
ExampleModelT = TypeVar("ExampleModelT", bound=ExampleModelTrunk)


@dataclass
class ExampleGenericConfig(Generic[ExampleModelT, Loss], MegatronBioNeMoTrainableModelConfig[ExampleModelT, Loss]):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    loss_cls: Type[Loss] = MSELossReduction 
    hidden_size: int = 64  # Needs to be set to avoid zero division error in megatron :(
    num_attention_heads: int = 1  # Needs to be set to avoid zero division error in megatron :(
    num_layers: int = 1  # Needs to be set to avoid zero division error in megatron :(
    # IMPORTANT: Since we're adding/overriding the loss_cls, and that's not how we generally track this, we need to
    #   add this into the list of config settings that we do not draw from the loaded checkpoint when restoring.
    override_parent_fields: List[str] = field(default_factory=lambda: OVERRIDE_BIONEMO_CONFIG_DEFAULTS + ["loss_cls"])

    def configure_model(self) -> ExampleModelT:
        """Uses model_cls and loss_cls to configure the model.

        Note: Must pass self into Model since model requires having a config object.

        Returns:
            The model object.
        """
        # 1. first load any settings that may exist in the checkpoint related to the model.
        if self.initial_ckpt_path:
            self.load_settings_from_checkpoint(self.initial_ckpt_path)
        # 2. then initialize the model
        model = self.model_cls(self)
        # 3. Load weights from the checkpoint into the model
        if self.initial_ckpt_path:
            self.update_model_from_checkpoint(model, self.initial_ckpt_path)
        return model

    def get_loss_reduction_class(self) -> Type[Loss]:
        """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
        return self.loss_cls


Models need to be megatron modules. At the most basic level this just means:
  1. They need a config argument of type megatron.core.ModelParallelConfig. An easy way of implementing this is to inherit from bionemo.llm.model.config.MegatronBioNeMoTrainableModelConfig. This is a class for bionemo that supports usage with Megatron models, as NeMo2 requires. This class also inherits ModelParallelConfig. 
  2. They need a self.model_type:megatron.core.transformer.enums.ModelType enum defined (ModelType.encoder_or_decoder is probably usually fine)
  3. def set_input_tensor(self, input_tensor) needs to be present. This is used in model parallelism. This function can be a stub/ placeholder function.


class ExampleModelTrunk(MegatronModule):
    def __init__(self, config: ModelParallelConfig) -> None:
        """Constructor of the model.

        Args:
            config: The config object is responsible for telling the strategy what model to create.
        """
        super().__init__(config)
        # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
        #  parallelizable megatron linear layers.
        self.model_type: ModelType = ModelType.encoder_or_decoder
        self.linear1 = nn.Linear(28 * 28, 64)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(64, 3)

    def forward(self, x: torch.Tensor) -> Tensor:
        # we could return a dictionary of strings to tensors here, but let's demonstrate this is not necessary
        x = x.view(x.size(0), -1)
        z = self.linear1(x)
        z = self.relu(z)
        z = self.linear2(z)
        return z

    def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
        """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
        pass

