Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lightgbm with PyTorch-based example to remove lightgbm dependency in visualization tutorial #5257

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
145 changes: 111 additions & 34 deletions tutorial/10_key_features/005_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Optuna provides various visualization features in :mod:`optuna.visualization` to analyze optimization results visually.

This tutorial walks you through this module by visualizing the history of lightgbm model for breast cancer dataset.
This tutorial walks you through this module by visualizing the optimization results of PyTorch model for FashionMNIST dataset.

For visualizing multi-objective optimization (i.e., the usage of :func:`optuna.visualization.plot_pareto_front`),
please refer to the tutorial of :ref:`multi_objective`.
Expand All @@ -33,11 +33,11 @@
"""

###################################################################################################
import lightgbm as lgb
import numpy as np
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


import optuna

Expand All @@ -55,74 +55,136 @@


SEED = 42
torch.manual_seed(SEED)

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10


def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 2)
layers = []

in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 64, 512)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())

in_features = out_features

layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))

return nn.Sequential(*layers)


# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()


def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = correct / N_VALID_EXAMPLES

np.random.seed(SEED)
return accuracy


###################################################################################################
# Define the objective function.
def objective(trial):
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
dtrain = lgb.Dataset(train_x, label=train_y)
dvalid = lgb.Dataset(valid_x, label=valid_y)

param = {
"objective": "binary",
"metric": "auc",
"verbosity": -1,
"boosting_type": "gbdt",
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
}

# Add a callback for pruning.
gbm = lgb.train(param, dtrain, valid_sets=[dvalid])

preds = gbm.predict(valid_x)
pred_labels = np.rint(preds)
accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
return accuracy
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)

val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)

optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)

for epoch in range(10):
train_model(model, optimizer, train_loader)

val_accuracy = eval_model(model, val_loader)
trial.report(val_accuracy, epoch)

if trial.should_prune():
raise optuna.exceptions.TrialPruned()

return val_accuracy


###################################################################################################
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=SEED),
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10),
pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=100, timeout=600)
study.optimize(objective, n_trials=30, timeout=300)

###################################################################################################
# Plot functions
# --------------
# Visualize the optimization history. See :func:`~optuna.visualization.plot_optimization_history` for the details.
plot_optimization_history(study)

###################################################################################################
# Visualize the learning curves of the trials. See :func:`~optuna.visualization.plot_intermediate_values` for the details.
plot_intermediate_values(study)

###################################################################################################
# Visualize high-dimensional parameter relationships. See :func:`~optuna.visualization.plot_parallel_coordinate` for the details.
plot_parallel_coordinate(study)

###################################################################################################
# Select parameters to visualize.
plot_parallel_coordinate(study, params=["bagging_freq", "bagging_fraction"])
plot_parallel_coordinate(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize hyperparameter relationships. See :func:`~optuna.visualization.plot_contour` for the details.
plot_contour(study)

###################################################################################################
# Select parameters to visualize.
plot_contour(study, params=["bagging_freq", "bagging_fraction"])
plot_contour(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize individual hyperparameters as slice plot. See :func:`~optuna.visualization.plot_slice` for the details.
plot_slice(study)

###################################################################################################
# Select parameters to visualize.
plot_slice(study, params=["bagging_freq", "bagging_fraction"])
plot_slice(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize parameter importances. See :func:`~optuna.visualization.plot_param_importances` for the details.
Expand All @@ -145,3 +207,18 @@ def objective(trial):
###################################################################################################
# Visualize the optimization timeline of performed trials. See :func:`~optuna.visualization.plot_timeline` for the details.
plot_timeline(study)

###################################################################################################
# Customize generated figures
# ---------------------------
# In :mod:`optuna.visualization` and :mod:`optuna.visualization.matplotlib`, a function returns an editable figure object:
# :class:`plotly.graph_objects.Figure` or :class:`matplotlib.axes.Axes` depending on the module.
# This allows users to modify the generated figure for their demand by using API of the visualization library.
# The following example replaces figure titles drawn by Plotly-based :func:`~optuna.visualization.plot_intermediate_values` manually.
fig = plot_intermediate_values(study)

fig.update_layout(
title="Hyperparameter optimization for FashionMNIST classification",
xaxis_title="Epoch",
yaxis_title="Validation Accuracy",
)