# Tale of the NODE-Transformer: an unsuccessful, deeply non-ecological yet relatively fruitful study about cross-breeding Transformer with Neural-ODE

Code is available in Apache2 License on [Github Project](https://github.com/mandubian/pytorch-neural-ode/)

You can contact me on Twitter [@mandubian](http://twitter.com/mandubian)

> Disclaimer: I don't pretend to be an expert about ODE and optimizing mathematics so sorry if I'm not precise or even wrong on some aspects or intuitions... Yet, I'm continuously learning like my models

## Abstract

Once upon a time, there was a guy who decided to give life to a creature, half-Transformer, half-Neural-ODE. That seemed a cool idea at first sight but it revealed to be a monster, not so cute, always hungry to swallow all the single GPU power without beating any SOTA. Yet, along his misadventures, he met Prince Serendip and discovered many interesting knowledge.

In this _tale of the NODE-Transformer_, you will know more about those weird creatures, the Transformer and the Neural-ODE and you will see how they were hybridized in a very foolish way. You will discover the narrow link between mathematical complexity and knowledge complexity and how it is naturally managed by the ODE-Solver in NODE-Transformer. You will then see how NODE-Transformer encoder seems to learn knowledge keeping a stable complexity while decoder builds more and more complex knowledge to fulfill the output task. Yet, NODE-Transformer appears to be much less performant than a classic Transformer, even with 1-layer of encoder/decoder. Searching for possible causes, you will learn that Neural-ODE can't represent all kind of functions but can be augmented to overcome those limitations. Then, despite augmented ODE, despite reduced scope to decoder only and finally a few tricks of optimization and regularization, NODE-Transformer never reaches classic Transformer's performance.

Finally, despite this creature might not be as nice as he expected, you will understand why the guy above was still liking his ugly creature and convinced it was worth the efforts.

Now, let me tell you of the days of not so high adventures!

## Transformer as residual neural network

First a quick reminder on the architecture of [Transformer model](http://arxiv.org/abs/1706.03762) is well known:

<img src="../media/Transformer.svg" width="50%"/>

Transformer is built on an encoder and a decoder which are both made of `N` successive self-attention and feed-forward layers with residual connections.
So, beyond their a-priori complexity, Transformer encoder and decoder can be seen as residual neural networks.

## Neural ODE as continuous limit of residual network

As you may known, a residual neural layer is defined by the following equation:

<img src="../media/residual_network.png" width="15%"/>

Residual connection uses the input data and just adds an update to it. It allows not to forget completely the original data while learning to update it progressively along the layers.

In paper [Neural Ordinary Differential Equations](http://arxiv.org/abs/1806.07366), Neural ODE is defined as the continuous limit of a residual neural network:

$$
\begin{aligned}
h_{t+1} - h_{t} = f(h_{t}, {\theta}).1 \\
h_{t+\delta_{t}} - h_{t} = f(h_{t}, {\theta}).\delta_{t} \\
{\frac {h_{t+\delta_{t}} - h_{t}}{\delta_{t}}} = f(h_{t}, {\theta}) \\
\lim_{\delta_{t}\to0} {\frac {h_{t+\delta_{t}} - h_{t}}{\delta_{t}}} = f(h_{t}, {\theta}) \\
{\frac {dh(t)}{dt}} = f(h(t), t, {\theta})
\end{aligned}
$$

This definition is a well-known Ordinary Differential Equation which are classicaly solved with ODE Solvers (Euler, Runge-Kutta variants as shown in next figure).

<img src="../media/Runge-Kutta_slopes.svg" width="50%"/>


In the [Neural ODE paper](http://arxiv.org/abs/1806.07366), without entering in details, such equation is proven to to be learnable by a neural network trained by back-propagation on an augmented version of gradient and a call to an ODE Solver.

To give some intuition, in a 5-layers residual network sharing weights on all layers (see following figure), backpropagation is performed on 5 fixed steps `[0.0, 0.25, 0.5, 0.75, 1.0]` in interval `[0.0, 1.0]` from the output layer back to the input layer.

Meanwhile, in a 1-layer Neural ODE, backpropagation is performed in a shallow way by an ODE Solver which explores the interval `[0.0, 1.0]` in a continuous and dynamic way according to its solving algorithm. Modern ODE Solvers (like the DOPRI5 aka Dormand–Prince method used in current study) use adaptive steps dynamically refined based on max error approximation until they reach enough precision.

<img src="../media/node_grad.png" width="50%"/>

So, by relying on ODE Solver's continuous and dynamic sampling of the steps of backpropagation, a Neural ODE network can be seen as a network able to adapt its own depth during training depending on the loss horizon it explores. This feature is very appealing as it means you can have a network that is able to adapt its own complexity to the task it's trying to learn without increasing the number of parameters in the network.

In paper [Neural ODEs as the deep limit of resnets with constant weights](https://arxiv.org/pdf/1906.12183v1.pdf), authors prove mathematically that Neural ODE are the limit of resnets when depth grows infinitely (with a few assumptions and constraints).

Beyond the mathematical aspects of Neural ODE, those networks have proven to be very efficient at modelising continuous flows and continuous time series, better than RNN and with fewer parameters.

<img src="../media/node_spirals.png" width="50%"/>


## The stupid idea: a Neural-ODE Transformer

Having been playing with Neural ODE, having used intensively transformers in NLP domain, suddenly a very stupid idea materialized:

> **What if Neural ODE layers replace residual layers in Transformers?**

When one has such foolish idea, one has to try to un-think as much as possible and be crazy till the end because this is in such ideas that one can find the best or the worst things. But in both cases, one will have fun and learn a lot of things!

So it was decided to hack [Transformer from cool Facebook FairSeq library](https://fairseq.readthedocs.io/en/latest/#) and [Pytorch Neural ODE TorchDiffEq](https://github.com/rtqichen/torchdiffeq) and build **the creature half-Transformer, half-Neural-ODE: the NODE-Transformer**

The code is not polished and complete yet but it is available on [this github](https://github.com/mandubian/pytorch-neural-ode/tree/master/node-transformer-fair/node_transformer) based on Fairseq and TorchDiffEq.


## NODE-Transformer Architecture

Here is the first **NODE-Transformer architecture**:

<img src="../media/Node-Transformer-Full.svg" width="50%"/>

You can see that it still has separated Encoder and Decoder pipelines. But, now instead of several layers of self-attention/feed-forward with residual connections, it has one single layer of Neural-ODE solving a single self-attention/feed-forward sub-network without any residual connection.

After some hacking on FairSeq and Torchdiffeq, it was time to test it...

# Normal Transformer Training on multi30k-en-ge

First of all, to have a reference, default transformer, as defined in Fairseq, was trained on Multi30k-EN-GE translation task.

Here is the best loss plot of 2 training sessions:

- default Transformer with 6-layers encoder/decoder and all default parameters from FairSeq (<span style="color:blue">blue</span>)
- a classic Transformer with 1-layer encoder/decoder and all default parameters from FairSeq (<span style="color:green">green</span>)

<img src="../media/transformer_full_decoder_1layer_best_loss.svg" width="50%"/>

6-layers transformer converges and reaches a lower loss (3.04) but not very far from 1-layer (3.194).

1-layer encoder/decoder absorbs most of Multi30k task complexity and 6 times more parameters just gives a light improvement.


_Naturally, for a translation task, other metrics like BLEU or perplexity are provided in Fairseq but we'll focus here on loss as other metrics behave the same in all tests performed in this study._


# NODE Transformer Training on multi30k-en-ge

Now, NODE-transformer as described in architecture above was trained on the same task as classic Transformers.

Here is the plot of a training session (<span style="color:grey">train loss</span>/<span style="color:orange">validation loss</span>):

<img src="../media/transformer_1layer_node_transformer_full_loss.svg" width="50%"/>

Good point, it looks like converging: for validation dataset, it reaches a minimum before rising again. For training dataset, it goes down and tends to overfit on training data as a normal Transformer.

----

**Conclusion 0: NODE-Transformer learns something**

----

Next is a plot for the compared best loss between:

- a classic Transformer with 1-layer encoder/decoder in <span style="color:green">green</span>
- a NODE Transformer with NODE encoder/decoder in <span style="color:orange">orange</span>

<img src="../media/transformer_1layer_node_transformer_full.svg" width="50%"/>

We can see that NODE Transformer converges down to a limit loss of ~4.5, much higher than a classic 1-layer transformer (~3.2).

----

**Conclusion 1: As is, NODE transformer isn't able to reach the performance of a normal transformer with 1 single layer**

----



Let's see the same chart in relative time (number of hours in absciss):

<img src="../media/transformer_1layer_node_transformer_full_relative.svg" width="50%"/>

Training this NODE transformer took more than 1.5 day (35h) against 1.5 hour for classic transformer. That's also why it was early-stopped after stagnating a bit because this experiment is clearly not very ecological after all ;)

----

**Conclusion 2: Training a NODE transformer is much slower and sub-efficient than classic transformer.**

----

> Those disappointing results lead to wonder why it clearly learns less well than a 1-layer transformer. First idea was naturally to check ODE Solver parameters.


## ODE-Solver Absolute and relative max error tolerance

The first point that arises when using an ODE Solver such as DOPRI5 (Dormand–Prince method) are the 2 main hyper-parameters:

- absolute error tolerance (ATOL)
- relative error tolerance (RTOL)

Without going too deep in mathematics, let say that in ODE solver deriving from Runge-Kutta method, an approximation of the max error between ground-truth value and estimated value can be computed. So, it's possible to fix a max error tolerance that allows the ODE Solver to know whether the estimated value is a viable solution. This description of both parameters gives some more intuition: https://www.mathworks.com/help/matlab/math/troubleshoot-common-ode-problems.html#bu8pzr7

DOPRI5 is an adaptive explicit ODE Solver. It means that, if it can't reach a low-enough error approximation, the solver can _refine the sampling steps_ until it finds an acceptable value. The more it refines the steps, the more the ODE Solver calls the neural model function until it finds an acceptable solution depending on the error tolerance.

In previous training session, ATOL/RTOL were fixed to 0.001 meaning "roughly" that it can't be more precise than 3 digits. It took 1.5 day to train with this error tolerance on a 1080-TI GPU. Lower ATOL/RTOL would have been a good test but it's not decent with current computing power (specially this summer in France where it has been >40°C). That's why it was decided to increase ATOL/RTOL to 0.01 and check what happens.

Here is the Best Loss plot in relative time of a new training session with ATOL/RTOL set to 0.01:

<img src="../media/transformer_full_lower_atol.svg" width="50%"/>

It converges down to a loss limit of ~5.98, much higher than the previous value ~4.5 with ATOL/RTOL of 0.001. So, it doesn't learn as well as using a lower error tolerance which was expected. Yet, it's interesting to remark that it reached this local minimum in a bit more than 4h compared to the previous 35h. So, it demonstrates that a bigger error tolerance truly leads to faster ODE Solving.

Now, let's observe the number of calls made by the ODE Solver to the neural networks in both encoder and decoder during the training:

#### NODE-Transformer Decoder ODE Calls

<img src="../media/transformer_full_lower_atol_nfe_decoder.svg" width="50%"/>

We see here that the number of calls increases progressively with high peak values and becomes noisier when the NODE-Transformer saturates and doesn't optimize anymore.

This progressive increase of calls could mean the ODE Solver is exploring a relatively "simple" loss landscape at beginning of training. We can imagine knowledge that is learnt at first is basic in general. We can make an analogy with a human-crafted NLP pipeline. At first, you do basic structural tasks: tag words, lemmatize tags, find dependencies between lemma etc...

Then, the ODE Solver is encountering more complex loss landscape and needs to sample its surroundings finer and finer to evaluate viable values. This corresponds to the increasing complexity of knowledge that is learnt during training: from basic language forms, it needs to model syntax and semantics, contextualize and build relationships between entities.

So, it's really interesting to validate this intuition that a Neural-ODE is truly increasing its own complexity dynamically to accomodate to the increasing knowledge complexity along training. But it's not just like a Resnet increasing arbitrarily its depth, the ODE-solver is naturally refining its computations depending on the complexity of the loss landscape it's exploring.

----

**A bit of philosophy: this idea is mind-blowing as the optimization mechanism is effectivey forcing the ODE-Solver to increase its own complexity to solve more complex equations and at same time, optimizing requires to model more and more complex knowledge to learn the task. is it mathematical complexity driving knowledge complexity or knowledge complexity driving mathematical complexity? More like chicken and egg maybe ;)**

---- 
It also proves we still need to work a lot to understand and study optimizers more and more. Let's remark that no paper proved anything about those aspects, theoretically speaking so it's still experimental hypothesis for now.

The peak values are interesting too. It means there are batches which contain much more complicated information to represent and the ODE-Solver has hard time to estimate a value for that. Those very peaky values are also an issue for Neural-ODE trainings as one single batch can end in a very long duration of computation or even end in never-ending computations. This is an issue in current Neural-ODE as some trainings might never end if it can't solve a single batch as it never reaches a viable estimation. Maybe we could consider skipping those too complex batches and re-focus on them later when ODE-Solver can digest a higher complexity. To be studied in next version...

Finally, the noisier aspect at the end of training is quite logical as knowledge then is really at its maximum of complexity and might vary a lot depending on the batches.


#### NODE-Transformer Encoder ODE Calls

<img src="../media/transformer_full_lower_atol_nfe_encoder.svg" width="50%"/>

This is really interesting: the encoder here learns almost in constant number of calls, except a few peak values (corresponding to batches of validation datasets between epochs).

Remember that the NODE-Transformer decoder above increased the number of calls during training and the intuition is that it is building more and more complex knowledge.

It would require an in-depth study in itself but we can suggest that encoding task might be a quite "brute-force" and systematic mechanism, a bit like extracting and storing all information about data in a database: the items, the lemma, even the basic relations like ordering or context. But it's raw information, it doesn't really contain complex knowledge and it doesn't really care about the final task to learn. Meanwhile, the decoder knows the final task to learn and has the responsibility to use that encoded raw information and model the fine relations between entities. Thus, it naturally models more and more complex knowledge.

----
**Conclusion 3: Neural ODE really increase its complexity during training when it needs to model more complex knowledges.**

----

**Conclusion 4: NODE-Transformer Decoder truly increases its complexity during training while Encoder keeps it almost constant. It might suggest complex knowledge modelization linked to the task happens more on the decoding part of the network. Deeper study on how knowledge is built in a transformer would be required to validate this intuition.**

----

The same study with lower ATOL/RTOL like 0.0001 and 0.00001 and longer trainings would be good to see whether it can learn more and reach classic transformer performances. But it requires having more powerful computing resources than a single 1080-TI GPU.

----

**REQUEST FOR RESOURCES: If you like this topic and have GPU resources that you can share for free and want to help perform more studies on that idea, don't hesitate to contact me on Twitter @mandubian or Github, I'd be happy to consume your resources ;)**

----

> We have seen here the NODE-Transformer learns something, builds some complex knowledge but doesn't reach Transformer performances. So, we can wonder whether Neural-ODE based on self-attention from transformer is even able to represent translation task.


## Can Neural-ODE learn a continuous function that represents translation task based on self-attention + feed-forward ?

Let's be honest, the narrator has no theoretical answer to this question as he hasn't studied ODE math enough to have a proven fact on those points.

Yet, he has read this very interesting paper [Augmented Neural ODEs](http://arxiv.org/abs/1904.01681).

### Introduction to Augmented Neural ODE

> **Please note that next figures are just copied from the [paper](http://arxiv.org/abs/1904.01681)**

The authors in this paper demonstrate that Neural-ODE cannot represent all types of functions and fail at this kind of functions:

$$
\begin{aligned}
h(−1) = 1 \\
h(1) = −1
\end{aligned}
$$

<img src="../media/func_aug.svg" width="30%"/>

ODE Solver is made to solve continuous flows, not this kind of functions with discrete values or discontinuities in feature space. When solving equations, it's trying to "draw" continuous trajectories (aka ODE flow) from one point to the other in the feature space in the following way:

<img src="../media/func_traj.png" width="30%"/>

But in the paper, they prove that if ODE flow has 2 trajectories of this kind with different initialization conditions and that intersects, it ends in a contradiction as it would mean initialization conditions are the same. So it's impossible for an ODE flow to "draw" crossing trajectories.

For the above function, ODE flow can't cross trajectories so it ends in something like (full red and blue lines):

<img src="../media/func_traj_error.png" width="30%"/>

Amusingly, the dotted red and blue lines are the trajectories learnt by a classic ResNet network which is able to learn this kind of functions: the sampling of space is so discrete in Resnet that it jumps above the crossing of trajectories without even being aware of it. It's not a feature of Resnet, it's just blind imprecision ;)

The same issue happens with following functions:

$$
\begin{aligned}
h(x) = 
\begin{cases}
    −1 & \text{if ||x|| ≤ r1} \\
    1 & \text{if r2 ≤ ||x|| ≤ r3}
\end{cases}
\end{aligned}
$$

<img src="../media/func_circle.png" width="25%"/>

Neural-ODE fails to learn such function:

<img src="../media/func_circle_error.png" width="50%"/>

In the paper, they also prove the flow of an ODE is an homeomorphism which can only continuously deform the input space. But, an homeomorphism can't create holes or tear a region apart.

Thus for all those reasons, for the 2 previous kinds of functions, a Neural-ODE is stuck in one or the other point/region and can't "jump" across the gap.

To solve this issue, an augmented version of Neural-ODE is then proposed:

<img src="../media/func_aug_solution.png" width="25%"/>

This simple augmentation by adding a few more synthetic dimensions to the vector with an initialization to 0 allows the ODE flow to solve previous cases by "jumping" across gaps and progressively learn to separate the regions:

<img src="../media/func_aug_solving.png" width="75%"/>

So, we know that classic Neural-ODE aren't able to represent all functions.

### Is our function/task of translation relying on simplified Transformer representable with Neural-ODE?

At first sight, translation task doesn't look like a continuous function. It is about language which doesn't sound like a mathematical continuous function. It maps discrete sentences to discrete sentences which are constituted of discrete tokens. Can the feature space of this function be continuous without holes?

Moreover, let's consider the transformer itself. Firstly, Transformer encoder and decoder embeds input and output tokens with a positional embedding that aims at introducing the notion of position of tokens with respect to each other. This embedding is based on a sinusoidal function or a learned function. What is the impact of this positional embedding on our feature space? Does it make it continuous (or the reverse)? Secondly, self-attention and feed-forward function are all based on simple linear operations, softmax function, normalization, but also on non-linear RELU. A priori, there is no reason those functions introduce strange holes in feature space but a deeper study would be necessary to check that.

Anyway, without studying further, it's hard to characterize clearly the translation task based on transformer and be sure it can be learned by a classic Neural-ODE.

**Thus, it was decided to enhance NODE-Transformer with Augmented Neural ODE and test it to see if it helps improving performances.**



## Augmented NODE-Transformer

Please note that when augmenting the NODE-Transformer using technique above, for alignment reasons, we need to take into account the Transformer attention blocks which are multi-head blocks. So, when we augment the NODE-Transformer by `n` dimensions, it's in fact an augmentation of `n * number_attention_heads`.

### Failed Experiment: Training an Augmented NODE-Transformer

Augmented Neural ODE is implemented in this [github repository](https://github.com/EmilienDupont/augmented-neural-odes) based on TorchDiffEq library so it was quite easy to port it.

Here are the plots of a training session of augmented full NODE-Transformer with ODE error tolerance of 0.001:


| Best Validation Loss | Training/Validation Loss |
|:---:|:---:|
| <img src="../media/node_transformer_full_aug1_tol001_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_full_aug1_tol001_loss.svg" width="150%"/> | 

| Decoder ODE Calls | Encoder ODE Calls |
|:---:|:---:|
| <img src="../media/node_transformer_full_aug1_tol001_nfe_decoder.svg" width="150%"/> | <img src="../media/node_transformer_full_aug1_tol001_nfe_encoder.svg" width="150%"/> |

After 7h, the training was early-stopped. The Decoder ODE calls stayed stable for one epoch, then decreasing for a couple of epochs and then increased a lot in a noisy way. Encoder ODE calls stay stable as usual. It would have taken tens of hours to converge to a minimum loss and it was already more than 30°C in the place due to external and GPU heat.

With this early-stopping, we can't say whether Augmented NODE-Transformer allows to reach better performance than without augmentation.

So if we wanted to test further, we needed to reduce required computing resources.

----

**Conclusion 5: No bigger GPU, no augmented full NODE-Transformer**

----

## Reducing NODE-Transformer computing burden

We've seen in all previous experiments that NODE-Encoder seems to learn without increasing its complexity, suggesting that using the Neural-ODE on the encoding part might not be very useful compared to a classic neural network.

**A very simple idea to reduce computing resources and duration is to get rid of Node-Encoder and just keep Node-Decoder**


### Optimization 1: NODE-Transformer with NODE-Decoder only

Here is the architecture of NODE-Transformer-Decoder-Only:

<img src="../media/Node-Transformer-1.png" width="50%"/>

The encoder part is the original Transformer Encoder with N layers and residual connections and the decoder part is our NODE-Decoder with 1 layer.


#### Yet another non concluding experiment: NODE-Decoder only with augmented dimensions

Here is a training session of that architecture with 1 augmented dimension and a high ODE Error Tolerance of `0.01` to reduce training duration:

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_nfe_decoder.svg" width="150%"/> |

After 4h of training, the training ended stuck in an ODE solver never-ending computation. Apparently, it couldn't reach an estimated value in the error tolerance and started refining sampling steps finer and finer without finding any solution. Without any certitude but lots of imagination, we can imagine that the augmented dimension allows the ODE-Solver to transform the space. But maybe it deforms the space a bit too "hard" and it ends in a very perturbated or noisy landscape in which this high error tolerance of `0.01` never allows to reach a good estimation. But an in-depth study on the ODE feature space would be required to have any serious analysis...

----

**Conclusion 6: Neural-ODE can end in a never-ending ODE solving state**

----


> Please note that in the ODE solver used here, no limit is set to the max number of calls so it can end in this forever-computing case. If you limit the number of ODE calls, how should it be managed in the training?... To be studied!

Anyway, when it fails, never abandon, push further! So reducing computation was pushed further!

In Decoding path of Transformer, there are 2 parts:

- A multi-head self-attention block on the output sequence,
- A multi-head attention between this self-attention computed on the output sequence and the output of the Transformer encoder self-attention applied to the input sequence.

So, we can suggest the intuition that the self-attention block on the output sequence might behave like the self-attention block applied to input sequence: it learns "raw" knowledge about the sequence and we showed experimentally that it doesn't take much advantage of Neural-ODE capabilities of increasing its own complexity.

Despite being just an intuition, it was interesting to test this approach with the so-called _Separated Node-Decoder architecture_



### Optimization 2: NODE-Transformer with Separated NODE-Decoder

<img src="../media/Node-Transformer-2.png" width="50%"/>

The Neural-ODE is applied only on the attention between the encoder output computed on the input sequence and the self-attention on the output sequence.

#### Finally a concluding experiment: Augmented NODE-Transformer with Separated NODE-Decoder

Here is a training of this architecture with 1 augmented dimension, same high error tolerance of 0.01.

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_sep_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_nfe_decoder.svg" width="150%"/> |

This training converged in around 4h.

Loss reached a lower value of ~4.38 which is lower than the 4.5 obtained with 35h first training on full NODE-Transformer with lower error tolerance of 0.001. It is also much lower than the `6.5` obtained with higher error tolerance of 0.01 on full NODE-Transformer.

Yet, it's still much higher than the ~3.2 loss reached by classic transformer with 1 layer in encoder/decoder. 

Moreover, we can remark that ODE calls started much lower than previous trainings (65 compared to the 180-200 in other trainings) and started to rise progressively with longer stability steps before rising again.

Encouraging but not fully satisfying. As expected, the presence of classic Transformer self-attention limitates the computation resources required and seems to provide more stable learning. The NODE-decoder is able to use that knowledge starting from a lower level of complexity and then build more and more complex decoding knowledge. But it is not able to reach the same level of performance as the classic Transformer 1-layer decoder with residual connection pipeline. The too high error tolerance naturally might prevent from refining knowledge. It would be worth testing lower error tolerance but, as explained earlier, more computing power would be welcome again.

As we obtain almost the same results as with full NODE-Transfomer, we can infer that the learning bottleneck of performance of original NODE-Transformer might happen in this part of the Decoder (ie the attention between encoder output on input sequence and self-attention on output sequence). For now, not having the computing power to test the same code with lower error tolerances, it's hard to make further suppositions on the causes.

**Conclusion 7: restricting NODE-Transformer to separated NODE-Decoder reduces drastically computations as expected without impacting final performances suggesting the NODE-Transformer learning  bottleneck might be located in this part of the decoder.**

**Conclusion 8: Augmented dimension didn't help improving performance importantly so our translation function based on transformer might not be among the non-representable functions by Neural-ODE.**

Then it was decided to test a last solution to improve not the computation but the performance: make NODE-Transformer decoder aware of time!


## Optimization 3: NODE-Transformer trained with Time Dependency

>This idea of time-dependency was stolen from the augmented neural-ODE code base.

DOPRI5 ODE Solver is an explicit adaptive method that samples time-steps in the interval `[0.0, 1.0]` according to its algorithm and max error tolerance. The sampled time-step is provided to the embedded function, i.e our simplified Transformer neural network. But our Transformer is not aware of time, it doesn't use this time-step in its computation. So if the ODE Solver calls our network with same value at time-step `t1=0.1` and time-step `t2=0.9`, it will return the same result. Intuitively, one would like the Transformer to compute something different at different time-steps as it is not the same "place" in the feature space, right?

So after augmenting the Neural-ODE equation with a synthetic dimension, the final vector passed to the transformer function was also augmented with the time information:

$$
\begin{aligned}
f(x, t) &= f(\begin{bmatrix}
           x \\
           t \\
         \end{bmatrix})
\end{aligned}
$$


#### A semi-concluding experiment: NODE-Transformer with Augmented Separated NODE-Decoder and Time-Dependency

Here is a training of NODE-Transformer restricted to separated NODE-Decoder augmented with one dimension and time-dependency and error tolerance of 0.01.

_In next plots, dark blue and red plots are the current training. The light blue are from previous training "Augmented NODE-Transformer with Separated NODE-Decoder" to compare it directly._

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls (dark blue) |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_timedep_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_timedep_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_with_timedep_nfe_decoder.svg" width="150%"/> |

We see here that the training is almost the same as without time-dependence. It reaches the same lower value of loss with almost the same curve. The only difference is in the NODE-Decoder ODE calls which start higher and rise faster with time-dependency. It sounds logical as adding time information naturally increases apparent complexity at beginning of training. Yet at the end of training, with or without time-dependency, the number of ODE calls is about the same. Without having done any further study, it might mean that it reaches the same level of complexity in both cases or maybe even the same time-steps distribution.

**Conclusion 9: Adding time-dependency didn't improve our training performance but it remains an interesting idea for other Neural-ODE use-cases in which it could be more meaningful.**


## Optimization 4: NODE-Transformer trained with Weight Decay

A last aspect was experimented: is it possible to reduce the increase of ODE calls by acting on our training process?

In Paper [FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models](http://arxiv.org/abs/1810.01367), the authors remarked that number of function evaluations in Neural-ODE is reduced by using different forms of regularization such as weight decay and spectral normalization at the cost of hurting a bit performance.

In our case, performance isn't really good but we decided to try weight decay as it is very easy to put in place. Weight decay is a very simple idea: after updating weights, they are multiplied by a factor below 1 to prevent those weights from growing too fast.

#### Deceptive experiment: NODE-Transformer with Augmented Separated NODE-Decoder and Weight Decay

We trained a NODE-Transformer restricted to separated NODE-Decoder augmented with 1 dimension and weight decay factor of 0.9 and error tolerance 0.01.

_In next plots, green and grey plots are the current training. The other plots (light blue and red) are from previous training "Augmented NODE-Transformer with Separated NODE-Decoder" to compare it directly._

| Best Validation Loss (grey) | Training/Validation Loss (green & grey) | Decoder ODE Calls (green) |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder.svg" width="150%"/> |

We can see that the number of ODE calls is the same in first epoch and then decreases with weight decay and stays lower than training without weight decay for a couple of epochs. Then it increases suddenly, higher than training without weight decay at the same epoch. But then it stays quite stable despite being quite noisy. Meanwhile in the training without weight decay, the number of calls grows constantly and goes over the one without weight decay.

So, it seems to lower the ODE-Solver computations globally but the impact on loss is not minor as said in [FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models](http://arxiv.org/abs/1810.01367).

**Conclusion 10: using weight decay in NODE-Transformer training reduces globally the number of ODE evaluation but hurts the performance. Yet the idea that regularizations of learnable parameters can help tuning the resource requirements of an ODE Solver at the cost of some performance is really interesting.**


## Implementation details

### Hacking TorchDiffEq Neural-ODE

In this project, Pytorch was the framework used and Neural-ODE implementation was found in [torchdiffeq github](https://github.com/rtqichen/torchdiffeq).

TorchDiffEq Neural-ODE code is good for basic neural networks with one input and one output. But Transformer encoder/decoder is not really a basic neural network as attention network requires multiple inputs (Q/K/V) and different options.

Without going in details, we needed to extend TorchDiffEq code to manage multiple and optional parameters in `odeint_adjoint` and sub-functions. The code can be found [odeint_ext](https://github.com/mandubian/pytorch-neural-ode/tree/master/odeint_ext) and we'll see later if it's generic enough to be contribute it back to torchdiffeq project.


### Creating NODE-Transformer with fairseq

NODE-Transformer is just a new kind of Transformer as implemented in [FairSeq library](https://github.com/pytorch/fairseq).

So it was just implemented as a new kind of Transformer using FairSeq API, the [NODE-Transformer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_transformer.py). Implementing it wasn't so complicated, the API is quite complete, you need to read some code to be sure about what to do but nothing crazy. _The code is still raw, not yet cleaned-up and polished so don't be surprised to find weird comments or remaining useless lines in a few places._

A custom [NODE-Trainer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_trainer.py) was also required to integrate ODE function calls in reports. Maybe this part should be enhanced to make it more simply extensible

Here are the new options to manipulate the new kind of FairSeq NODE-Transformer:

```
    --arch node_transformer    
    --node-encoder
    --node-decoder
    --node-rtol 0.01
    --node-atol 0.01
    --node-ts [0.0, 1.0]
    --node-augment-dims 1
    --node-time-dependent
    --node-separated-decoder
```



## Conclusion

NODE-Transformer is cool juste because it's a nice kind of Transformer! But, as implemented right now, it is clearly not an efficient network, both in terms of performance and computing burden. Yet, despite the huge amount of electricity consumed for useless trainings, it was a nice trip because a lot about many topics was learnt and this trip might not be ended yet.

Many other aspects would be worth studing such as:

- having more computing GPU power to train NODE-Transformer with smaller error tolerances to see if performance improves at least to the level of a 1-layer classic Transformer
- studying what "increasing knowledge complexity" truly means along the training process: we could explore attention matrices, gradient fields, ODE-Solver feature landscape...
- testing Neural-ODE with attention mechanism alone on a simpler task to see how both behave together
- studying if it's a viable idea to skip batches that take too many ODE calls and recompute them later when global network has reached a sufficient complexity level for those batches.
- ...

---- 

**REQUEST FOR RESOURCES: If you like this topic and have GPU resources that you can share for free and want to help perform more studies on that idea, don't hesitate to contact me on Twitter @mandubian or Github, I'd be happy to consume your resources ;)**

----

Thanks and have Ordinary Differential Enjoyment!



## References

1. **Neural Ordinary Differential Equations**, Chen & al (2018), http://arxiv.org/abs/1806.07366

1. **Augmented Neural ODEs**, Dupont, Doucet, Teh (2018), http://arxiv.org/abs/1904.01681, 

1. **Neural ODEs as the Deep Limit of ResNets with constant weights**, Avelin, Nyström (2019), https://arxiv.org/abs/1906.12183v1

1. **FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models**, Grathwohl & al (2018), http://arxiv.org/abs/1810.01367