Skip to content

Commit

Permalink
make visdom optional
Browse files Browse the repository at this point in the history
Summary: Make Implicitron run without visdom installed.

Reviewed By: shapovalov

Differential Revision: D40587974

fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 22, 2022
1 parent 46cb5aa commit ff933ab
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 22 deletions.
2 changes: 1 addition & 1 deletion projects/implicitron_trainer/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
Predictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.
Expand Down
21 changes: 14 additions & 7 deletions pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
import torch
Expand All @@ -27,7 +27,9 @@
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate
from visdom import Visdom

if TYPE_CHECKING:
from visdom import Visdom


EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
Expand All @@ -43,14 +45,16 @@ class _Visualizer:

visdom_env: str = "eval_debug"

_viz: Visdom = field(init=False)
_viz: Optional["Visdom"] = field(init=False)

def __post_init__(self):
self._viz = vis_utils.get_visdom_connection()

def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
):
if self._viz is None:
return
self._viz.images(
torch.cat(
(
Expand All @@ -68,7 +72,10 @@ def show_rgb(
def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
):
self._viz.images(
if self._viz is None:
return
viz = self._viz
viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
Expand All @@ -80,13 +87,13 @@ def show_depth(
win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
)
self._viz.images(
viz.images(
loss_mask_now,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
self._viz.images(
viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
Expand Down Expand Up @@ -126,7 +133,7 @@ def show_depth(
pointcloud_max_points=10000,
pointcloud_marker_size=1,
)
self._viz.plotlyplot(
viz.plotlyplot(
plotlyplot,
env=self.visdom_env,
win=f"pcl{name_postfix}",
Expand Down
10 changes: 6 additions & 4 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
import tqdm
Expand All @@ -34,7 +34,9 @@
from pytorch3d.renderer import utils as rend_utils

from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom

if TYPE_CHECKING:
from visdom import Visdom

from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase
Expand Down Expand Up @@ -544,7 +546,7 @@ def _get_objective(self, preds) -> Optional[torch.Tensor]:

def visualize(
self,
viz: Visdom,
viz: Optional["Visdom"],
visdom_env_imgs: str,
preds: Dict[str, Any],
prefix: str,
Expand All @@ -559,7 +561,7 @@ def visualize(
preds: predictions dict like returned by forward()
prefix: prepended to the names of images
"""
if not viz.check_connection():
if viz is None or not viz.check_connection():
logger.info("no visdom server! -> skipping batch vis")
return

Expand Down
10 changes: 6 additions & 4 deletions pytorch3d/implicitron/models/visualization/render_flyaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import os
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
import torch
Expand All @@ -27,7 +27,9 @@
make_depth_image,
)
from tqdm import tqdm
from visdom import Visdom

if TYPE_CHECKING:
from visdom import Visdom

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
def _show_predictions(
preds: List[Dict[str, Any]],
sequence_name: str,
viz: Visdom,
viz: "Visdom",
viz_env: str = "visualizer",
predicted_keys: Sequence[str] = (
"images_render",
Expand Down Expand Up @@ -318,7 +320,7 @@ def _show_predictions(
def _generate_prediction_videos(
preds: List[Dict[str, Any]],
sequence_name: str,
viz: Optional[Visdom] = None,
viz: Optional["Visdom"] = None,
viz_env: str = "visualizer",
predicted_keys: Sequence[str] = (
"images_render",
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/implicitron/tools/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def plot_stats(
novisdom = False

viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection():
if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True

Expand Down
21 changes: 16 additions & 5 deletions pytorch3d/implicitron/tools/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING

import torch
from visdom import Visdom

if TYPE_CHECKING:
from visdom import Visdom


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,9 +42,9 @@ def get_visdom_env(visdom_env: str, exp_dir: str) -> str:
def get_visdom_connection(
server: str = "http://localhost",
port: int = 8097,
) -> Visdom:
) -> Optional["Visdom"]:
"""
Obtain a connection to a visdom server.
Obtain a connection to a visdom server if visdom is installed.
Args:
server: Server address.
Expand All @@ -51,14 +53,23 @@ def get_visdom_connection(
Returns:
connection: The connection object.
"""
try:
from visdom import Visdom
except ImportError:
logger.debug("Cannot load visdom")
return None

if server == "None":
return None

global _viz_singleton
if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port)
return _viz_singleton


def visualize_basics(
viz: Visdom,
viz: "Visdom",
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
Expand Down

0 comments on commit ff933ab

Please sign in to comment.