Skip to content

Commit

Permalink
InceptionTimePlus was found to be missing from the special architectu…
Browse files Browse the repository at this point in the history
…re config clauses in build_ts_model, an assumption was made to put it in the list next to InceptionTime -> FIX for issue timeseriesAI#847
  • Loading branch information
Craig Versek committed Nov 8, 2023
1 parent 7701373 commit 004c0f3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nbs/030_models.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
" \"MLSTM_FCNPlus\", \"MGRU_FCNPlus\", \"RNNAttentionPlus\", \"LSTMAttentionPlus\", \"GRUAttentionPlus\", \"ConvTran\", \"ConvTranPlus\", 'mWDNPlus']:\n",
" pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} d={d} arch_config={arch_config}, kwargs={kwargs})', verbose)\n",
" model = (arch(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, **arch_config, **kwargs)).to(device=device)\n",
" elif sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'RNNPlus', 'LSTMPlus', 'GRUPlus', 'InceptionTime', 'TSiT', 'Sequencer', 'XceptionTimePlus',\n",
" elif sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'RNNPlus', 'LSTMPlus', 'GRUPlus', 'InceptionTime', 'InceptionTimePlus', 'TSiT', 'Sequencer', 'XceptionTimePlus',\n",
" 'GRU_FCN', 'OmniScaleCNN', 'mWDN', 'TST', 'XCM', 'MLP', 'MiniRocket', 'InceptionRocket', 'ResNetPlus', \n",
" 'RNNAttention', 'LSTMAttention', 'GRUAttention', 'MultiRocket', 'MultiRocketPlus', 'Hydra', 'HydraPlus', \n",
" 'HydraMultiRocket', 'HydraMultiRocketPlus']\n",
Expand Down
2 changes: 1 addition & 1 deletion tsai/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def build_ts_model(arch, c_in=None, c_out=None, seq_len=None, d=None, dls=None,
"MLSTM_FCNPlus", "MGRU_FCNPlus", "RNNAttentionPlus", "LSTMAttentionPlus", "GRUAttentionPlus", "ConvTran", "ConvTranPlus", 'mWDNPlus']:
pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} d={d} arch_config={arch_config}, kwargs={kwargs})', verbose)
model = (arch(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, **arch_config, **kwargs)).to(device=device)
elif sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'RNNPlus', 'LSTMPlus', 'GRUPlus', 'InceptionTime', 'TSiT', 'Sequencer', 'XceptionTimePlus',
elif sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'RNNPlus', 'LSTMPlus', 'GRUPlus', 'InceptionTime', 'InceptionTimePlus', 'TSiT', 'Sequencer', 'XceptionTimePlus',
'GRU_FCN', 'OmniScaleCNN', 'mWDN', 'TST', 'XCM', 'MLP', 'MiniRocket', 'InceptionRocket', 'ResNetPlus',
'RNNAttention', 'LSTMAttention', 'GRUAttention', 'MultiRocket', 'MultiRocketPlus', 'Hydra', 'HydraPlus',
'HydraMultiRocket', 'HydraMultiRocketPlus']
Expand Down

0 comments on commit 004c0f3

Please sign in to comment.