Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix endpoint issues in pyglove colab. Should resolve https://github.com/google/vizier/issues/1044 #1047

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/advanced_topics/pyglove/vizier_as_backend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"\n",
"import pyglove as pg\n",
"from vizier import pyglove as pg_vizier\n",
"from vizier.service import vizier_server"
"from vizier.service import servers"
]
},
{
Expand Down Expand Up @@ -136,7 +136,7 @@
"id": "zIbLEGi6prpm"
},
"source": [
"Alternatively, if using a remote server, the endpoint can be specified as well:"
"**Alternatively**, if using a remote server, the endpoint can be specified as well:"
]
},
{
Expand All @@ -147,7 +147,7 @@
},
"outputs": [],
"source": [
"server = vizier_server.DefaultVizierServer(host=hostname) # Normally hosted on a remote machine.\n",
"server = servers.DefaultVizierServer() # Normally hosted on a remote machine.\n",
"pg_vizier.init('my_study', vizier_endpoint=server.endpoint)"
]
},
Expand All @@ -169,15 +169,15 @@
},
"outputs": [],
"source": [
"num_workers = 10\n",
"NUM_WORKERS = 10\n",
"\n",
"\n",
"def work_fn(worker_id):\n",
" print(f\"Worker ID: {worker_id}\")\n",
" for value, feedback in pg.sample(\n",
" search_space,\n",
" algorithm=algorithm,\n",
" num_examples=num_trials // num_workers,\n",
" num_examples=num_trials // NUM_WORKERS,\n",
" name=\"worker_run\",\n",
" ):\n",
" reward = evaluator(value)\n",
Expand Down Expand Up @@ -217,7 +217,7 @@
"outputs": [],
"source": [
"with multiprocessing.pool.ThreadPool(num_workers) as pool:\n",
" pool.map(work_fn, range(num_workers))"
" pool.map(work_fn, range(NUM_WORKERS))"
]
},
{
Expand All @@ -238,7 +238,7 @@
"outputs": [],
"source": [
"processes = []\n",
"for worker_id in range(num_workers):\n",
"for worker_id in range(NUM_WORKERS):\n",
" p = multiprocessing.Process(target=work_fn, args=(worker_id,))\n",
" p.start()\n",
" processes.append(p)\n",
Expand All @@ -265,7 +265,7 @@
"outputs": [],
"source": [
"# Server Machine\n",
"server = vizier_server.DefaultVizierServer(host=hostname)"
"server = servers.DefaultVizierServer()"
]
},
{
Expand Down
Loading