/
__init__.py
103 lines (92 loc) · 3.76 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import warnings
import torch.nn as nn
from neuralhydrology.modelzoo.arlstm import ARLSTM
from neuralhydrology.modelzoo.cudalstm import CudaLSTM
from neuralhydrology.modelzoo.mamba import Mamba
from neuralhydrology.modelzoo.customlstm import CustomLSTM
from neuralhydrology.modelzoo.ealstm import EALSTM
from neuralhydrology.modelzoo.embcudalstm import EmbCudaLSTM
from neuralhydrology.modelzoo.handoff_forecast_lstm import HandoffForecastLSTM
from neuralhydrology.modelzoo.hybridmodel import HybridModel
from neuralhydrology.modelzoo.gru import GRU
from neuralhydrology.modelzoo.mclstm import MCLSTM
from neuralhydrology.modelzoo.mtslstm import MTSLSTM
from neuralhydrology.modelzoo.multihead_forecast_lstm import MultiHeadForecastLSTM
from neuralhydrology.modelzoo.odelstm import ODELSTM
from neuralhydrology.modelzoo.sequential_forecast_lstm import SequentialForecastLSTM
from neuralhydrology.modelzoo.stacked_forecast_lstm import StackedForecastLSTM
from neuralhydrology.modelzoo.transformer import Transformer
from neuralhydrology.utils.config import Config
SINGLE_FREQ_MODELS = [
"cudalstm",
"ealstm",
"customlstm",
"embcudalstm",
"gru",
"transformer",
"mamba",
"mclstm",
"arlstm",
"handoff_forecast_lstm",
"sequential_forecast_lstm",
"multihead_forecast_lstm",
"stacked_forecast_lstm"
]
AUTOREGRESSIVE_MODELS = ['arlstm']
def get_model(cfg: Config) -> nn.Module:
"""Get model object, depending on the run configuration.
Parameters
----------
cfg : Config
The run configuration.
Returns
-------
nn.Module
A new model instance of the type specified in the config.
"""
if cfg.model.lower() in SINGLE_FREQ_MODELS and len(cfg.use_frequencies) > 1:
raise ValueError(f"Model {cfg.model} does not support multiple frequencies.")
if cfg.model.lower() not in AUTOREGRESSIVE_MODELS and cfg.autoregressive_inputs:
raise ValueError(f"Model {cfg.model} does not support autoregression.")
if cfg.model.lower() != "mclstm" and cfg.mass_inputs:
raise ValueError(f"The use of 'mass_inputs' with {cfg.model} is not supported.")
if cfg.model.lower() == "arlstm":
model = ARLSTM(cfg=cfg)
elif cfg.model.lower() == "cudalstm":
model = CudaLSTM(cfg=cfg)
elif cfg.model.lower() == "ealstm":
model = EALSTM(cfg=cfg)
elif cfg.model.lower() == "customlstm":
model = CustomLSTM(cfg=cfg)
elif cfg.model.lower() == "lstm":
warnings.warn(
"The `LSTM` class has been renamed to `CustomLSTM`. Support for `LSTM` will we dropped in the future.",
FutureWarning)
model = CustomLSTM(cfg=cfg)
elif cfg.model.lower() == "gru":
model = GRU(cfg=cfg)
elif cfg.model.lower() == "embcudalstm":
model = EmbCudaLSTM(cfg=cfg)
elif cfg.model.lower() == "mtslstm":
model = MTSLSTM(cfg=cfg)
elif cfg.model.lower() == "odelstm":
model = ODELSTM(cfg=cfg)
elif cfg.model.lower() == "mclstm":
model = MCLSTM(cfg=cfg)
elif cfg.model.lower() == "transformer":
model = Transformer(cfg=cfg)
elif cfg.model.lower() == "mamba":
model = Mamba(cfg=cfg)
elif cfg.model.lower() == "handoff_forecast_lstm":
model = HandoffForecastLSTM(cfg=cfg)
elif cfg.model.lower() == "multihead_forecast_lstm":
model = MultiHeadForecastLSTM(cfg=cfg)
elif cfg.model.lower() == "sequential_forecast_lstm":
model = SequentialForecastLSTM(cfg=cfg)
elif cfg.model.lower() == "stacked_forecast_lstm":
model = StackedForecastLSTM(cfg=cfg)
elif cfg.model.lower() == "hybrid_model":
model = HybridModel(cfg=cfg)
else:
raise NotImplementedError(f"{cfg.model} not implemented or not linked in `get_model()`")
return model