Skip to content

Commit

Permalink
fix(callbacks): fix wandb callbacks, changelog mod
Browse files Browse the repository at this point in the history
  • Loading branch information
okunator committed Oct 7, 2022
1 parent 2719ef7 commit ce4bc0c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 107 deletions.
33 changes: 16 additions & 17 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,50 @@

## Test

- Update tests for Inferes and mask utils.

- Add tests for the benchmarkers.
- Update tests for Inferes and mask utils.
- Add tests for the benchmarkers.

## Fixes

- init and typing fixes
- init and typing fixes

## Docs

- Typo fies in docs
- Typo fies in docs

## Features

- Add numba parallellized median filter and majority voting for post-processing
- Add support for own semantic and type seg post-proc funcs in Inferers
- Add numba parallellized median filter and majority voting for post-processing
- Add support for own semantic and type seg post-proc funcs in Inferers

- Add segmentation performance benchmarking helper class.
- Add segmentation latency benchmarking helper class.
- Add segmentation performance benchmarking helper class.
- Add segmentation latency benchmarking helper class.

<a id='changelog-0.1.2'></a>

# 0.1.2 — 2022-09-09

## Fixes

- **datasets.writers**: Update `save2db` & `save2folder` for optional type_map and sem_map args.
- **datasets.writers**: Pre-processing (`pre-proc`) callable arg for `_get_tiles` method. This enables the Lizard datamodule.
- **inference**: Fix- padding bug with sliding window inference.
- Update `save2db` & `save2folder` for optional type_map and sem_map args.
- Pre-processing (`pre-proc`) callable arg for `_get_tiles` method. This enables the Lizard datamodule.
- Fix- padding bug with sliding window inference.

## Features

- **datamodules**: Lizard datamodule (https://arxiv.org/abs/2108.11195)
- Lizard datamodule (https://arxiv.org/abs/2108.11195)

- **models**: Add a universal multi-task U-net model builder (experimental)
- Add a universal multi-task U-net model builder (experimental)

## Test

- **dataset**: Update dataset tests.
- Update dataset tests.

- **models**: Update tests for multi-task U-Net
- Update tests for multi-task U-Net

## Type Hints

- **models**: Fix incorrect type hints.
- Fix incorrect type hints.

## Examples

Expand Down
4 changes: 0 additions & 4 deletions cellseg_models_pytorch/training/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
dist_sync_on_step: bool = False,
progress_grouo: Any = None,
dist_sync_func: Callable = None,
num_classes: int = None,
**kwargs
) -> None:
"""Create a custom torchmetrics mIoU callback.
Expand All @@ -121,9 +120,6 @@ def __init__(
dist_sync_func : Callable, optional
Callback that performs the allgather operation on the metric state.
When None, DDP will be used to perform the allgather.
num_classes : int, optional
If not None, multi-class miou will be returned.
"""
super().__init__(
compute_on_step=compute_on_step,
Expand Down
160 changes: 76 additions & 84 deletions cellseg_models_pytorch/training/callbacks/wandb_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..functional import iou

__all__ = ["WandbImageCallback", "WandbClassMetricCallback"]
__all__ = ["WandbImageCallback", "WandbClassBarCallback", "WandbClassLineCallback"]


class WandbImageCallback(pl.Callback):
Expand Down Expand Up @@ -104,26 +104,22 @@ def on_validation_batch_end(
trainer.logger.experiment.log(log_dict)


class WandbClassMetricCallback(pl.Callback):
class WandbIoUCallback(pl.Callback):
def __init__(
self,
type_classes: Dict[str, int],
sem_classes: Optional[Dict[str, int]],
freq: int = 100,
return_series: bool = True,
return_bar: bool = True,
return_table: bool = False,
) -> None:
"""Call back to compute per-class ious and log them to wandb."""
"""Create a base class for IoU wandb callbacks."""
super().__init__()
self.type_classes = type_classes
self.sem_classes = sem_classes
self.freq = freq
self.return_series = return_series
self.return_bar = return_bar
self.return_table = return_table
self.cell_ious = np.empty(0)
self.sem_ious = np.empty(0)

def batch_end(self) -> None:
"""Abstract batch end method."""
raise NotImplementedError

def compute(
self,
Expand All @@ -139,36 +135,47 @@ def compute(
met = iou(pred, target).mean(dim=0)
return met.to("cpu").numpy()

def get_table(
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str]
) -> wandb.Table:
"""Return a wandb Table with step, iou and label values for every step."""
batch_data = [
[xi * self.freq, c, np.round(ious[xi, i], 4)]
for i, c, in classes.items()
for xi in x
]
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Log the inputs and outputs of the model to wandb."""
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")

return wandb.Table(data=batch_data, columns=["step", "label", "value"])
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Log the inputs and outputs of the model to wandb."""
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")


class WandbClassBarCallback(WandbIoUCallback):
def __init__(
self,
type_classes: Dict[str, int],
sem_classes: Optional[Dict[str, int]],
freq: int = 100,
) -> None:
"""Create a wandb callback that logs per-class mIoU at batch ends."""
super().__init__(type_classes, sem_classes, freq)

def get_bar(self, iou: np.ndarray, classes: Dict[int, str], title: str) -> Any:
"""Return a wandb bar plot object of the current per class iou values."""
batch_data = [[lab, val] for lab, val in zip(list(classes.values()), iou)]
table = wandb.Table(data=batch_data, columns=["label", "value"])
return wandb.plot.bar(table, "label", "value", title=title)

def get_series(
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str], title: str
) -> Any:
"""Return a wandb series plot obj of the per class iou values over timesteps."""
return wandb.plot.line_series(
xs=x.tolist(),
ys=[ious[:, c].tolist() for c in classes.keys()],
keys=list(classes.values()),
title=title,
xname="step",
)

def batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -182,69 +189,54 @@ def batch_end(
log_dict = {}
if "type" in list(batch.keys()):
iou = self.compute("type", outputs, batch)
self.cell_ious = np.append(self.cell_ious, iou)
cell_ious = self.cell_ious.reshape(-1, len(self.type_classes))
x = np.arange(cell_ious.shape[0])

if self.return_table:
log_dict[f"{phase}/type_ious_table"] = self.get_table(
cell_ious, x, self.type_classes
)

if self.return_series:
log_dict[f"{phase}/type_ious_per_class"] = self.get_series(
cell_ious, x, self.type_classes, title="Per type class mIoU"
)

if self.return_bar:
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
list(iou), self.type_classes, title="Cell class mIoUs"
)
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
list(iou), self.type_classes, title="Cell class mIoUs"
)

if "sem" in list(batch.keys()):
iou = self.compute("sem", outputs, batch)

self.sem_ious = np.append(self.sem_ious, iou)
sem_ious = self.sem_ious.reshape(-1, len(self.sem_classes))
x = np.arange(sem_ious.shape[0])

if self.return_table:
log_dict[f"{phase}/sem_ious_table"] = self.get_table(
cell_ious, x, self.type_classes
)

if self.return_series:
log_dict[f"{phase}/sem_ious_per_class"] = self.get_series(
cell_ious, x, self.type_classes, title="Per sem class mIoU"
)

if self.return_bar:
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
list(iou), self.type_classes, title="Sem class mIoUs"
)
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
list(iou), self.sem_classes, title="Sem class mIoUs"
)

trainer.logger.experiment.log(log_dict)

def on_train_batch_end(

class WandbClassLineCallback(WandbIoUCallback):
def __init__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
type_classes: Dict[str, int],
sem_classes: Optional[Dict[str, int]],
freq: int = 100,
) -> None:
"""Log the inputs and outputs of the model to wandb."""
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
"""Create a wandb callback that logs per-class mIoU at batch ends."""
super().__init__(type_classes, sem_classes, freq)

def on_validation_batch_end(
def get_points(self, iou: np.ndarray, classes: Dict[int, str]) -> Any:
"""Return a wandb bar plot object of the current per class iou values."""
return {lab: val for lab, val in zip(list(classes.values()), iou)}

def batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
phase: str,
) -> None:
"""Log the inputs and outputs of the model to wandb."""
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")
"""Log metrics at every 100th step to wandb."""
if batch_idx % self.freq == 0:
log_dict = {}
if "type" in list(batch.keys()):
iou = self.compute("type", outputs, batch)
log_dict[f"{phase}/type_ious_points"] = self.get_points(
list(iou), self.type_classes
)

if "sem" in list(batch.keys()):
iou = self.compute("sem", outputs, batch)
log_dict[f"{phase}/sem_ious_points"] = self.get_points(
list(iou), self.sem_classes
)

trainer.logger.experiment.log(log_dict)
2 changes: 0 additions & 2 deletions cellseg_models_pytorch/training/lit/lightning_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def __init__(
scheduler_params : Dict[str, Any]
Params dict for the scheduler. Refer to torch lr_scheduler docs
for the possible scheduler arguments.
return_soft_masks : bool, default=True
Return the model outputs for logging if True. Saves mem if set to False.
log_freq : int, default=100
Return soft masks every every n batches for callbacks and logging.
Expand Down

0 comments on commit ce4bc0c

Please sign in to comment.