New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial pass at wandb Ludwig integration #514
Changes from 8 commits
64c00db
52ec945
717245a
6161ff2
44b682d
0b24405
d437278
b3df25b
0428aed
e11365e
078d2c9
70e1c60
b086647
6691c19
9bfdf2f
4a71e88
0285945
7568c78
09b559c
c796f5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# coding=utf-8 | ||
# Copyright (c) 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import logging | ||
import os | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Wandb(): | ||
"Class that defines the methods necessary to hook into process." | ||
|
||
@staticmethod | ||
def import_call(argv, *args, **kwargs): | ||
""" | ||
Enable Third-party support from wandb.ai | ||
Allows experiment tracking, visualization, and | ||
management. | ||
""" | ||
try: | ||
import wandb | ||
# Needed to call an attribute of wandb to make DeepSource not complain | ||
return Wandb() if wandb.__version__ else None | ||
except ImportError: | ||
logger.error( | ||
"Ignored --wandb: Please install wandb; see https://docs.wandb.com") | ||
return None | ||
|
||
@staticmethod | ||
def train_model(model, *args, **kwargs): | ||
import wandb | ||
logger.info("wandb.train_model() called...") | ||
config = model.hyperparameters.copy() | ||
del config["input_features"] | ||
del config["output_features"] | ||
wandb.config.update(config) | ||
|
||
@staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was speced to be a regular method, rather than a static method. Can you adapt? |
||
def train_init(experiment_directory, experiment_name, model_name, | ||
resume, output_directory): | ||
import wandb | ||
logger.info("wandb.train_init() called...") | ||
wandb.init(project=os.getenv("WANDB_PROJECT", experiment_name), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, I was not aware of what |
||
sync_tensorboard=True, dir=output_directory) | ||
wandb.save(os.path.join(experiment_directory, "*")) | ||
|
||
@staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be more consistent to have all of the methods be regular class methods rather than some be static methods. I know you don't need the instance in your code, but it would be better I think to have all of the methods be the same, rather than some static and some not. Some contributions may want to store useful state in the class instance. |
||
def visualize_figure(fig): | ||
import wandb | ||
logger.info("wandb.visualize_figure() called...") | ||
if wandb.run: | ||
wandb.log({"figure": fig}) | ||
|
||
@staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
def predict_end(stats, *args, **kwargs): | ||
import wandb | ||
logger.info("wandb.predict() called... %s", stats) | ||
if wandb.run: | ||
wandb.summary.update(dict(stats)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be a regular method too?