Skip to content

Commit

Permalink
PROD-290: Show DGAN training progress
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 6d0e6c5c09286940c7ec71bc3810d2159cc3df06
  • Loading branch information
misberner committed Jan 4, 2023
1 parent b0c6c15 commit 13ed083
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@
import logging
import math

from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch

from gretel_synthetics.timeseries_dgan.config import DfStyle, DGANConfig, OutputType
from gretel_synthetics.timeseries_dgan.structures import ProgressInfo
from gretel_synthetics.timeseries_dgan.torch_modules import Discriminator, Generator
from gretel_synthetics.timeseries_dgan.transformations import (
create_additional_attribute_outputs,
Expand Down Expand Up @@ -150,6 +151,7 @@ def train_numpy(
feature_types: Optional[List[OutputType]] = None,
attributes: Optional[np.ndarray] = None,
attribute_types: Optional[List[OutputType]] = None,
progress_callback: Optional[Callable[[ProgressInfo]]] = None,
):
"""Train DGAN model on data in numpy arrays.
Expand Down Expand Up @@ -304,7 +306,7 @@ def train_numpy(
torch.Tensor(internal_features),
)

self._train(dataset)
self._train(dataset, progress_callback=progress_callback)

def train_dataframe(
self,
Expand All @@ -315,6 +317,7 @@ def train_dataframe(
time_column: Optional[str] = None,
discrete_columns: Optional[List[str]] = None,
df_style: DfStyle = DfStyle.WIDE,
progress_callback: Optional[Callable[[ProgressInfo]]] = None,
):
"""Train DGAN model on data in pandas DataFrame.
Expand Down Expand Up @@ -394,6 +397,7 @@ def train_dataframe(
features=features,
attribute_types=self.data_frame_converter.attribute_types,
feature_types=self.data_frame_converter.feature_types,
progress_callback=progress_callback,
)

def generate_numpy(
Expand Down Expand Up @@ -619,6 +623,7 @@ def init_weights(m):
def _train(
self,
dataset: Dataset,
progress_callback: Optional[Callable[[ProgressInfo]]] = None,
):
"""Internal method for training DGAN model.
Expand Down Expand Up @@ -685,7 +690,7 @@ def _train(
for epoch in range(self.config.epochs):
logger.info(f"epoch: {epoch}")

for real_batch in loader:
for batch_idx, real_batch in enumerate(loader):
global_step += 1

with torch.cuda.amp.autocast(
Expand Down Expand Up @@ -788,6 +793,16 @@ def _train(
scaler.step(opt_generator)
scaler.update()

if progress_callback is not None:
progress_callback(
ProgressInfo(
epoch=epoch,
total_epochs=self.config.epochs,
batch=batch_idx,
total_batches=len(loader),
)
)

def _generate(
self, attribute_noise: torch.Tensor, feature_noise: torch.Tensor
) -> NumpyArrayTriple:
Expand Down
36 changes: 36 additions & 0 deletions src/gretel_synthetics/timeseries_dgan/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Auxiliary datastructures for DGAN
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional


@dataclass
class ProgressInfo:
"""Information about DGAN training progress.
Args:
epoch: the current epoch, zero-based.
total_epochs: the total number of epochs.
batch: the current batch within the current epoch, zero-based.
total_batches: the total number of batches in this epoch.
"""

epoch: int
total_epochs: int
batch: int
total_batches: int

@property
def frac_completed(self) -> float:
"""
An estimation of which fraction of the overall task is completed.
Returns:
A number between 0.0 and 1.0 indicating which fraction of the task is completed.
"""
return (
self.epoch + 1 + float(self.batch + 1) / self.total_batches
) / self.total_epochs

0 comments on commit 13ed083

Please sign in to comment.