Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

☄️ comet integration #129

Merged
merged 29 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/src/models/hello_world.py
Expand Up @@ -45,7 +45,7 @@
),
ModelCheckpoint(),
EmissionTrackerCallback(),
CometCallback(offline=True),
CometCallback(offline=False),
]

if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions gradsflow/callbacks/gpu.py
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from loguru import logger

from gradsflow.callbacks import Callback
from gradsflow.utility.imports import requires

Expand All @@ -35,13 +37,13 @@ def __init__(self, offline: bool = False, **kwargs):
from codecarbon import EmissionsTracker, OfflineEmissionsTracker

if offline:
self.tracker = OfflineEmissionsTracker(**kwargs)
self._emission_tracker = OfflineEmissionsTracker(**kwargs)
else:
self.tracker = EmissionsTracker(**kwargs)
self.tracker.start()
self._emission_tracker = EmissionsTracker(**kwargs)
self._emission_tracker.start()

super().__init__(model=None)

def on_fit_end(self):
emissions: float = self.tracker.stop()
print(f"Emissions: {emissions} kg")
emissions: float = self._emission_tracker.stop()
logger.info(f"Emissions: {emissions} kg")
75 changes: 56 additions & 19 deletions gradsflow/callbacks/logger/comet.py
Expand Up @@ -34,32 +34,69 @@ class CometCallback(Callback):
def __init__(
self,
project_name: str = "awesome-project",
workspace: Optional[str] = None,
experiment_id: Optional[str] = None,
api_key: Optional[str] = os.environ.get("COMET_API_KEY"),
code_file: str = CURRENT_FILE,
offline: bool = False,
**kwargs
**kwargs,
):
super().__init__(
model=None,
)
self._code_file = code_file
self.experiment = self._create_experiment(project_name=project_name, api_key=api_key, offline=offline, **kwargs)
self._experiment_id = experiment_id
self.experiment = self._create_experiment(
project_name=project_name,
workspace=workspace,
api_key=api_key,
offline=offline,
experiment_id=experiment_id,
**kwargs,
)
self._train_prefix = "train"
self._val_prefix = "val"

@requires("comet_ml", "CometCallback requires comet_ml to be installed!")
def _create_experiment(
self, project_name: str, offline: bool = False, api_key: Optional[str] = None, **kwargs
self,
project_name: str,
workspace: str,
offline: bool = False,
api_key: Optional[str] = None,
experiment_id: Optional[str] = None,
**kwargs,
) -> BaseExperiment:
from comet_ml import Experiment, OfflineExperiment
from comet_ml import (
ExistingExperiment,
ExistingOfflineExperiment,
Experiment,
OfflineExperiment,
)

if offline:
experiment = OfflineExperiment(project_name=project_name, **kwargs)
if experiment_id:
experiment = ExistingOfflineExperiment(
project_name=project_name, workspace=workspace, previous_experiment=experiment_id, **kwargs
)
else:
experiment = OfflineExperiment(project_name=project_name, workspace=workspace, **kwargs)
else:
experiment = Experiment(project_name=project_name, api_key=api_key, **kwargs)
if experiment_id:
experiment = ExistingExperiment(
project_name=project_name,
workspace=workspace,
api_key=api_key,
previous_experiment=experiment_id,
**kwargs,
)
else:
experiment = Experiment(project_name=project_name, workspace=workspace, api_key=api_key, **kwargs)
return experiment

def on_fit_start(self):
self.experiment.set_model_graph(self.model.learner)
self.experiment.set_code(self._code_file)
self.experiment.log_code(self._code_file)

def on_train_epoch_start(
self,
Expand All @@ -71,31 +108,31 @@ def on_val_epoch_start(
):
self.experiment.validate()

def on_train_step_end(self, *args, **kwargs):
def _step(self, prefix: str, *args, **kwargs):
step = self.model.tracker.mode(prefix).steps
outputs = kwargs["outputs"]
loss = outputs["loss"].item()
self.experiment.log_metrics(outputs.get("metrics", {}))
self.experiment.log_metric("train_step_loss", loss)
self.experiment.log_metrics(outputs.get("metrics", {}), step=step, prefix=prefix)
self.experiment.log_metric(f"{prefix}_step_loss", loss, step=step)

def on_train_step_end(self, *args, **kwargs):
self._step(*args, **kwargs, prefix=self._train_prefix)

def on_val_step_end(self, *args, **kwargs):
outputs = kwargs["outputs"]
loss = outputs["loss"].item()
self.experiment.log_metrics(outputs.get("metrics", {}))
self.experiment.log_metric("val_step_loss", loss)
self._step(*args, **kwargs, prefix=self._val_prefix)

def on_epoch_end(self):
step = self.model.tracker.current_step
epoch = self.model.tracker.current_epoch
train_loss = self.model.tracker.train_loss
train_metrics = self.model.tracker.train_metrics
val_loss = self.model.tracker.val_loss
val_metrics = self.model.tracker.val_metrics

self.experiment.train()
self.experiment.log_metric("epoch_loss", train_loss, step=step, epoch=epoch)
self.experiment.log_metrics(train_metrics, step=step, epoch=epoch)
self.experiment.log_metric("train_epoch_loss", train_loss, epoch=epoch)
self.experiment.log_metrics(train_metrics, epoch=epoch, prefix=self._train_prefix)

self.experiment.validate()
self.experiment.log_metric("epoch_loss", val_loss, step=step, epoch=epoch)
self.experiment.log_metrics(val_metrics, step=step, epoch=epoch)
self.experiment.log_metric("val_epoch_loss", val_loss, epoch=epoch)
self.experiment.log_metrics(val_metrics, epoch=epoch, prefix=self._val_prefix)
self.experiment.log_epoch_end(epoch)
1 change: 0 additions & 1 deletion gradsflow/core/base.py
Expand Up @@ -60,7 +60,6 @@ def reset_metrics(self):
class BaseTracker:
max_epochs: int = 0
current_epoch: int = 0 # current train current_epoch
current_step: int = 0 # current current_step
steps_per_epoch: Optional[int] = None
train: TrackingValues = TrackingValues()
val: TrackingValues = TrackingValues()
13 changes: 4 additions & 9 deletions gradsflow/models/tracker.py
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, List

from rich import box
Expand Down Expand Up @@ -39,13 +38,10 @@ def mode(self, mode) -> TrackingValues:

raise NotImplementedError(f"mode {mode} is not implemented!")

def track(self, key, value, render=False):
"""Tracks values for each step"""
if render:
warnings.warn("render is deprecated!")
def track(self, key, value):
"""Tracks value"""
epoch = self.current_epoch
step = self.current_step
data = {"current_epoch": epoch, "current_step": step, key: to_item(value)}
data = {"current_epoch": epoch, key: to_item(value)}
self.logs.append(data)

def track_loss(self, loss: float, mode: str):
Expand All @@ -56,7 +52,7 @@ def track_loss(self, loss: float, mode: str):
self.track(key, loss)

def track_metrics(self, metric: Dict[str, float], mode: str):
"""Update `TrackingValues` metrics. mode can be train or val and will update logs if render is True"""
"""Update `TrackingValues` metrics. mode can be train or val"""
value_tracker = self.mode(mode)
# Track values that averages with epoch
for key, value in metric.items():
Expand Down Expand Up @@ -103,7 +99,6 @@ def create_table(self) -> Table:
def reset(self):
self.max_epochs = 0
self.current_epoch = 0
self.current_step = 0
self.steps_per_epoch = None
self.train = TrackingValues()
self.val = TrackingValues()
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_tracker.py
Expand Up @@ -33,8 +33,8 @@ def test_mode():


def test_track():
tracker.track("val", 0.9, render=True)
tracker.track("score", 0.5, render=False)
tracker.track("val", 0.9)
tracker.track("score", 0.5)


def test_create_table():
Expand Down