In [None]:

# tag::load_cifar[]
from torchvision import transforms, datasets


def load_cifar(train: bool):
    transform = transforms.Compose([  # <1>
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    def loader():
        return datasets.CIFAR10(  # <2>
            root="./data",
            download=True,
            train=train,  # <3>
            transform=transform)

    return loader
# end::load_cifar[]

# tag::std[]
from ray.data import read_datasource, datasource


source = datasource.SimpleTorchDatasource()  # <1>
train_dataset = read_datasource(source, dataset_factory=load_cifar(train=True))  # <2>
test_dataset = read_datasource(source, dataset_factory=load_cifar(train=False))
# end::std[]


# tag::batch_conversion[]
import pandas as pd
from ray.data.extensions import TensorArray


def convert_to_pandas(batch):
    return pd.DataFrame({
        "image": TensorArray([image.numpy() for image, _ in batch]),  # <1>
        "label": [label for _, label in batch]  # <2>
    })


train_dataset = train_dataset.map_batches(convert_to_pandas)  # <3>
test_dataset = test_dataset.map_batches(convert_to_pandas)
# end::batch_conversion[]


# tag::torch_model[]
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# end::torch_model[]

# tag::torch_training_loop[]
from ray import train
from ray.air import session, Checkpoint


def train_loop(config):
    model = train.torch.prepare_model(Net())  # <1>
    loss_fct = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_batches = session.get_dataset_shard("train").iter_torch_batches(  # <2>
        batch_size=config["batch_size"],
    )

    for epoch in range(config["epochs"]):
        running_loss = 0.0
        for i, data in enumerate(train_batches):
            inputs, labels = data["image"], data["label"]  # <3>

            optimizer.zero_grad()  # <4>
            forward_outputs = model(inputs)
            loss = loss_fct(forward_outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()  # <5>
            if i % 1000 == 0:
                print(f"[{epoch + 1}, {i + 1:4d}] loss: {running_loss / 1000:.3f}")
                running_loss = 0.0

        session.report(  # <6>
            dict(running_loss=running_loss),
            checkpoint=Checkpoint.from_dict(dict(model=model.module.state_dict())),
        )
# end::torch_training_loop[]


# tag::torch_trainer[]
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, RunConfig
from ray.air.callbacks.mlflow import MLflowLoggerCallback


trainer = TorchTrainer(
    train_loop_per_worker=train_loop,
    train_loop_config={"batch_size": 10, "epochs": 5},
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2),
    run_config=RunConfig(callbacks=[
        MLflowLoggerCallback(experiment_name="torch_trainer")
    ])

)
result = trainer.fit()
# end::torch_trainer[]

# tag::store_checkpoint[]
CHECKPOINT_PATH = "torch_checkpoint"
result.checkpoint.to_directory(CHECKPOINT_PATH)
# end::store_checkpoint[]

import sys
sys.exit("End of executable script. Stopping smoke tests here.")

# tag::custom_data[]
from ray.data import read_datasource, datasource


class SnowflakeDatasource(datasource.Datasource):
    pass


dataset = read_datasource(SnowflakeDatasource(), ...)
# end::custom_data[]


# tag::custom_trainer[]
from ray.train.data_parallel_trainer import DataParallelTrainer


class JaxTrainer(DataParallelTrainer):
    pass


trainer = JaxTrainer(
    ...,
    scaling_config=ScalingConfig(...),
    datasets=dict(train=dataset),
)
# end::custom_trainer[]


# tag::custom_tuner[]
from ray.tune import logger, tuner
from ray.air.config import RunConfig, ScalingConfig

class NeptuneCallback(logger.LoggerCallback):
    pass


tuner = tuner.Tuner(
    trainer,
    run_config=RunConfig(callbacks=[NeptuneCallback(...)])
)
# end::custom_tuner[]