<img src="https://cdn.comet.ml/img/notebook_logo.png">

[Comet](https://www.comet.com/site/products/ml-experiment-tracking/?utm_campaign=ray_train&utm_medium=colab) is an MLOps Platform that is designed to help Data Scientists and Teams build better models faster! Comet provides tooling to track, Explain, Manage, and Monitor your models in a single place! It works with Jupyter Notebooks and Scripts and most importantly it's 100% free to get started!

[Ray Train](https://docs.ray.io/en/latest/train/train.html) abstracts away the complexity of setting up a distributed training system.

Instrument your runs with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

[Find more information about our integration with Ray Train](https://www.comet.ml/docs/v2/integrations/ml-frameworks/ray/)

Get a preview for what's to come. Check out a completed experiment created from this notebook [here](https://www.comet.com/examples/comet-example-ray-train-keras/99d169308c854be7ac222c995a2bfa26?experiment-tab=systemMetrics).

This example is based on the [following Ray Train Lightning example](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html).

# Install Dependencies

In [None]:
%pip install "comet_ml>=3.49.0" "ray[air]>=2.1.0" "lightning" "torchvision"

# Initialize Comet

In [None]:
import comet_ml

comet_ml.login()

# Import Dependencies

In [None]:
import os
import tempfile

import comet_ml.integration.ray
from comet_ml.integration.ray import comet_worker

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import lightning.pytorch as pl

import ray.train.lightning
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig

# Prepare your model

In [None]:
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("ligthning_loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

# Define your distributed training function

This function is gonna be distributed and executed on each distributed worker.

In [None]:
@comet_worker
def train_func(config):
    from lightning.pytorch.loggers import CometLogger

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(
        root=data_dir, train=True, download=True, transform=transform
    )
    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

    # Training
    model = ImageClassifier()

    comet_logger = CometLogger()

    # Temporary workaround, can be removed once
    # https://github.com/Lightning-AI/pytorch-lightning/pull/20275 has
    # been merged and released
    comet_logger._experiment = comet_ml.get_running_experiment()

    # [1] Configure PyTorch Lightning Trainer.
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        devices="auto",
        accelerator="auto",
        strategy=ray.train.lightning.RayDDPStrategy(),
        plugins=[ray.train.lightning.RayLightningEnvironment()],
        callbacks=[ray.train.lightning.RayTrainReportCallback()],
        logger=comet_logger,
        # [1a] Optionally, disable the default checkpointing behavior
        # in favor of the `RayTrainReportCallback` above.
        enable_checkpointing=False,
        log_every_n_steps=2,
    )
    trainer = ray.train.lightning.prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=train_dataloader)

# Define the function that schedule the distributed job

In [None]:
def train(num_workers: int = 2, use_gpu: bool = False, epochs=1):
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
    config = {"use_gpu": use_gpu, "epochs": epochs}

    ray_trainer = TorchTrainer(
        train_func,
        scaling_config=scaling_config,
        train_loop_config=config,
        # run_config=RunConfig(callbacks=[callback]),
    )
    comet_ml.integration.ray.comet_ray_train_logger(
        ray_trainer, project_name="comet-example-ray-train-pytorch-lightning"
    )
    result = ray_trainer.fit()

# Train the model

Ray will wait indefinitely if we request more num_workers that the available resources, the code below ensure we never request more CPU than available locally.

In [None]:
ideal_num_workers = 2

available_local_cpu_count = os.cpu_count() - 1
num_workers = min(ideal_num_workers, available_local_cpu_count)

if num_workers < 1:
    num_workers = 1

train(num_workers, use_gpu=False, epochs=3)

In [None]:
comet_ml.end()