Skip to content

Commit 6e77181

Browse files
authored
Squashed commit of the following: (Lightning-AI#2164)
commit 29fb0506cd38a15c359e369cc8bc4435916b0c78 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 19:35:30 2020 +0000 fix checking for version for docs to build commit 467fd64 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 18:56:05 2020 +0000 remove no local test commit a7cc9f8 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 18:46:44 2020 +0000 fix commit 3fdbb72 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 18:23:30 2020 +0000 revert requirements commit 9b8686b Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 18:16:42 2020 +0000 make it a fixture commit eec7495 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 18:01:32 2020 +0000 fix up the testing commit 896d94a Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 17:47:28 2020 +0000 fix some tests commit 6d22bde Merge: 6175d4e 6ebe0d7 Author: Brendan Fahy <bmfahy@gmail.com> Date: Sat Aug 8 10:20:47 2020 +0000 Merge remote-tracking branch 'origin/master' into tb_use_gfile commit 6175d4e Author: Brendan Fahy <bmfahy@gmail.com> Date: Fri Aug 7 10:16:36 2020 +0000 Use tensorboard.compat.gfile to support remote writing
1 parent 983c030 commit 6e77181

File tree

6 files changed

+111
-38
lines changed

6 files changed

+111
-38
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import _logger as log
1717
from pytorch_lightning.callbacks.base import Callback
1818
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
19+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
1920

2021

2122
class ModelCheckpoint(Callback):
@@ -104,7 +105,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
104105
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
105106
mode: str = 'auto', period: int = 1, prefix: str = ''):
106107
super().__init__()
107-
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
108+
if(filepath):
109+
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
110+
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
108111
rank_zero_warn(
109112
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
110113
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -116,12 +119,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
116119
if filepath is None: # will be determined by trainer at runtime
117120
self.dirpath, self.filename = None, None
118121
else:
119-
if os.path.isdir(filepath):
122+
if gfile.isdir(filepath):
120123
self.dirpath, self.filename = filepath, '{epoch}'
121124
else:
122125
filepath = os.path.realpath(filepath)
123126
self.dirpath, self.filename = os.path.split(filepath)
124-
os.makedirs(self.dirpath, exist_ok=True)
127+
if not gfile.exists(self.dirpath):
128+
makedirs(self.dirpath)
125129
self.save_last = save_last
126130
self.save_top_k = save_top_k
127131
self.save_weights_only = save_weights_only
@@ -163,16 +167,23 @@ def kth_best_model(self):
163167
return self.kth_best_model_path
164168

165169
def _del_model(self, filepath):
166-
if os.path.isfile(filepath):
167-
os.remove(filepath)
170+
if gfile.exists(filepath):
171+
try:
172+
# in compat mode, remove is not implemented so if running this
173+
# against an actual remove file system and the correct remote
174+
# dependencies exist then this will work fine.
175+
gfile.remove(filepath)
176+
except AttributeError:
177+
os.remove(filepath)
168178

169179
def _save_model(self, filepath, trainer, pl_module):
170180

171181
# in debugging, track when we save checkpoints
172182
trainer.dev_debugger.track_checkpointing_history(filepath)
173183

174184
# make paths
175-
os.makedirs(os.path.dirname(filepath), exist_ok=True)
185+
if not gfile.exists(os.path.dirname(filepath)):
186+
makedirs(os.path.dirname(filepath))
176187

177188
# delegate the saving to the model
178189
if self.save_function is not None:
@@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module):
308319

309320
filepath = self.format_checkpoint_name(epoch, metrics)
310321
version_cnt = 0
311-
while os.path.isfile(filepath):
322+
while gfile.exists(filepath):
312323
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
313324
# this epoch called before
314325
version_cnt += 1

pytorch_lightning/core/saving.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytorch_lightning import _logger as log
1212
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
1313
from pytorch_lightning.utilities.cloud_io import load as pl_load
14+
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
1415

1516
PRIMITIVE_TYPES = (bool, int, float, str)
1617
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
@@ -273,30 +274,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
273274
True
274275
>>> os.remove(path_csv)
275276
"""
276-
if not os.path.isfile(tags_csv):
277-
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
277+
if not gfile.exists(tags_csv):
278+
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
278279
return {}
279280

280-
with open(tags_csv) as fp:
281-
csv_reader = csv.reader(fp, delimiter=',')
281+
with cloud_open(tags_csv, "r", newline="") as fp:
282+
csv_reader = csv.reader(fp, delimiter=",")
282283
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
283284

284285
return tags
285286

286287

287288
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
288-
if not os.path.isdir(os.path.dirname(tags_csv)):
289-
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
289+
if not gfile.isdir(os.path.dirname(tags_csv)):
290+
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
290291

291292
if isinstance(hparams, Namespace):
292293
hparams = vars(hparams)
293294

294-
with open(tags_csv, 'w', newline='') as fp:
295-
fieldnames = ['key', 'value']
295+
with cloud_open(tags_csv, "w", newline="") as fp:
296+
fieldnames = ["key", "value"]
296297
writer = csv.DictWriter(fp, fieldnames=fieldnames)
297-
writer.writerow({'key': 'key', 'value': 'value'})
298+
writer.writerow({"key": "key", "value": "value"})
298299
for k, v in hparams.items():
299-
writer.writerow({'key': k, 'value': v})
300+
writer.writerow({"key": k, "value": v})
300301

301302

302303
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
@@ -310,11 +311,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
310311
True
311312
>>> os.remove(path_yaml)
312313
"""
313-
if not os.path.isfile(config_yaml):
314-
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
314+
if not gfile.exists(config_yaml):
315+
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
315316
return {}
316317

317-
with open(config_yaml) as fp:
318+
with cloud_open(config_yaml, "r") as fp:
318319
tags = yaml.load(fp)
319320

320321
return tags
@@ -326,11 +327,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
326327
config_yaml: path to new YAML file
327328
hparams: parameters to be saved
328329
"""
329-
if not os.path.isdir(os.path.dirname(config_yaml)):
330-
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
330+
if not gfile.isdir(os.path.dirname(config_yaml)):
331+
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
331332

332333
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
333334
from omegaconf import OmegaConf
335+
334336
OmegaConf.save(hparams, config_yaml, resolve=True)
335337
return
336338

@@ -341,7 +343,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
341343
hparams = dict(hparams)
342344
assert isinstance(hparams, dict)
343345

344-
with open(config_yaml, 'w', newline='') as fp:
346+
with cloud_open(config_yaml, "w", newline="") as fp:
345347
yaml.dump(hparams, fp)
346348

347349

pytorch_lightning/loggers/tensorboard.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning.core.saving import save_hparams_to_yaml
1717
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
1818
from pytorch_lightning.utilities import rank_zero_only
19+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
1920

2021
try:
2122
from omegaconf import Container, OmegaConf
@@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter:
109110
return self._experiment
110111

111112
assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
112-
os.makedirs(self.root_dir, exist_ok=True)
113+
if self.root_dir and not gfile.exists(str(self.root_dir)):
114+
makedirs(self.root_dir)
113115
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
114116
return self._experiment
115117

@@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
162164
def save(self) -> None:
163165
super().save()
164166
dir_path = self.log_dir
165-
if not os.path.isdir(dir_path):
167+
if not gfile.isdir(dir_path):
166168
dir_path = self.save_dir
167169

168170
# prepare the file path
@@ -188,13 +190,13 @@ def version(self) -> int:
188190
def _get_next_version(self):
189191
root_dir = os.path.join(self.save_dir, self.name)
190192

191-
if not os.path.isdir(root_dir):
193+
if not gfile.isdir(root_dir):
192194
log.warning('Missing logger folder: %s', root_dir)
193195
return 0
194196

195197
existing_versions = []
196-
for d in os.listdir(root_dir):
197-
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
198+
for d in gfile.listdir(root_dir):
199+
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
198200
existing_versions.append(int(d.split("_")[1]))
199201

200202
if len(existing_versions) == 0:

pytorch_lightning/trainer/training_io.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
)
105105
from pytorch_lightning.utilities import rank_zero_warn, AMPType
106106
from pytorch_lightning.utilities.cloud_io import load as pl_load
107+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
107108

108109
try:
109110
import torch_xla
@@ -407,9 +408,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
407408
did_restore = False
408409

409410
# look for hpc weights
410-
folderpath = self.weights_save_path
411-
if os.path.exists(folderpath):
412-
files = os.listdir(folderpath)
411+
folderpath = str(self.weights_save_path)
412+
if gfile.exists(folderpath):
413+
files = gfile.listdir(folderpath)
413414
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
414415

415416
# if hpc weights exist restore model
@@ -488,15 +489,17 @@ def restore_training_state(self, checkpoint):
488489
# ----------------------------------
489490
def hpc_save(self, folderpath: str, logger):
490491
# make sure the checkpoint folder exists
491-
os.makedirs(folderpath, exist_ok=True)
492+
folderpath = str(folderpath) # because the tests pass a path object
493+
if not gfile.exists(folderpath):
494+
makedirs(folderpath)
492495

493496
# save logger to make sure we get all the metrics
494497
logger.save()
495498

496499
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
497500

498-
if not os.path.exists(folderpath):
499-
os.makedirs(folderpath, exist_ok=True)
501+
if not gfile.exists(folderpath):
502+
makedirs(folderpath)
500503
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
501504

502505
# give model a chance to do something on hpc_save
@@ -549,7 +552,7 @@ def hpc_load(self, folderpath, on_gpu):
549552
log.info(f'restored hpc model from: {filepath}')
550553

551554
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
552-
files = os.listdir(path)
555+
files = gfile.listdir(str(path))
553556
files = [x for x in files if name_key in x]
554557
if len(files) == 0:
555558
return 0
Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,65 @@
1+
import sys
2+
import os
3+
from typing import Union
14
from pathlib import Path
25
from urllib.parse import urlparse
3-
46
import torch
57

8+
import tensorboard
9+
from packaging import version
10+
from pytorch_lightning import _logger as log
11+
12+
# we want this for tf.io.gfile, which if tf is installed gives full tf,
13+
# otherwise gives a pruned down version which works for some file backends but
14+
# not all
15+
from tensorboard.compat import tf
16+
17+
gfile = tf.io.gfile
18+
19+
pathlike = Union[Path, str]
20+
21+
# older version of tensorboard had buggy gfile compatibility layers
22+
# only support remote cloud paths if newer
23+
624

725
def load(path_or_url: str, map_location=None):
826
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
927
return torch.load(path_or_url, map_location=map_location)
10-
else:
11-
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
28+
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
29+
30+
31+
def modern_gfile():
32+
"""Check the version number of tensorboard.
33+
34+
Cheking to see if it has the gfile compatibility layers needed for remote
35+
file operations
36+
"""
37+
tb_version = version.parse(tensorboard.version.VERSION)
38+
modern_gfile = tb_version >= version.parse('2.0')
39+
40+
41+
def cloud_open(path: pathlike, mode: str, newline:str = None):
42+
if sys.platform == "win32":
43+
log.debug(
44+
"gfile does not handle newlines correctly on windows so remote files are not"
45+
"supported falling back to normal local file open."
46+
)
47+
return open(path, mode, newline=newline)
48+
if not modern_gfile():
49+
log.debug(
50+
"tenosrboard.compat gfile does not work on older versions "
51+
"of tensorboard for remote files, using normal local file open."
52+
)
53+
return open(path, mode, newline=newline)
54+
try:
55+
return gfile.GFile(path, mode)
56+
except NotImplementedError as e:
57+
# minimal dependencies are installed and only local files will work
58+
return open(path, mode, newline=newline)
59+
60+
61+
def makedirs(path: pathlike):
62+
if hasattr(gfile, "makedirs") and modern_gfile():
63+
return gfile.makedirs(str(path))
64+
# otherwise minimal dependencies are installed and only local files will work
65+
return os.makedirs(path, exist_ok=True)

requirements/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
77
# pyyaml>=3.13
88
PyYAML>=5.1 # OmegaConf requirement >=5.1
99
tqdm>=4.41.0
10+
packaging

0 commit comments

Comments
 (0)