Skip to content

Commit

Permalink
Merge pull request #5257 from nzw0301/replace-lightgbm-with-pytorch-f…
Browse files Browse the repository at this point in the history
…or-visualisation-tutorial

Replace lightgbm with PyTorch-based example to remove lightgbm dependency in visualization tutorial
  • Loading branch information
nabenabe0928 committed Feb 20, 2024
2 parents beacca8 + 359b5bc commit 52e2f5d
Showing 1 changed file with 112 additions and 35 deletions.
147 changes: 112 additions & 35 deletions tutorial/10_key_features/005_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Optuna provides various visualization features in :mod:`optuna.visualization` to analyze optimization results visually.
This tutorial walks you through this module by visualizing the history of lightgbm model for breast cancer dataset.
This tutorial walks you through this module by visualizing the optimization results of PyTorch model for FashionMNIST dataset.
For visualizing multi-objective optimization (i.e., the usage of :func:`optuna.visualization.plot_pareto_front`),
please refer to the tutorial of :ref:`multi_objective`.
Expand All @@ -33,11 +33,11 @@
"""

###################################################################################################
import lightgbm as lgb
import numpy as np
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


import optuna

Expand All @@ -54,75 +54,137 @@
from optuna.visualization import plot_timeline


SEED = 42
SEED = 13
torch.manual_seed(SEED)

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10


def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 2)
layers = []

in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 64, 512)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())

in_features = out_features

layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))

return nn.Sequential(*layers)


# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()


def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = correct / N_VALID_EXAMPLES

np.random.seed(SEED)
return accuracy


###################################################################################################
# Define the objective function.
def objective(trial):
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
dtrain = lgb.Dataset(train_x, label=train_y)
dvalid = lgb.Dataset(valid_x, label=valid_y)

param = {
"objective": "binary",
"metric": "auc",
"verbosity": -1,
"boosting_type": "gbdt",
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
}

# Add a callback for pruning.
gbm = lgb.train(param, dtrain, valid_sets=[dvalid])

preds = gbm.predict(valid_x)
pred_labels = np.rint(preds)
accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
return accuracy
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)

val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)

optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)

for epoch in range(10):
train_model(model, optimizer, train_loader)

val_accuracy = eval_model(model, val_loader)
trial.report(val_accuracy, epoch)

if trial.should_prune():
raise optuna.exceptions.TrialPruned()

return val_accuracy


###################################################################################################
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=SEED),
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10),
pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=100, timeout=600)
study.optimize(objective, n_trials=30, timeout=300)

###################################################################################################
# Plot functions
# --------------
# Visualize the optimization history. See :func:`~optuna.visualization.plot_optimization_history` for the details.
plot_optimization_history(study)

###################################################################################################
# Visualize the learning curves of the trials. See :func:`~optuna.visualization.plot_intermediate_values` for the details.
plot_intermediate_values(study)

###################################################################################################
# Visualize high-dimensional parameter relationships. See :func:`~optuna.visualization.plot_parallel_coordinate` for the details.
plot_parallel_coordinate(study)

###################################################################################################
# Select parameters to visualize.
plot_parallel_coordinate(study, params=["bagging_freq", "bagging_fraction"])
plot_parallel_coordinate(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize hyperparameter relationships. See :func:`~optuna.visualization.plot_contour` for the details.
plot_contour(study)

###################################################################################################
# Select parameters to visualize.
plot_contour(study, params=["bagging_freq", "bagging_fraction"])
plot_contour(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize individual hyperparameters as slice plot. See :func:`~optuna.visualization.plot_slice` for the details.
plot_slice(study)

###################################################################################################
# Select parameters to visualize.
plot_slice(study, params=["bagging_freq", "bagging_fraction"])
plot_slice(study, params=["lr", "n_layers"])

###################################################################################################
# Visualize parameter importances. See :func:`~optuna.visualization.plot_param_importances` for the details.
Expand All @@ -145,3 +207,18 @@ def objective(trial):
###################################################################################################
# Visualize the optimization timeline of performed trials. See :func:`~optuna.visualization.plot_timeline` for the details.
plot_timeline(study)

###################################################################################################
# Customize generated figures
# ---------------------------
# In :mod:`optuna.visualization` and :mod:`optuna.visualization.matplotlib`, a function returns an editable figure object:
# :class:`plotly.graph_objects.Figure` or :class:`matplotlib.axes.Axes` depending on the module.
# This allows users to modify the generated figure for their demand by using API of the visualization library.
# The following example replaces figure titles drawn by Plotly-based :func:`~optuna.visualization.plot_intermediate_values` manually.
fig = plot_intermediate_values(study)

fig.update_layout(
title="Hyperparameter optimization for FashionMNIST classification",
xaxis_title="Epoch",
yaxis_title="Validation Accuracy",
)

0 comments on commit 52e2f5d

Please sign in to comment.