Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial pass at wandb Ludwig integration #514

Merged
merged 20 commits into from Feb 2, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions ludwig/contribs/__init__.py
Expand Up @@ -35,13 +35,15 @@
method with `pass`, or just don't implement the method.
"""

## Contributors, import your class here:
# Contributors, import your class here:
from .comet import Comet
from .wandb import Wandb

contrib_registry = {
## Contributors, add your class here:
# Contributors, add your class here:
'classes': {
'comet': Comet,
'wandb': Wandb,
},
'instances': [],
}
73 changes: 73 additions & 0 deletions ludwig/contribs/wandb.py
@@ -0,0 +1,73 @@
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
import os


logger = logging.getLogger(__name__)


class Wandb():
"Class that defines the methods necessary to hook into process."

@staticmethod
def import_call(argv, *args, **kwargs):
"""
Enable Third-party support from wandb.ai
Allows experiment tracking, visualization, and
management.
"""
try:
import wandb
# Needed to call an attribute of wandb to make DeepSource not complain
return Wandb() if wandb.__version__ else None
except ImportError:
logger.error(
"Ignored --wandb: Please install wandb; see https://docs.wandb.com")
return None

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a regular method too?

def train_model(model, *args, **kwargs):
import wandb
logger.info("wandb.train_model() called...")
config = model.hyperparameters.copy()
del config["input_features"]
del config["output_features"]
wandb.config.update(config)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was speced to be a regular method, rather than a static method. Can you adapt?

def train_init(experiment_directory, experiment_name, model_name,
resume, output_directory):
import wandb
logger.info("wandb.train_init() called...")
wandb.init(project=os.getenv("WANDB_PROJECT", experiment_name),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add name=model_name to the init parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I was not aware of what model_name was used for initially but now it makes sense to use it as the W&B run name

sync_tensorboard=True, dir=output_directory)
wandb.save(os.path.join(experiment_directory, "*"))

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more consistent to have all of the methods be regular class methods rather than some be static methods. I know you don't need the instance in your code, but it would be better I think to have all of the methods be the same, rather than some static and some not. Some contributions may want to store useful state in the class instance.

def visualize_figure(fig):
import wandb
logger.info("wandb.visualize_figure() called...")
if wandb.run:
wandb.log({"figure": fig})

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

def predict_end(stats, *args, **kwargs):
import wandb
logger.info("wandb.predict() called... %s", stats)
if wandb.run:
wandb.summary.update(dict(stats))
7 changes: 6 additions & 1 deletion ludwig/predict.py
Expand Up @@ -190,6 +190,8 @@ def predict(
train_set_metadata
)

contrib_command("predict_end", test_stats)

return test_stats


Expand All @@ -216,7 +218,10 @@ def save_prediction_outputs(
for output_field, outputs in postprocessed_output.items():
for output_type, values in outputs.items():
if output_type not in skip_output_types:
save_csv(csv_filename.format(output_field, output_type), values)
save_csv(
csv_filename.format(output_field, output_type),
values
)


def save_test_statistics(test_stats, experiment_dir_name):
Expand Down
3 changes: 3 additions & 0 deletions ludwig/train.py
Expand Up @@ -317,6 +317,9 @@ def full_train(
train_set_metadata
)

contrib_command("train_init", experiment_directory=experiment_dir_name, experiment_name=experiment_name,
model_name=model_name, output_directory=output_directory, resume=model_resume_path is not None)

# run the experiment
model, result = train(
training_set=training_set,
Expand Down