# Install

In [None]:
pip install torch

In [None]:
pip install torchvision

# Imports and define NN

In [1]:
import ray
from ray.train.torch import TorchTrainer
from torchvision import datasets, transforms
import os
import tempfile
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Define Training Loop

In [2]:
# Training loop.
def train_loop_per_worker(config):

    # Read configurations.
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]

    # Fetch training dataset.
    train_dataset_shard = ray.train.get_dataset_shard("train")

    # Instantiate and prepare model for training.
    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    # Define loss and optimizer.
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Create data loader.
    dataloader = train_dataset_shard.iter_torch_batches(
        batch_size=batch_size, dtypes=torch.float
    )

    # Train multiple epochs.
    for epoch in range(num_epochs):

        # Train epoch.
        for batch in dataloader:
            output = model(batch["input"])
            loss = loss_fn(output, batch["label"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Create checkpoint.
        base_model = (model.module
            if isinstance(model, DistributedDataParallel) else model)
        checkpoint_dir = tempfile.mkdtemp()
        torch.save(
            {"model_state_dict": base_model.state_dict()},
            os.path.join(checkpoint_dir, "model.pt"),
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        # Report metrics and checkpoint.
        ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)

# Initializing the Cluster

In [3]:
# Initialize Ray
ray.init()

2024-07-11 23:15:52,313	INFO worker.py:1743 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Python version:,3.10.14
Ray version:,2.10.0
Dashboard:,http://127.0.0.1:8265


[36m(RayTrainWorker pid=2027)[0m Setting up process group for: env:// [rank=0, world_size=4]
[36m(TorchTrainer pid=2010)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=2010)[0m - (ip=127.0.0.1, pid=2027) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=2010)[0m - (ip=127.0.0.1, pid=2028) world_rank=1, local_rank=1, node_rank=0
[36m(TorchTrainer pid=2010)[0m - (ip=127.0.0.1, pid=2029) world_rank=2, local_rank=2, node_rank=0
[36m(TorchTrainer pid=2010)[0m - (ip=127.0.0.1, pid=2030) world_rank=3, local_rank=3, node_rank=0
[36m(RayTrainWorker pid=2027)[0m Moving model to device: cpu
[36m(RayTrainWorker pid=2027)[0m Wrapping provided model in DistributedDataParallel.
[36m(SplitCoordinator pid=2037)[0m I0000 00:00:1720754211.920469 4292714 config.cc:230] gRPC experiments enabled: call_status_override_on_cancellation, event_engine_dns, event_engine_listener, http2_stats_fix, monitoring_experiment, pick_first_new, trace_record_callops, work_seria

In [5]:
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig, FailureConfig
from ray.train.torch import TorchTrainer

max_failures = 1
num_workers = 4
use_gpu = False
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1), failure_config=FailureConfig(max_failures))

# Define datasets.
train_dataset = ray.data.from_items(
    [{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
)
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=datasets,
)

# Running Training Job

In [6]:
# Start training
result = trainer.fit()
last_checkpoint = result.checkpoint

0,1
Current time:,2024-07-11 23:16:56
Running for:,00:00:10.88
Memory:,22.8/32.0 GiB

Trial name,status,loc,iter,total time (s),loss
TorchTrainer_2b370_00000,TERMINATED,127.0.0.1:2010,20,7.39531,5712590000000.0


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).


(pid=2037) - split(4, equal=True) 1:   0%|                                                                    …

(pid=2037) Running 0:   0%|                                                                                   …

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).
You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).
2024-07-11 23:16:56,490	INFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/Users/farceo/ray_results/TorchTrainer_2024-07-11_23-16-45' in 0.0042s.
2024-07-11 23:16:56,494	INFO tune.py:1048 -- Total run time: 10.89 seconds (10.87 seconds for the tuning loop).


In [8]:
result.metrics

{'loss': 5712591192064.0,
 'timestamp': 1720754215,
 'checkpoint_dir_name': 'checkpoint_000019',
 'should_checkpoint': True,
 'done': True,
 'training_iteration': 20,
 'trial_id': '2b370_00000',
 'date': '2024-07-11_23-16-55',
 'time_this_iter_s': 0.17356491088867188,
 'time_total_s': 7.395307779312134,
 'pid': 2010,
 'hostname': 'farceo-mac',
 'node_ip': '127.0.0.1',
 'config': {'train_loop_config': {'num_epochs': 20,
   'lr': 0.01,
   'batch_size': 32}},
 'time_since_restore': 7.395307779312134,
 'iterations_since_restore': 20,
 'experiment_tag': '0'}

# Scoring new data

In [15]:
def score_model(checkpoint_path, data_points):
    model_state_dict = torch.load(checkpoint_path)["model_state_dict"]
    model = NeuralNetwork()
    model.load_state_dict(model_state_dict)
    model.eval()  # Set the model to evaluation mode
    
    with torch.no_grad():  # Disable gradient calculation for inference
        inputs = torch.tensor([data_points], dtype=torch.float).reshape(-1, 1)
        predictions = model(inputs)
        return predictions.numpy()

In [28]:
checkpoint_path = open(os.path.join(last_checkpoint.path, "model.pt"), 'rb')

predictions = score_model(checkpoint_path, train_dataset.to_pandas()['input'].values)

  inputs = torch.tensor([data_points], dtype=torch.float).reshape(-1, 1)


In [29]:
predictions

array([[-2338410.],
       [-2338410.],
       [-2338410.],
       ...,
       [-2338410.],
       [-2338410.],
       [-2338410.]], dtype=float32)

# End