In [17]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

from ray import tune
from ray import train
from ray.tune.schedulers import ASHAScheduler

In [14]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # In this example, we don't change the model architecture
        # due to simplicity.
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

In [15]:
def train_mnist(config):
    # Data Setup
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(
        datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)
    
    #train_loader = train.torch.prepare_data_loader(train_loader)
    test_loader = DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)
    
    #test_loader = train.torch.prepare_data_loader(test_loader)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ConvNet()
    #model = train.torch.prepare_model(model)
    model.to(device)
    
    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], 
        momentum=config["momentum"], 
        weight_decay=config['lambda']
    )
    
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)

        # Send the current training result back to Tune
        tune.report(mean_accuracy=acc)

        if i % 5 == 0:
            # This saves the model to the trial directory
            torch.save(model.state_dict(), "./model.pth")


In [16]:
search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
    "momentum": tune.uniform(0.1, 0.9),
    "lambda": tune.grid_search([0.1, 0.01, 0.001])
}

# Uncomment this to enable distributed execution
# `ray.init(address="auto")`

# Download the dataset first
datasets.MNIST("~/data", train=True, download=True)

analysis = tune.run(train_mnist, config=search_space)



Trial name,status,loc,lr,momentum
train_mnist_0cb5f_00000,RUNNING,127.0.0.1:9421,2.4136e-10,0.500739


[2m[36m(train_mnist pid=9421)[0m 2022-08-15 05:37:45,578	ERROR function_runner.py:286 -- Runner Thread raised error.
[2m[36m(train_mnist pid=9421)[0m Traceback (most recent call last):
[2m[36m(train_mnist pid=9421)[0m   File "/Users/gaominquan/anaconda3/lib/python3.7/site-packages/ray/tune/function_runner.py", line 277, in run
[2m[36m(train_mnist pid=9421)[0m     self._entrypoint()
[2m[36m(train_mnist pid=9421)[0m   File "/Users/gaominquan/anaconda3/lib/python3.7/site-packages/ray/tune/function_runner.py", line 352, in entrypoint
[2m[36m(train_mnist pid=9421)[0m     self._status_reporter.get_checkpoint(),
[2m[36m(train_mnist pid=9421)[0m   File "/Users/gaominquan/anaconda3/lib/python3.7/site-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
[2m[36m(train_mnist pid=9421)[0m     return method(self, *_args, **_kwargs)
[2m[36m(train_mnist pid=9421)[0m   File "/Users/gaominquan/anaconda3/lib/python3.7/site-packages/ray/tune/function_runner.py

Result for train_mnist_0cb5f_00000:
  date: 2022-08-15_05-37-45
  experiment_id: 0c344cf527d74a50a6ba68aefd1d8a36
  hostname: Minquans-MacBook-Pro.local
  node_ip: 127.0.0.1
  pid: 9421
  timestamp: 1660567065
  trial_id: 0cb5f_00000
  


Trial name,status,loc,lr,momentum
train_mnist_0cb5f_00000,ERROR,127.0.0.1:9421,2.4136e-10,0.500739

Trial name,# failures,error file
train_mnist_0cb5f_00000,1,"/Users/gaominquan/ray_results/train_mnist_2022-08-15_05-37-39/train_mnist_0cb5f_00000_0_lr=0.0000,momentum=0.5007_2022-08-15_05-37-39/error.txt"


Trial name,status,loc,lr,momentum
train_mnist_0cb5f_00000,ERROR,127.0.0.1:9421,2.4136e-10,0.500739

Trial name,# failures,error file
train_mnist_0cb5f_00000,1,"/Users/gaominquan/ray_results/train_mnist_2022-08-15_05-37-39/train_mnist_0cb5f_00000_0_lr=0.0000,momentum=0.5007_2022-08-15_05-37-39/error.txt"


TuneError: ('Trials did not complete', [train_mnist_0cb5f_00000])