In [5]:
import os

from pathlib import Path
from datetime import datetime
from functools import partial

import torch

from torchvision import datasets, transforms

from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchmetrics import Accuracy, Precision, Recall

from source.resnet import TorchModel, ResNet, BasicBlock, Bottleneck
from source.callback import CompositeCallback, ClassificationReporter, Profiler, Saver, Tuner, DefaultCallback
from source.plotting import matplotlib_imshow


from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [6]:
cwd = Path(os.getcwd())


train_dir = cwd / "imagenette2-320" / "train"

tsfm_train = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

trainset = datasets.ImageFolder(root=train_dir, transform=tsfm_train)

In [7]:
config = {
    "layer1": tune.randint(2, 4),
    "layer2": tune.randint(2, 4),
    "layer3": tune.randint(4, 16),
    "layer4": tune.randint(2, 4),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16, 32, 64])
}

def to_tune(config, checkpoint_dir=None, data_dir=None):
    trainset = datasets.ImageFolder(root=data_dir, transform=tsfm_train)
    classes = trainset.classes
    num_classes = len(classes)

    card = int(len(trainset) * 0.8)
    trainset, valset = random_split(trainset, [card, len(trainset) - card])

    trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True)
    valloader = DataLoader(dataset=valset, batch_size=16, shuffle=True)

    trainloader.classes, valloader.classes = classes, classes


    model = ResNet(
        block_cls=BasicBlock,
        layers=[config["layer1"], config["layer2"], config["layer3"], config["layer4"]],
        num_classes=num_classes
    )
    criterion = torch.nn.CrossEntropyLoss(reduction="sum")
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    callback = Tuner()

    torchmodel = TorchModel(model, optimizer, criterion, callback=callback)
    torchmodel.train(trainloader, valloader, epochs=100)

scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=100,
    grace_period=1,
    reduction_factor=2
)
reporter = CLIReporter(
    metric_columns=["loss", "accuracy", "training_iteration"]
)
result = tune.run(
    partial(to_tune, data_dir=train_dir),
    resources_per_trial={"cpu": 1, "gpu": int(torch.cuda.is_available())},
    config=config,
    num_samples=20,
    scheduler=scheduler,
    progress_reporter=reporter,
    local_dir=(cwd / "ray_result")
)


== Status ==
Current time: 2022-06-21 04:56:22 (running for 00:00:00.70)
Memory usage on this node: 7.7/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 64.000: None | Iter 32.000: None | Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/8 CPUs, 0/0 GPUs, 0.0/6.96 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/pavelkiselev/PycharmProjects/pythonProject4/ray_result/to_tune_2022-06-21_04-56-21
Number of trials: 16/20 (15 PENDING, 1 RUNNING)
+---------------------+----------+-----------------+--------------+----------+----------+----------+----------+-------------+
| Trial name          | status   | loc             |   batch_size |   layer1 |   layer2 |   layer3 |   layer4 |          lr |
|---------------------+----------+-----------------+--------------+----------+----------+----------+----------+-------------|
| to_tune_59710_00000 | RUNNING  | 127.0.0.1:33743 |           32 |        3 |        2 |      

[2m[36m(func pid=33743)[0m E0621 04:56:24.528817000 6158610432 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers


[2m[36m(func pid=33743)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]


[2m[36m(func pid=33753)[0m E0621 04:56:27.321573000 12901707776 fork_posix.cc:76]                 Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(func pid=33751)[0m E0621 04:56:27.306417000 6251950080 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(func pid=33754)[0m E0621 04:56:27.371380000 6265761792 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers


[2m[36m(func pid=33753)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33753)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
[2m[36m(func pid=33754)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33754)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
[2m[36m(func pid=33751)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33751)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
[2m[36m(func pid=33752)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33752)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
[2m[36m(func pid=33748)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33748)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]


[2m[36m(func pid=33748)[0m E0621 04:56:27.523975000 6164344832 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(func pid=33752)[0m E0621 04:56:27.491097000 6220263424 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(func pid=33750)[0m E0621 04:56:27.562372000 6320353280 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(func pid=33749)[0m E0621 04:56:27.589824000 6228832256 fork_posix.cc:76]                  Other threads are currently calling into gRPC, skipping fork() handlers


[2m[36m(func pid=33750)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33750)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
[2m[36m(func pid=33749)[0m Epoch loop:   0%|          | 0/100 [00:00<?, ?epoch/s]
[2m[36m(func pid=33749)[0m Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]
== Status ==
Current time: 2022-06-21 04:56:29 (running for 00:00:07.59)
Memory usage on this node: 11.2/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 64.000: None | Iter 32.000: None | Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/6.96 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/pavelkiselev/PycharmProjects/pythonProject4/ray_result/to_tune_2022-06-21_04-56-21
Number of trials: 20/20 (12 PENDING, 8 RUNNING)
+---------------------+----------+-----------------+--------------+----------+----------+----------+----------+-------------+
| Trial na

2022-06-21 04:56:32,997	ERROR trial_runner.py:886 -- Trial to_tune_59710_00006: Error processing event.
NoneType: None


Result for to_tune_59710_00006:
  date: 2022-06-21_04-56-27
  experiment_id: e75f63737bf04f9298d28fa5741fde6f
  hostname: MacBook-Air-Pavel.local
  node_ip: 127.0.0.1
  pid: 33753
  timestamp: 1655776587
  trial_id: '59710_00006'
  
== Status ==
Current time: 2022-06-21 04:56:33 (running for 00:00:11.09)
Memory usage on this node: 13.3/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 64.000: None | Iter 32.000: None | Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 7.0/8 CPUs, 0/0 GPUs, 0.0/6.96 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/pavelkiselev/PycharmProjects/pythonProject4/ray_result/to_tune_2022-06-21_04-56-21
Number of trials: 20/20 (1 ERROR, 12 PENDING, 7 RUNNING)
+---------------------+----------+-----------------+--------------+----------+----------+----------+----------+-------------+
| Trial name          | status   | loc             |   batch_size |   layer1 |   layer2 |   layer3 |   l

2022-06-21 04:56:33,389	ERROR tune.py:743 -- Trials did not complete: [to_tune_59710_00000, to_tune_59710_00001, to_tune_59710_00002, to_tune_59710_00003, to_tune_59710_00004, to_tune_59710_00005, to_tune_59710_00006, to_tune_59710_00007, to_tune_59710_00008, to_tune_59710_00009, to_tune_59710_00010, to_tune_59710_00011, to_tune_59710_00012, to_tune_59710_00013, to_tune_59710_00014, to_tune_59710_00015, to_tune_59710_00016, to_tune_59710_00017, to_tune_59710_00018, to_tune_59710_00019]
2022-06-21 04:56:33,390	INFO tune.py:747 -- Total run time: 11.46 seconds (11.09 seconds for the tuning loop).

In [None]:
best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(best_trial.last_result["accuracy"]))