In [None]:
#!pip install "ray[tune]"

In [1]:
import os
import tempfile

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock

In [3]:
from torchvision import datasets, transforms

import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import AsyncHyperBandScheduler

In [4]:
# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 5
TEST_SIZE = 256


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


def train_func(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test_func(model, data_loader, device=None):
    device = device or torch.device("cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total


def get_data_loaders(batch_size=64):
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/data.lock")):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "~/data", train=True, download=True, transform=mnist_transforms
            ),
            batch_size=batch_size,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "~/data", train=False, download=True, transform=mnist_transforms
            ),
            batch_size=batch_size,
            shuffle=True,
        )
    return train_loader, test_loader


def train_mnist(config):
    should_checkpoint = config.get("should_checkpoint", False)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = get_data_loaders()
    model = ConvNet().to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"]
    )

    while True:
        train_func(model, optimizer, train_loader, device)
        acc = test_func(model, test_loader, device)
        metrics = {"mean_accuracy": acc}

        # Report metrics (and possibly a checkpoint)
        if should_checkpoint:
            with tempfile.TemporaryDirectory() as tempdir:
                torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
                train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
        else:
            train.report(metrics)


In [5]:
ray.init(num_cpus=2)

# for early stopping
sched = AsyncHyperBandScheduler()

resources_per_trial = {"cpu": 2, "gpu": 0}  
tuner = tune.Tuner(
    tune.with_resources(train_mnist, resources=resources_per_trial),
    tune_config=tune.TuneConfig(
        metric="mean_accuracy",
        mode="max",
        scheduler=sched,
        num_samples=10,
    ),
    run_config=train.RunConfig(
        name="exp",
        stop={
            "mean_accuracy": 0.98,
            "training_iteration": 100,
        },
    ),
    param_space={
        "lr": tune.loguniform(1e-4, 1e-2),
        "momentum": tune.uniform(0.1, 0.9),
    },
)
results = tuner.fit()

print("Best config is:", results.get_best_result().config)

assert not results.errors

0,1
Current time:,2025-01-22 10:41:26
Running for:,00:01:11.85
Memory:,3.8/15.5 GiB

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_e1a7f_00000,TERMINATED,172.25.227.84:514227,0.00107739,0.237296,0.296875,100,13.4715
train_mnist_e1a7f_00001,TERMINATED,172.25.227.84:514419,0.00011328,0.682774,0.078125,1,1.29499
train_mnist_e1a7f_00002,TERMINATED,172.25.227.84:514535,0.00177864,0.679697,0.0875,1,1.28485
train_mnist_e1a7f_00003,TERMINATED,172.25.227.84:514676,0.00290934,0.780965,0.096875,1,1.40882
train_mnist_e1a7f_00004,TERMINATED,172.25.227.84:514797,0.000144907,0.857826,0.09375,16,3.29882
train_mnist_e1a7f_00005,TERMINATED,172.25.227.84:514938,0.00485494,0.549558,0.05625,1,1.18666
train_mnist_e1a7f_00006,TERMINATED,172.25.227.84:515070,0.000136043,0.139247,0.046875,1,1.20644
train_mnist_e1a7f_00007,TERMINATED,172.25.227.84:515188,0.000137921,0.405633,0.09375,1,1.33575
train_mnist_e1a7f_00008,TERMINATED,172.25.227.84:515320,0.000553241,0.241452,0.040625,1,1.23723
train_mnist_e1a7f_00009,TERMINATED,172.25.227.84:515458,0.00121003,0.74033,0.090625,1,1.44194


[36m(train_mnist pid=514227)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[36m(train_mnist pid=514227)[0m Failed to download (trying next):
[36m(train_mnist pid=514227)[0m <urlopen error [Errno 111] Connection refused>
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/jdowling/data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]
  1%|          | 98304/9912422 [00:00<00:18, 525633.44it/s]
  3%|▎         | 262144/9912422 [00:00<00:13, 725215.30it/s]
 14%|█▍        | 1376256/9912422 [00:00<00:02, 3391858.34it/s]
 24%|██▍       | 2424832/9912422 [00:00<00:01, 5511123.75it/s]
 34%|███▎      | 3342336/9912422 [00:00<00:01, 4729231.36it/s]
 68%|██████▊   | 6717440/9912422 [00:01<00:00, 11551889.22it/s]
 83%|████████▎ | 8192000/9912422 [00:01<00:00, 12074231.61it/s]
100%|██████████| 9912422/9912422 [00:01<00:00, 7514320.11it/s] 


[36m(train_mnist pid=514227)[0m Extracting /home/jdowling/data/MNIST/raw/train-images-idx3-ubyte.gz to /home/jdowling/data/MNIST/raw
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
[36m(train_mnist pid=514227)[0m Failed to download (trying next):
[36m(train_mnist pid=514227)[0m <urlopen error [Errno 111] Connection refused>
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/jdowling/data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 296365.86it/s]


[36m(train_mnist pid=514227)[0m Extracting /home/jdowling/data/MNIST/raw/train-labels-idx1-ubyte.gz to /home/jdowling/data/MNIST/raw
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
[36m(train_mnist pid=514227)[0m Failed to download (trying next):
[36m(train_mnist pid=514227)[0m <urlopen error [Errno 111] Connection refused>
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/jdowling/data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]
  4%|▍         | 65536/1648877 [00:00<00:04, 341275.85it/s]
 16%|█▌        | 262144/1648877 [00:00<00:01, 733870.73it/s]
 32%|███▏      | 524288/1648877 [00:00<00:00, 1296843.71it/s]
 50%|████▉     | 819200/1648877 [00:00<00:00, 1783927.93it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 2383116.20it/s]


[36m(train_mnist pid=514227)[0m Extracting /home/jdowling/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/jdowling/data/MNIST/raw
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
[36m(train_mnist pid=514227)[0m Failed to download (trying next):
[36m(train_mnist pid=514227)[0m <urlopen error [Errno 111] Connection refused>
[36m(train_mnist pid=514227)[0m 
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
[36m(train_mnist pid=514227)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/jdowling/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3881525.83it/s]


[36m(train_mnist pid=514227)[0m Extracting /home/jdowling/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/jdowling/data/MNIST/raw
[36m(train_mnist pid=514227)[0m 


2025-01-22 10:41:26,559	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/jdowling/ray_results/exp' in 0.0135s.
2025-01-22 10:41:26,572	INFO tune.py:1041 -- Total run time: 72.21 seconds (71.83 seconds for the tuning loop).


Best config is: {'lr': 0.0010773901567539795, 'momentum': 0.2372958260270055}
