In [15]:
import os
import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
from pathlib import Path
from optuna_dashboard import run_server
from optuna.visualization import plot_pareto_front

In [14]:
DB_PATH = Path("/Users/maryamhomayoon/PycharmProjects/optuna-examples-cloned/optuna-examples/db.sqlite3")
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
STORAGE = f"sqlite:///{DB_PATH.as_posix()}"

SEED = 42


DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if DEVICE.type == "mps":
    torch.mps.manual_seed(SEED)
else:
    torch.manual_seed(SEED)


BATCHSIZE = 128
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10

Using device: mps


In [None]:
def define_model(trial):
    # We optimize the number of layers, hidden units and dropout ratio in each layer.
    n_layers = trial.suggest_int("n_layers", 1, 10)
    layers = []
    in_out_features = [] 

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        in_out_features.append((in_features, out_features))
        # p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
        # layers.append(nn.Dropout(p))

        in_features = out_features
    layers.append(nn.Linear(in_features, CLASSES))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers) , in_out_features

In [None]:
def estimate_layer_latency(in_features, out_features, num_cores, batch):
    latency = in_features * out_features * batch / num_cores
    return latency

In [None]:
def get_mnist():
    # Load FashionMNIST dataset.
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=False, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=False,
    )

    return train_loader, valid_loader

In [None]:
def objective(trial):
    # Generate the model.
    model, out = define_model(trial)
    model = model.to(DEVICE)
    # Generate the optimizers.
    # optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    # lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    # optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=lr)

    lat = 0.0
    for i, (in_f, out_f) in enumerate(out):
            lat += estimate_layer_latency(in_f, out_f,num_cores=1, batch=BATCHSIZE)
    trial.set_user_attr("latency_ms_est", lat)

    # Get the FashionMNIST dataset.
    train_loader, valid_loader = get_mnist()

    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            # Limiting training data for faster epochs.
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                # Limiting validation data.
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
                output = model(data)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

        # trial.report(accuracy, epoch)
        # print(f"[Trial {trial.number} | Epoch {epoch+1}/{EPOCHS}] acc={accuracy:.4f}, lat={lat:.4f}ms")

        # # Handle pruning based on the intermediate value.
        # if trial.should_prune():
        #     raise optuna.exceptions.TrialPruned()
        
    return accuracy, lat

In [None]:
study = optuna.create_study(
    directions=["maximize","minimize"],
    study_name="Jup_fashion_mnist_fixed_lr_with_latÙ€batch+plot",
    storage=STORAGE,
    load_if_exists=True
)
study.optimize(objective, n_trials=100, timeout=600)

# pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
# print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

best_trials = study.best_trials
print(f"\nNumber of Pareto-optimal trials: {len(best_trials)}")

for t in best_trials:
    print(f"  Values: accuracy={t.values[0]:.4f}, latency={t.values[1]:.4f}")
    print("  Params:")
    for k, v in t.params.items():
        print(f"    {k}: {v}")


fig = plot_pareto_front(study, target_names=["Accuracy", "Latency"])
fig.show()

In [None]:
# Start the Optuna Dashboard server on localhost:8080
run_server(STORAGE)