Skip to content

Commit

Permalink
feat(ml): Visformer architecture with torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed May 7, 2021
1 parent 5e7c8f6 commit 40ec03f
Show file tree
Hide file tree
Showing 7 changed files with 869 additions and 3 deletions.
5 changes: 3 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ self_supervised | string | yes | "" | self-supervised mode: "mask" for
embedding_size | int | yes | 768 | embedding size for NLP models
freeze_traced | bool | yes | false | Freeze the traced part of the net during finetuning (e.g. for classification)
retain_graph | bool | yes | false | Whether to use `retain_graph` with torch autograd
template | string | yes | "" | e.g. "bert", "gpt2", "recurrent", "nbeats", "vit", "ttransformer", "resnet50", ... All templates are listed in the [Model Templates](#model-templates) section.
template | string | yes | "" | e.g. "bert", "gpt2", "recurrent", "nbeats", "vit", "visformer", "ttransformer", "resnet50", ... All templates are listed in the [Model Templates](#model-templates) section.
template_params | dict | yes | template dependent | Model parameter for templates. All parameters are listed in the [Model Templates](#model-templates) section.
regression | bool | yes | false | Whether the model is a regressor
timesteps | int | yes | N/A | Number of timesteps for time models (LSTM/NBEATS...) : this sets the length of sequences that will be given for learning, every timestep contains inputs and outputs as defined by the csv/csvts connector
Expand Down Expand Up @@ -1401,7 +1401,7 @@ vgg_16 | deep neural net | Images | Convolutional network for image c

- LSTM-like models (including autoencoder): `recurrent`
- NBEATS model: `nbeats`
- Vision transformer: `vit`
- Vision transformer: `vit` and `visformer`
- Transformer-based timeseries models: `ttransformer`
- [TorchVision image classification models](https://pytorch.org/vision/0.8/models.html):
- `resnet18`
Expand Down Expand Up @@ -1480,6 +1480,7 @@ Parameter | Template | Type | Default | D
--------- | --------- | ------ | ---------------------------- | -----------
template_params.stackdef | nbeats | array of string | ["t2","s","g3","b3","h10" ] | default means: trend stack with theta = 2, seasonal stack with theta maxed , generic stack with theta = 3, 3 blocks per stacks, hidden unit size of 10 everywhere
template_params.vit_flavor | vit | string | vit_base_patch16 | Vision transformer architecture, from smaller to larger: vit_tiny_patch16, vit_small_patch16, vit_base_patch32, vit_base_patch16, vit_large_patch16, vit_large_patch32, vit_huge_patch16, vit_huge_patch32
template_params.visformer_flavor | visformer | visformer_tiny | Visformer architecture, from visformer_tiny or visformer_small
template_params.realformer | vit | bool | false | Whether to use the 'realformer' residual among attention heads
template_params.positional_encoding.type | ttransformer | string | "sincos" | Positional encoding "sincos for original frequential encoding, "naive" for simple enumeration
template_params.positional_encoding.learn | ttransformer | bool | false | learn or not positional encoding (starting from above value)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ if (USE_TORCH)
backends/torch/db_lmdb.cpp
backends/torch/native/templates/nbeats.cc
backends/torch/native/templates/vit.cc
backends/torch/native/templates/visformer.cc
backends/torch/native/templates/ttransformer.cc
backends/torch/native/templates/ttransformer/tembedder.cc
backends/torch/native/templates/ttransformer/positionalenc.cc
Expand Down
5 changes: 4 additions & 1 deletion src/backends/torch/native/native_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
*/

#include "native_factory.h"

#include "native_wrapper.h"

namespace dd
Expand Down Expand Up @@ -89,6 +88,10 @@ namespace dd

return new ViT(inputc, template_params);
}
else if (tdef.find("visformer") != std::string::npos)
{
return new Visformer(inputc, template_params);
}
else if (VisionModelsFactory::is_vision_template(tdef))
{
return VisionModelsFactory::from_template(tdef, template_params,
Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/native/native_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "native_net.h"
#include "./templates/nbeats.h"
#include "./templates/vit.h"
#include "./templates/visformer.h"
#include "./templates/ttransformer.h"
#include "../torchinputconns.h"
#include "apidata.h"
Expand All @@ -45,6 +46,7 @@ namespace dd
{
if (tdef.find("nbeats") != std::string::npos
|| tdef.find("vit") != std::string::npos
|| tdef.find("visformer") != std::string::npos
|| tdef.find("ttransformer") != std::string::npos)
return true;
else if (VisionModelsFactory::is_vision_template(tdef))
Expand Down
Loading

0 comments on commit 40ec03f

Please sign in to comment.