In [16]:
import os
import tempfile

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.models import resnet18

from ray import train, tune
from ray.tune.schedulers import ASHAScheduler

import ray
import ray.train.torch
from ray.train import Checkpoint

In [17]:
data_dir = os.path.join(os.getcwd(), "../data")

def train_func(model, optimizer, criterion, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test_func(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

def train_mnist(config):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), 
         torchvision.transforms.Normalize((0.5,), (0.5,))]
    )

    train_loader = DataLoader(
        torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),
        batch_size=128,
        shuffle=True)
    test_loader = DataLoader(
        torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),
        batch_size=128,
        shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    
    # 训练 10 个 epoch
    for epoch in range(10):
        train_func(model, optimizer, criterion, train_loader)
        acc = test_func(model, test_loader)
        print(f"epoch: {epoch}, acc: {acc}")

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None
            if (epoch + 1) % 5 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(temp_checkpoint_dir, "model.pth")
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            train.report({"mean_accuracy": acc}, checkpoint=checkpoint)

In [18]:
# search_space = {
#     "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
#     "momentum": tune.uniform(0.1, 0.9),
# }

# # Uncomment this to enable distributed execution
# # `ray.init(address="auto")`
# trainable_with_gpu = tune.with_resources(train_mnist, {"gpu": 1})

# tuner = tune.Tuner(
#     trainable_with_gpu,
#     param_space=search_space,
# )
# results = tuner.fit()

In [19]:
train_mnist({'lr': 0.001, 'momentum': 0.5})

epoch: 0, acc: 0.8123
epoch: 1, acc: 0.8381
epoch: 2, acc: 0.8509
epoch: 3, acc: 0.8576
epoch: 4, acc: 0.8594
epoch: 5, acc: 0.8609
epoch: 6, acc: 0.8643
epoch: 7, acc: 0.865
epoch: 8, acc: 0.8702
epoch: 9, acc: 0.8661


In [20]:
# tuner = tune.Tuner(
#     train_mnist,
#     tune_config=tune.TuneConfig(
#         num_samples=20,
#         scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
#     ),
#     param_space=search_space,
# )
# results = tuner.fit()

# # Obtain a trial dataframe from all run trials of this `tune.run` call.
# dfs = {result.path: result.metrics_dataframe for result in results}

In [21]:
# # Plot by epoch
# ax = None  # This plots everything on the same plot
# for d in dfs.values():
#     ax = d.mean_accuracy.plot(ax=ax, legend=False)

In [22]:
# import os
# import tempfile

# import torch
# from torch.nn import CrossEntropyLoss
# from torch.optim import Adam
# from torch.utils.data import DataLoader
# from torchvision.models import resnet18
# from torchvision.datasets import FashionMNIST
# from torchvision.transforms import ToTensor, Normalize, Compose

# # Model, Loss, Optimizer
# model = resnet18(num_classes=10)
# model.conv1 = torch.nn.Conv2d(
#     1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
# )
# model.to("cuda")
# criterion = CrossEntropyLoss()
# optimizer = Adam(model.parameters(), lr=0.001)

# # Data
# transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
# train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
# train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

# # Training
# for epoch in range(10):
#     for images, labels in train_loader:
#         print(images.shape)
#         print(labels.shape)
#         images, labels = images.to("cuda"), labels.to("cuda")
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     metrics = {"loss": loss.item(), "epoch": epoch}
#     checkpoint_dir = tempfile.mkdtemp()
#     checkpoint_path = os.path.join(checkpoint_dir, "model.pt")
#     torch.save(model.state_dict(), checkpoint_path)
#     print(metrics)