Skip to content

Commit 22d7d03

Browse files
S-aiueo32Bordaawaelchli
authored
Replace meta_tags.csv with hparams.yaml (Lightning-AI#1271)
* Add support for hierarchical dict * Support nested Namespace * Add docstring * Migrate hparam flattening to each logger * Modify URLs in CHANGELOG * typo * Simplify the conditional branch about Namespace Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * added examples section to docstring * renamed _dict -> input_dict * mata_tags.csv -> hparams.yaml * code style fixes * add pyyaml * remove unused import * create the member NAME_HPARAMS_FILE * improve tests * Update tensorboard.py * pass the local test w/o relavents of Horovod * formatting * update dependencies * fix dependencies * Apply suggestions from code review * add savings * warn * docstrings * tests * Apply suggestions from code review * saving * Apply suggestions from code review * use default * remove logging * typo fixes * update docs * update CHANGELOG * clean imports * add blank lines * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * back to namespace * add docs * test fix * update dependencies * add space Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 35fe2ef commit 22d7d03

File tree

12 files changed

+226
-79
lines changed

12 files changed

+226
-79
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
### Changed
2828

29+
- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
30+
2931
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
3032

3133
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
@@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3638

3739
### Deprecated
3840

41+
- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
42+
3943
### Removed
4044

4145
### Fixed

docs/source/test_set.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method.
2121

2222
.. code-block:: python
2323
24-
model = MyLightningModule.load_from_metrics(
25-
weights_path='/path/to/pytorch_checkpoint.ckpt',
26-
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
27-
on_gpu=True,
24+
model = MyLightningModule.load_from_checkpoint(
25+
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
26+
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
2827
map_location=None
2928
)
3029

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytorch>=1.1
1414
- tensorboard>=1.14
1515
- future>=0.17.1
16+
- pyyaml>=3.13
1617

1718
# For dev and testing
1819
- tox

pytorch_lightning/core/lightning.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import inspect
33
import os
4+
import warnings
45
from abc import ABC, abstractmethod
56
from argparse import Namespace
67
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
@@ -16,7 +17,7 @@
1617
from pytorch_lightning.core.grads import GradInformation
1718
from pytorch_lightning.core.hooks import ModelHooks
1819
from pytorch_lightning.core.memory import ModelSummary
19-
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
20+
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
2021
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
2122
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
2223
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -1438,49 +1439,88 @@ def load_from_checkpoint(
14381439
cls,
14391440
checkpoint_path: str,
14401441
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
1441-
tags_csv: Optional[str] = None,
1442+
hparams_file: Optional[str] = None,
1443+
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
14421444
hparam_overrides: Optional[Dict] = None,
14431445
*args, **kwargs
14441446
) -> 'LightningModule':
14451447
r"""
14461448
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
14471449
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
1448-
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
1449-
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
1450+
with an argument called ``hparams`` which is an object of :class:`~dict` or
1451+
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
1452+
when parsing command line arguments).
1453+
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
14501454
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
14511455
14521456
Example:
14531457
.. code-block:: python
14541458
1459+
# define hparams as Namespace
14551460
from argparse import Namespace
14561461
hparams = Namespace(**{'learning_rate': 0.1})
14571462
14581463
model = MyModel(hparams)
14591464
14601465
class MyModel(LightningModule):
1461-
def __init__(self, hparams):
1466+
def __init__(self, hparams: Namespace):
14621467
self.learning_rate = hparams.learning_rate
14631468
1469+
# ----------
1470+
1471+
# define hparams as dict
1472+
hparams = {
1473+
drop_prob: 0.2,
1474+
dataloader: {
1475+
batch_size: 32
1476+
}
1477+
}
1478+
1479+
model = MyModel(hparams)
1480+
1481+
class MyModel(LightningModule):
1482+
def __init__(self, hparams: dict):
1483+
self.learning_rate = hparams['learning_rate']
1484+
14641485
Args:
14651486
checkpoint_path: Path to checkpoint.
14661487
model_args: Any keyword args needed to init the model.
14671488
map_location:
14681489
If your checkpoint saved a GPU model and you now load on CPUs
14691490
or a different number of GPUs, use this to map to the new setup.
14701491
The behaviour is the same as in :func:`torch.load`.
1471-
tags_csv: Optional path to a .csv file with two columns (key, value)
1492+
hparams_file: Optional path to a .yaml file with hierarchical structure
14721493
as in this example::
14731494
1474-
key,value
1475-
drop_prob,0.2
1476-
batch_size,32
1495+
drop_prob: 0.2
1496+
dataloader:
1497+
batch_size: 32
14771498
14781499
You most likely won't need this since Lightning will always save the hyperparameters
14791500
to the checkpoint.
14801501
However, if your checkpoint weights don't have the hyperparameters saved,
1481-
use this method to pass in a .csv file with the hparams you'd like to use.
1482-
These will be converted into a :class:`~argparse.Namespace` and passed into your
1502+
use this method to pass in a .yaml file with the hparams you'd like to use.
1503+
These will be converted into a :class:`~dict` and passed into your
14831504
:class:`LightningModule` for use.
1505+
1506+
If your model's `hparams` argument is :class:`~argparse.Namespace`
1507+
and .yaml file has hierarchical structure, you need to refactor your model to treat
1508+
`hparams` as :class:`~dict`.
1509+
1510+
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
1511+
tags_csv:
1512+
.. warning:: .. deprecated:: 0.7.6
1513+
1514+
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
1515+
1516+
Optional path to a .csv file with two columns (key, value)
1517+
as in this example::
1518+
1519+
key,value
1520+
drop_prob,0.2
1521+
batch_size,32
1522+
1523+
Use this method to pass in a .csv file with the hparams you'd like to use.
14841524
hparam_overrides: A dictionary with keys to override in the hparams
14851525
14861526
Return:
@@ -1502,7 +1542,7 @@ def __init__(self, hparams):
15021542
# or load weights and hyperparameters from separate files.
15031543
MyLightningModule.load_from_checkpoint(
15041544
'path/to/checkpoint.ckpt',
1505-
tags_csv='/path/to/hparams_file.csv'
1545+
hparams_file='/path/to/hparams_file.yaml'
15061546
)
15071547
15081548
# override some of the params with new values
@@ -1531,9 +1571,22 @@ def __init__(self, hparams):
15311571

15321572
# add the hparams from csv file to checkpoint
15331573
if tags_csv is not None:
1534-
hparams = load_hparams_from_tags_csv(tags_csv)
1535-
hparams.__setattr__('on_gpu', False)
1536-
checkpoint['hparams'] = vars(hparams)
1574+
hparams_file = tags_csv
1575+
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)
1576+
1577+
if hparams_file is not None:
1578+
extension = hparams_file.split('.')[-1]
1579+
if extension.lower() in ('csv'):
1580+
hparams = load_hparams_from_tags_csv(hparams_file)
1581+
elif extension.lower() in ('yml', 'yaml'):
1582+
hparams = load_hparams_from_yaml(hparams_file)
1583+
else:
1584+
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')
1585+
1586+
hparams['on_gpu'] = False
1587+
1588+
# overwrite hparams by the given file
1589+
checkpoint['hparams'] = hparams
15371590

15381591
# override the hparam keys that were passed in
15391592
if hparam_overrides is not None:
@@ -1549,15 +1602,18 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
15491602

15501603
if cls_takes_hparams:
15511604
if ckpt_hparams is not None:
1552-
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
1553-
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
1605+
hparams_type = checkpoint.get('hparams_type', 'Namespace')
1606+
if hparams_type.lower() == 'dict':
1607+
hparams = ckpt_hparams
1608+
elif hparams_type.lower() == 'namespace':
1609+
hparams = Namespace(**ckpt_hparams)
15541610
else:
15551611
rank_zero_warn(
15561612
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
15571613
" contains argument 'hparams'. Will pass in an empty Namespace instead."
15581614
" Did you forget to store your model hyperparameters in self.hparams?"
15591615
)
1560-
hparams = Namespace()
1616+
hparams = {}
15611617
else: # The user's LightningModule does not define a hparams argument
15621618
if ckpt_hparams is None:
15631619
hparams = None
@@ -1568,7 +1624,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
15681624
)
15691625

15701626
# load the state_dict on the model automatically
1571-
if hparams:
1627+
if cls_takes_hparams:
15721628
kwargs.update(hparams=hparams)
15731629
model = cls(*args, **kwargs)
15741630
model.load_state_dict(checkpoint['state_dict'])

pytorch_lightning/core/saving.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import ast
12
import csv
23
import os
4+
import yaml
35
from argparse import Namespace
46
from typing import Union, Dict, Any
57

68
from pytorch_lightning import _logger as log
9+
from pytorch_lightning.utilities import rank_zero_warn
710

811

912
class ModelIO(object):
@@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None:
7982
hparams.update({k: v})
8083

8184

82-
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
85+
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
86+
"""Load hparams from a file.
87+
88+
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
89+
>>> path_csv = './testing-hparams.csv'
90+
>>> save_hparams_to_tags_csv(path_csv, hparams)
91+
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
92+
>>> vars(hparams) == hparams_new
93+
True
94+
>>> os.remove(path_csv)
95+
"""
8396
if not os.path.isfile(tags_csv):
84-
log.warning(f'Missing Tags: {tags_csv}.')
85-
return Namespace()
97+
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
98+
return {}
8699

87-
with open(tags_csv) as f:
88-
csv_reader = csv.reader(f, delimiter=',')
100+
with open(tags_csv) as fp:
101+
csv_reader = csv.reader(fp, delimiter=',')
89102
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
90-
ns = Namespace(**tags)
91-
return ns
103+
104+
return tags
105+
106+
107+
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
108+
if not os.path.isdir(os.path.dirname(tags_csv)):
109+
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
110+
111+
if isinstance(hparams, Namespace):
112+
hparams = vars(hparams)
113+
114+
with open(tags_csv, 'w') as fp:
115+
fieldnames = ['key', 'value']
116+
writer = csv.DictWriter(fp, fieldnames=fieldnames)
117+
writer.writerow({'key': 'key', 'value': 'value'})
118+
for k, v in hparams.items():
119+
writer.writerow({'key': k, 'value': v})
120+
121+
122+
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
123+
"""Load hparams from a file.
124+
125+
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
126+
>>> path_yaml = './testing-hparams.yaml'
127+
>>> save_hparams_to_yaml(path_yaml, hparams)
128+
>>> hparams_new = load_hparams_from_yaml(path_yaml)
129+
>>> vars(hparams) == hparams_new
130+
True
131+
>>> os.remove(path_yaml)
132+
"""
133+
if not os.path.isfile(config_yaml):
134+
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
135+
return {}
136+
137+
with open(config_yaml) as fp:
138+
tags = yaml.load(fp, Loader=yaml.SafeLoader)
139+
140+
return tags
141+
142+
143+
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
144+
if not os.path.isdir(os.path.dirname(config_yaml)):
145+
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
146+
147+
if isinstance(hparams, Namespace):
148+
hparams = vars(hparams)
149+
150+
with open(config_yaml, 'w', newline='') as fp:
151+
yaml.dump(hparams, fp)
92152

93153

94154
def convert(val: str) -> Union[int, float, bool, str]:
95-
constructors = [int, float, str]
96-
97-
if isinstance(val, str):
98-
if val.lower() == 'true':
99-
return True
100-
if val.lower() == 'false':
101-
return False
102-
103-
for c in constructors:
104-
try:
105-
return c(val)
106-
except ValueError:
107-
pass
108-
return val
155+
try:
156+
return ast.literal_eval(val)
157+
except (ValueError, SyntaxError) as e:
158+
log.debug(e)
159+
return val

0 commit comments

Comments
 (0)