-
Notifications
You must be signed in to change notification settings - Fork 400
/
types.py
63 lines (44 loc) · 2.38 KB
/
types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Reference for common types used throughout the composer library.
Attributes:
Batch (Any): Alias to type Any.
A batch of data can be represented in several formats, depending on the application.
PyTorchScheduler (torch.optim.lr_scheduler._LRScheduler): Alias for base class of learning rate schedulers such
as :class:`torch.optim.lr_scheduler.ConstantLR`.
JSON (str | float | int | None | List['JSON'] | Dict[str, 'JSON']): JSON Data.
Dataset (torch.utils.data.Dataset[Batch]): Alias for :class:`torch.utils.data.Dataset`.
"""
from __future__ import annotations
from typing import Any, Dict, List, Union
import torch
import torch.utils.data
from composer.utils.string_enum import StringEnum
__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'BreakEpochException']
Batch = Any
Dataset = torch.utils.data.Dataset[Batch]
PyTorchScheduler = torch.optim.lr_scheduler._LRScheduler
JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']]
class BreakEpochException(Exception):
"""Raising this exception will immediately end the current epoch.
If you're wondering whether you should use this, the answer is no.
"""
pass
class MemoryFormat(StringEnum):
"""Enum class to represent different memory formats.
See :class:`torch.torch.memory_format` for more details.
Attributes:
CONTIGUOUS_FORMAT: Default PyTorch memory format represnting a tensor allocated with consecutive dimensions
sequential in allocated memory.
CHANNELS_LAST: This is also known as NHWC. Typically used for images with 2 spatial dimensions (i.e., Height and
Width) where channels next to each other in indexing are next to each other in allocated memory. For example, if
C[0] is at memory location M_0 then C[1] is at memory location M_1, etc.
CHANNELS_LAST_3D: This can also be referred to as NTHWC. Same as :attr:`CHANNELS_LAST` but for videos with 3
spatial dimensions (i.e., Time, Height and Width).
PRESERVE_FORMAT: A way to tell operations to make the output tensor to have the same memory format as the input
tensor.
"""
CONTIGUOUS_FORMAT = 'contiguous_format'
CHANNELS_LAST = 'channels_last'
CHANNELS_LAST_3D = 'channels_last_3d'
PRESERVE_FORMAT = 'preserve_format'