-
Notifications
You must be signed in to change notification settings - Fork 425
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* docs: update new algorithm service details * feat: trial augmentation strategy * feat: pbt suggestion service * feat: PbtTemplate and associated test image * feat: introduce annotation field to trial specifications * feat: trial assignment changes to support annotations from suggestion - Add new Annotation types to suggestion_types.go - Add Annotation object and update Trial parser in trial.py * feat: update pbt suggestion to use new Annotation api - Suggestion uses exact match to track spawned trials - Trials that get transmitted, but not created (or added to experiment) are added back to the respawn pool (population_size consistency) * chore: gofmt and black run across PBT changes * feedback: remove tf summary export, change default print unit, reduce range to be percentage compatible. * feedback: move PBT template to example. * feedback: changes to inject_webhook and utils. - Rename mutateVolume to mutateMetricsCollectorVolume - Add addContainerVolumeMount - Add getPrimaryContainerIndex * feedback: change suggestion mutation mount variable name and add to consts * feedback: Add trial_names to GetSuggestionsReply and change suggestion path to <experiment>/<trial> * feedback: removed unnecessary checks and moved to async pbt implementation * feedback: update trial name override location and change annotations override to labels. * feedback: add pbt to github workflow * feedback: move labels to ParameterAssignments in GetSuggestionsReply and cleanup pbt.yaml. * feedback: remove operator changes * feedback: GHA updates * feedback: new formatting changes * feedback: add suggestion-pbt to gh-actions build-load.sh. * fix: missing pbt->simple-pbt name changes, add simple-pbt to update-images.sh update yaml function (causing failing gha). * feedback: add pointer to website from main readme for pbt
- Loading branch information
Showing
45 changed files
with
1,526 additions
and
232 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
name: E2E Test with simple-pbt | ||
on: | ||
- pull_request | ||
|
||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
|
||
jobs: | ||
e2e: | ||
runs-on: ubuntu-20.04 | ||
timeout-minutes: 120 | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v2 | ||
|
||
- name: Setup Test Env | ||
uses: ./.github/workflows/template-setup-e2e-test | ||
with: | ||
kubernetes-version: ${{ matrix.kubernetes-version }} | ||
|
||
- name: Run e2e test with ${{ matrix.experiments }} experiments | ||
uses: ./.github/workflows/template-e2e-test | ||
with: | ||
experiments: ${{ matrix.experiments }} | ||
# Comma Delimited | ||
trial-images: simple-pbt | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
# Detail: https://hub.docker.com/r/kindest/node | ||
# TODO (tenzen-y): We need to consider running tests on more kubernetes versions. | ||
# kubernetes-version: ["v1.20.15", "v1.21.12", "v1.22.9", "v1.23.6", "v1.24.1"] | ||
kubernetes-version: ["v1.21.12", "v1.22.9", "v1.23.6"] | ||
# Comma Delimited | ||
experiments: ["simple-pbt"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
FROM python:3.9-slim | ||
|
||
ENV TARGET_DIR /opt/katib | ||
ENV SUGGESTION_DIR cmd/suggestion/pbt/v1beta1 | ||
ENV GRPC_HEALTH_PROBE_VERSION v0.4.6 | ||
|
||
RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \ | ||
apt-get -y update && \ | ||
apt-get -y install gfortran libopenblas-dev liblapack-dev wget && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/*; \ | ||
else \ | ||
apt-get -y update && \ | ||
apt-get -y install wget && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/*; \ | ||
fi | ||
RUN if [ "$(uname -m)" = "ppc64le" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \ | ||
elif [ "$(uname -m)" = "aarch64" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \ | ||
else \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \ | ||
fi && \ | ||
chmod +x /bin/grpc_health_probe | ||
|
||
ADD ./pkg/ ${TARGET_DIR}/pkg/ | ||
ADD ./${SUGGESTION_DIR}/ ${TARGET_DIR}/${SUGGESTION_DIR}/ | ||
WORKDIR ${TARGET_DIR}/${SUGGESTION_DIR} | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
RUN chgrp -R 0 ${TARGET_DIR} \ | ||
&& chmod -R g+rwX ${TARGET_DIR} | ||
|
||
ENV PYTHONPATH ${TARGET_DIR}:${TARGET_DIR}/pkg/apis/manager/v1beta1/python:${TARGET_DIR}/pkg/apis/manager/health/python | ||
|
||
ENTRYPOINT ["python", "main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2022 The Kubeflow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import grpc | ||
import time | ||
from pkg.apis.manager.v1beta1.python import api_pb2_grpc | ||
from pkg.apis.manager.health.python import health_pb2_grpc | ||
from pkg.suggestion.v1beta1.pbt.service import PbtService | ||
from concurrent import futures | ||
|
||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 | ||
DEFAULT_PORT = "0.0.0.0:6789" | ||
|
||
|
||
def serve(): | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
service = PbtService() | ||
api_pb2_grpc.add_SuggestionServicer_to_server(service, server) | ||
health_pb2_grpc.add_HealthServicer_to_server(service, server) | ||
|
||
server.add_insecure_port(DEFAULT_PORT) | ||
print("Listening...") | ||
server.start() | ||
try: | ||
while True: | ||
time.sleep(_ONE_DAY_IN_SECONDS) | ||
except KeyboardInterrupt: | ||
server.stop(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
grpcio==1.41.1 | ||
protobuf==3.19.1 | ||
googleapis-common-protos==1.53.0 | ||
numpy==1.22.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
apiVersion: kubeflow.org/v1beta1 | ||
kind: Experiment | ||
metadata: | ||
namespace: kubeflow | ||
name: simple-pbt | ||
spec: | ||
maxTrialCount: 2 | ||
parallelTrialCount: 2 | ||
maxFailedTrialCount: 3 | ||
resumePolicy: FromVolume | ||
objective: | ||
type: maximize | ||
goal: 0.99 | ||
objectiveMetricName: Validation-accuracy | ||
algorithm: | ||
algorithmName: pbt | ||
algorithmSettings: | ||
- name: suggestion_trial_dir | ||
value: /var/log/katib/checkpoints/ | ||
- name: n_population | ||
value: '40' | ||
- name: truncation_threshold | ||
value: '0.2' | ||
parameters: | ||
- name: lr | ||
parameterType: double | ||
feasibleSpace: | ||
min: '0.0001' | ||
max: '0.02' | ||
step: '0.0001' | ||
trialTemplate: | ||
primaryContainerName: training-container | ||
trialParameters: | ||
- name: learningRate | ||
description: Learning rate for training the model | ||
reference: lr | ||
trialSpec: | ||
apiVersion: batch/v1 | ||
kind: Job | ||
spec: | ||
template: | ||
spec: | ||
containers: | ||
- name: training-container | ||
image: docker.io/kubeflowkatib/simple-pbt:latest | ||
command: | ||
- "python3" | ||
- "/opt/pbt/pbt_test.py" | ||
- "--epochs=20" | ||
- "--lr=${trialParameters.learningRate}" | ||
- "--checkpoint=/var/log/katib/checkpoints/" | ||
restartPolicy: Never |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM python:3.9-slim | ||
|
||
ADD examples/v1beta1/trial-images/simple-pbt /opt/pbt | ||
WORKDIR /opt/pbt | ||
|
||
RUN python3 -m pip install -r requirements.txt | ||
|
||
RUN chgrp -R 0 /opt/pbt \ | ||
&& chmod -R g+rwX /opt/pbt | ||
|
||
ENTRYPOINT ["python3", "/opt/pbt/pbt_test.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#!/usr/bin/env python | ||
|
||
# Implementation based on: | ||
# https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py | ||
|
||
import argparse | ||
import numpy as np | ||
import os | ||
import pickle | ||
import random | ||
import time | ||
|
||
# Ensure job runs for at least this long (secs) to allow metrics collector to | ||
# read PID correctly before cleanup | ||
_METRICS_COLLECTOR_SPAWN_LATENCY = 7 | ||
|
||
|
||
class PBTBenchmarkExample: | ||
"""Toy PBT problem for benchmarking adaptive learning rate. | ||
The goal is to optimize this trainable's accuracy. The accuracy increases | ||
fastest at the optimal lr, which is a function of the current accuracy. | ||
The optimal lr schedule for this problem is the triangle wave as follows. | ||
Note that many lr schedules for real models also follow this shape: | ||
best lr | ||
^ | ||
| /\ | ||
| / \ | ||
| / \ | ||
| / \ | ||
------------> accuracy | ||
In this problem, using PBT with a population of 2-4 is sufficient to | ||
roughly approximate this lr schedule. Higher population sizes will yield | ||
faster convergence. Training will not converge without PBT. | ||
""" | ||
|
||
def __init__(self, lr, checkpoint: str): | ||
self._lr = lr | ||
|
||
self._checkpoint_file = os.path.join(checkpoint, "training.ckpt") | ||
if os.path.exists(self._checkpoint_file): | ||
with open(self._checkpoint_file, "rb") as fin: | ||
checkpoint_data = pickle.load(fin) | ||
self._accuracy = checkpoint_data["accuracy"] | ||
self._step = checkpoint_data["step"] | ||
else: | ||
os.makedirs(checkpoint, exist_ok=True) | ||
self._step = 1 | ||
self._accuracy = 0.0 | ||
|
||
def save_checkpoint(self): | ||
with open(self._checkpoint_file, "wb") as fout: | ||
pickle.dump({"step": self._step, "accuracy": self._accuracy}, fout) | ||
|
||
def step(self): | ||
midpoint = 50 # lr starts decreasing after acc > midpoint | ||
q_tolerance = 3 # penalize exceeding lr by more than this multiple | ||
noise_level = 2 # add gaussian noise to the acc increase | ||
# triangle wave: | ||
# - start at 0.001 @ t=0, | ||
# - peak at 0.01 @ t=midpoint, | ||
# - end at 0.001 @ t=midpoint * 2, | ||
if self._accuracy < midpoint: | ||
optimal_lr = 0.01 * self._accuracy / midpoint | ||
else: | ||
optimal_lr = 0.01 - 0.01 * (self._accuracy - midpoint) / midpoint | ||
optimal_lr = min(0.01, max(0.001, optimal_lr)) | ||
|
||
# compute accuracy increase | ||
q_err = max(self._lr, optimal_lr) / ( | ||
min(self._lr, optimal_lr) + np.finfo(float).eps | ||
) | ||
if q_err < q_tolerance: | ||
self._accuracy += (1.0 / q_err) * random.random() | ||
elif self._lr > optimal_lr: | ||
self._accuracy -= (q_err - q_tolerance) * random.random() | ||
self._accuracy += noise_level * np.random.normal() | ||
self._accuracy = max(0, min(100, self._accuracy)) | ||
|
||
self._step += 1 | ||
|
||
def __repr__(self): | ||
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format( | ||
self._step, self._lr, self._accuracy / 100 | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse CLI arguments | ||
parser = argparse.ArgumentParser(description="PBT Basic Test") | ||
parser.add_argument( | ||
"--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)" | ||
) | ||
parser.add_argument( | ||
"--epochs", type=int, default=20, help="number of epochs to train (default: 20)" | ||
) | ||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
default="/var/log/katib/checkpoints/", | ||
help="checkpoint directory (resume and save)", | ||
) | ||
opt = parser.parse_args() | ||
|
||
benchmark = PBTBenchmarkExample(opt.lr, opt.checkpoint) | ||
|
||
start_time = time.time() | ||
for i in range(opt.epochs): | ||
benchmark.step() | ||
exec_time_thresh = time.time() - start_time - _METRICS_COLLECTOR_SPAWN_LATENCY | ||
if exec_time_thresh < 0: | ||
time.sleep(abs(exec_time_thresh)) | ||
benchmark.save_checkpoint() | ||
|
||
print(benchmark) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
numpy==1.22.2 |
Oops, something went wrong.