Skip to content
Permalink
Browse files

tensorboardX: catch the exception when tensorboardX isn't installed (#…

…1711)

Import SummaryWriter from torch if torch.version > 1.2.0

Signed-off-by: Prachi Gupta <pragupta@us.ibm.com>
  • Loading branch information
pragupta committed Feb 5, 2020
1 parent 7d7a9df commit 7b3edd3a3a84d23047471abeef7833e891b97e8e
Showing with 9 additions and 3 deletions.
  1. +9 −3 examples/pytorch_imagenet_resnet50.py
@@ -8,10 +8,10 @@
import torch.utils.data.distributed
from torchvision import datasets, transforms, models
import horovod.torch as hvd
import tensorboardX
import os
import math
from tqdm import tqdm
from distutils.version import LooseVersion

# Training settings
parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
@@ -85,8 +85,14 @@
verbose = 1 if hvd.rank() == 0 else 0

# Horovod: write TensorBoard logs on first worker.
log_writer = tensorboardX.SummaryWriter(args.log_dir) if hvd.rank() == 0 else None

try:
if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'):
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter
log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
except ImportError:
log_writer = None

# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(4)

0 comments on commit 7b3edd3

Please sign in to comment.
You can’t perform that action at this time.