# NODE Transformer: Tale of an unsuccessful, deeply non-ecological yet fruitful study on Neural-ODE applied to Transformers

## Abstract

## Transformer as residual neural network

First a quick reminder on the architecture of [Transformer model](http://arxiv.org/abs/1706.03762) is well known:

<img src="../media/Transformer.svg" width="50%"/>

Transformer is built on an encoder and a decoder which are both made of `N` successive self-attention and feed-forward layers with residual connections.
So, beyond their a-priori complexity, Transformer encoder and decoder can be seen as residual neural networks.

## Neural ODE as continuous limit of residual network

As you may known, a residual neural layer is defined by the following equation:

<img src="../media/residual_network.png" width="25%"/>

In paper [Neural Ordinary Differential Equations](http://arxiv.org/abs/1806.07366), Neural ODE is defined as the continuous limit of a residual neural network:

$$
\begin{aligned}
h_{t+1} - h_{t} = f(h_{t}, {\theta}).1 \\
h_{t+\delta_{t}} - h_{t} = f(h_{t}, {\theta}).\delta_{t} \\
{\frac {h_{t+\delta_{t}} - h_{t}}{\delta_{t}}} = f(h_{t}, {\theta}) \\
\lim_{\delta_{t}\to0} {\frac {h_{t+\delta_{t}} - h_{t}}{\delta_{t}}} = f(h_{t}, {\theta}) \\
{\frac {dh(t)}{dt}} = f(h(t), t, {\theta})
\end{aligned}
$$

This definition is a well-known Ordinary Differential Equation which are classicaly solved with ODE Solvers (Euler, Runge-Kutta variants as shown in next figure).

<img src="../media/Runge-Kutta_slopes.svg" width="50%"/>


In the [Neural ODE paper](http://arxiv.org/abs/1806.07366), without entering in details, such equation is proven to to be learnable by a neural network trained by back-propagation on an augmented version of gradient and a call to an ODE Solver.

To give some intuition, in a 5-layers residual network sharing weights on all layers (see following figure), backpropagation is performed on 5 fixed steps from the output layer back to the input layer: if you imagine the layers to be represented by discrete steps in an interval `[0.0, 1.0]`, you can define 5 equally-distant points `[0.0, 0.25, 0.5, 0.75, 1.0]` in interval `[0.0, 1.0]` to represent those fixed steps.

Meanwhile, in a 1-layer Neural ODE, backpropagation is performed in a shallow way by an ODE Solver which explores the interval `[0.0, 1.0]` in a continuous and dynamic way according to its solving algorithm. Modern ODE Solvers (like the DOPRI5 aka Dormand–Prince method used in current study) use adaptive steps dynamically refined based on max error approximation until they reach enough precision.

<img src="../media/node_grad.png" width="50%"/>

So, by relying on ODE Solver's continuous and dynamic sampling of the steps of backpropagation, a Neural ODE network can be seen as a network able to adapt its own depth during training depending on the gradient horizon it explores. This feature is very appealing to me as it means you can have a network that is able to adapt its own complexity to the task it's trying to learn without increasing the number of parameters in the network.

In paper [Neural ODEs as the deep limit of resnets with constant weights](https://arxiv.org/pdf/1906.12183v1.pdf), authors prove mathematically that Neural ODE are the limit of resnets when depth grows infinitely (with a few assumptions and constraints).

Beyond the mathematical aspects of Neural ODE, those networks have proven to be very efficient at modelising continuous flows and continuous time series, better than RNN and with fewer parameters.

<img src="../media/node_spirals.png" width="50%"/>


## The stupid idea: a Neural-ODE Transformer

I've been playing with Neural ODE since last year.
I've been working with transformers in NLP domain.

Suddenly a very stupid idea materialized in my mind:

> **What if I replace residual layers in Transformers by Neural ODE layers?**

When I have such foolish idea, I try to un-think as much as possible and be crazy till the end because this is in such ideas that you can find the best or the worst things. But in both cases, you'll have fun and learn a lot of things!

So I decided to hack [Transformer from cool Facebook FairSeq library](https://fairseq.readthedocs.io/en/latest/#) and [Pytorch Neural ODE](https://github.com/rtqichen/torchdiffeq) and build **my creature half-Transformer, half-Neural-ODE: the NODE-Transformer**

The code is not polished and complete yet but it is already on [my github](https://github.com/mandubian/pytorch-neural-ode/tree/master/node-transformer-fair/node_transformer)


## NODE-Transformer Architecture

Here is the first **NODE-Transformer architecture**:

<img src="../media/Node-Transformer-Full.svg" width="50%"/>

You can see that it still has separated Encoder and Decoder sides. But, now instead of several layers of self-attention/feed-forward with residual connections, it has one single layer of Neural-ODE solving a single self-attention/feed-forward sub-network without any residual connection.

After some hacking on FairSeq and Torchdiffeq, it was time to test it...

# Normal Transformer Training on multi30k-en-ge

First of all, to have a reference, I trained a default transformer (as defined by Fairseq) on Multi30k-EN-GE translation task.

Here is the best validation loss plot of 2 training sessions:

- a classic Transformer with 6-layers encoder/decoder (<span style="color:blue">blue</span>)
- a classic Transformer with 1-layer encoder/decoder (<span style="color:green">green</span>)

<img src="../media/transformer_full_decoder_1layer_best_loss.svg" width="50%"/>

6-layers transformer converges and reaches a lower loss (3.04) but not very far from 1-layer (3.194).

1-layer encoder/decoder absorbs most of Multi30k task complexity and 6 times more parameters just gives a light improvement.


_Naturally, for a translation task, other metrics are used like BLEU or perplexity are provided in Fairseq but I'll focus here on loss as other metrics behave the same in my tests._


# NODE Transformer Training on multi30k-en-ge

I started by training a NODE-transformer as described in architecture above.

Here is the plot for a training session (<span style="color:grey">train loss</span>/<span style="color:orange">validation loss</span>):

<img src="../media/transformer_1layer_node_transformer_full_loss.svg" width="50%"/>

Good point, it looks like converging: for validation dataset, it reaches a minimum before rising again. For training dataset, it goes down and on other session, I let it train further and it tends to overfit on training data as expected.

**Conclusion 0: NODE transformer learns something**

Here is a plot for the compared best loss (as computed by fairseq library, differently than above loss) between:

- a classic Transformer with 1-layer encoder/decoder in <span style="color:green">green</span>
- a NODE Transformer with NODE encoder/decoder in <span style="color:orange">orange</span>

<img src="../media/transformer_1layer_node_transformer_full.svg" width="50%"/>

We can see that NODE Transformer converges down to a limit loss of ~4.5, much higher than a classic 1-layer transformer (~3.2).

**Conclusion 1: As is, NODE transformer isn't able to reach the performance of a normal transformer with 1 single layer**



Let's see the same chart in relative time (number of hours in absciss):

<img src="../media/transformer_1layer_node_transformer_full_relative.svg" width="50%"/>

Training this NODE transformer took more than 1.5 day (35h) against 1.5 hour for classic transformer. That's also why it was early-stopped after stagnating a bit because this experiment is clearly not very ecological after all ;)

**Conclusion 2: Training a NODE transformer is much slower and sub-efficient than classic transformer.**

> Those disappointing results lead me to wonder why it clearly learns less well than a 1-layer transformer. My first idea was naturally to check ODE Solver parameters.


## ODE-Solver Absolute and relative max error tolerance

The first point that arises when using an ODE Solver such as DOPRI5 (Dormand–Prince method) are the 2 main hyper-parameters:

- absolute error tolerance (ATOL)
- relative error tolerance (RTOL)

Without going too deep in mathematics, let say that in ODE solver deriving from Runge-Kutta method, an approximation of the max error between ground-truth value and estimated value can be computed. So, it's possible to fix a max error tolerance that allows the ODE Solver to know whether the estimated value is a viable solution. This description of both parameters give some more intuition: https://www.mathworks.com/help/matlab/math/troubleshoot-common-ode-problems.html#bu8pzr7

DOPRI5 is an adaptive explicit ODE Solver. It means that, if it can't reach a low-enough error approximation, the solver can _refine the sampling steps_ until it finds an acceptable value. The more it refines the steps, the more the ODE Solver calls the neural model function until it finds an acceptable solution depending on the error tolerance.

In previous training session, ATOL/RTOL were fixed to `0.001` meaning _roughly_ that it can't be more precise than 3 digits. It took 1.5 day to train with this error tolerance on my 1080TI GPU. I would have liked to use lower ATOL/RTOL but it's not decent with my current computing power (specially this summer in France where it has been >40°C). That's why I've preferred reducing ATOL/RTOL to `0.01` and check what happens.

Here is the best_loss plot in relative time of a new training session with ATOL/RTOL set to `0.01`:

<img src="../media/transformer_full_lower_atol.svg" width="50%"/>

It converges down to a loss limit of ~5.98, much higher than the previous value ~4.5 with ATOL/RTOL of `0.001`. So, it doesn't learn as well as using a lower error tolerance which was expected. Yet, it's interesting to remark that it reached this local minimum in a bit more than 4h compared to the previous 35h. So, it demonstrates that a bigger error tolerance leads to faster ODE Solving.

Now, let's observe the number of calls to the 2 sides (encoder and decoder) by the ODE Solver during the training:

#### NODE-Transformer Decoder ODE Calls

<img src="../media/transformer_full_lower_atol_nfe_decoder.svg" width="50%"/>

We see here that the number of calls increases progressively with high peak values and becomes noisier when the NODE-Transformer saturates and doesn't optimize anymore.

This progressive increase of calls could mean the ODE Solver is exploring a relatively "simple" loss manifold landscape at beginning of training. We can imagine knowledge that is learnt at first is basic in general. We can make an analogy with a human-crafted NLP pipeline. At first, you do basic structural tasks: tag words, lemmatize tags, find dependencies between lemma etc...

Then, the ODE Solver is encountering more complex loss landscape and need to sample its surroundings finer and finer to evaluate viable values. This corresponds to the increasing complexity of knowledge that is learnt during training: from basic language forms, it needs to model syntax and semantics, relationships between entities, contextualized all of that.

So, it's really interesting to validate this intuition that a Neural-ODE is truly increasing its own complexity dynamically to accomodate to the increasing knowledge complexity along training. But it's not just like a Resnet increasing arbitrarily its depth, the ODE-solver is naturally refining its computations depending on the complexity of the loss landscape it's exploring. Personally, I find this idea mind-blowing and it also proves we still need to work a lot to understand and study more efficient optimizers... Let's remark that I have not read yet anything concrete, theoretically speaking, about these ideas. It's more experimental hypothesis for now so don't take as a math proof.

The peak values are interesting too. It means there are batches which contain much more complicated information and the ODE-Solver has hard time to estimate a value for that. Those very peaky values are also an issue for Neural-ODE trainings as one single batch can end in a very long duration of computation or even end in never-ending computations. This is an issue in current Neural-ODE as some trainings might never end if it can't solve a single batch as it never reaches a viable estimation. Maybe we could consider skipping those too complex batches and re-focus on them later when ODE-Solver can digest a higher complexity. For a next version...

Finally, the noisier aspect at the end of training is quite logical as knowledge then is really more complicated to learn so the ODE-Solver needs to explore finer and finer depending on the complexity it encounters in different batches.


#### NODE-Transformer Encoder ODE Calls

<img src="../media/transformer_full_lower_atol_nfe_encoder.svg" width="50%"/>

This is really interesting: the encoder here learns almost in constant calls, except a few peak values (corresponding to validation datasets between epochs). Remember that the NODE-Transformer decoder above increased the number of calls during training demonstrating that decoder is building more and more complex knowledge.

It would require an in-depth study in itself but we can suggest that encoding task might be a quite brute-force and systematic task, a bit like storing all information about your data in a database: the items, the lemma, even the relations. But it's raw information, it doesn't really contain complex knowledge. The responsibility to use it and model the fine relations between entities is for the decoder which knows the task to fulfill and needs to model more and more complex knowledge to learn.


**Conclusion 3: Neural ODE really increase its complexity during training when it needs to model more complex knowledges.**

**Conclusion 4: NODE-Transformer Decoder truly increases its complexity during training while Encoder keeps it almost constant. It might suggest complex knowledge modelization happens more on the decoding part of the network but deeper study would be required...**

I'd like to do the same study with lower ATOL/RTOL like `0.0001` and `0.00001` and longer trainings to see whether it can learn more and reach classic transformer performances. But I haven't the computing power in my hands for that.

**If you have GPU power that you share with me for free and want to help me check whether lower ODE Solver error tolerance can help NODE-Transformer reach 1-layer transformer performance and perform more studies on that idea, don't hesitate to contact me on Twitter @mandubian or Github, I'd be happy ;)**


## Can Neural-ODE learn a continuous function that represents translation task based on self-attention + feed-forward ?

Let's be honest, I have no theoretical answer to this question. I haven't studied ODE math enough to have a proven intuition on those points.

Yet, I've read this very interesting paper [Augmented Neural ODEs](http://arxiv.org/abs/1904.01681).

### Introduction to Augmented Neural ODE

> **Please note that next figures are just copied from the [paper](http://arxiv.org/abs/1904.01681)**

The authors in this paper demonstrate that Neural-ODE cannot represent all types of functions and fail at this kind of functions:

$$
\begin{aligned}
h(−1) = 1 \\
h(1) = −1
\end{aligned}
$$

<img src="../media/func_aug.svg" width="30%"/>

As ODE Solver is made to solve continuous flows, not functions with discrete values or "holes", it's trying to "draw" trajectories (aka ODE flow) from one point to the other in the following way:

<img src="../media/func_traj.png" width="30%"/>

But in the paper, they prove that if ODE flow has 2 trajectories with different initialization conditions and that intersects, it ends in a contradiction as it would mean initialization conditions are the same. So it's impossible for an ODE flow with crossing trajectories.

For the above function, ODE flow can't cross trajectories so it ends in something like (full red and blue lines):

<img src="../media/func_traj_error.png" width="30%"/>

Amusingly, the dotted red and blue lines are the trajectories learnt by a classic ResNet network which is able to learn this kind of functions: the sampling of space is so discrete in Resnet that it jumps above the crossing of trajectories without even being aware of it. It's not a feature of Resnet, it's just luck!

The same issue happens with following functions:

$$
\begin{aligned}
h(x) = 
\begin{cases}
    −1 & \text{if ||x|| ≤ r1} \\
    1 & \text{if r2 ≤ ||x|| ≤ r3}
\end{cases}
\end{aligned}
$$

<img src="../media/func_circle.png" width="25%"/>

Neural-ODE fails to learn such function:

<img src="../media/func_circle_error.png" width="50%"/>

In the paper, the flow of an ODE is proven to be an homeomorphism which can only continuously deform the input space. So, it can't create holes or tear a region apart. For both previous functions, a Neural-ODE is stuck in one or the other point/region and can't "jump" across the gap.

To solve this issue, an augmented version of Neural-ODE is then proposed:

<img src="../media/func_aug_solution.png" width="25%"/>

This simple augmentation by adding a few more synthetic dimensions to the vector with an initialization to 0 allows the ODE flow to solve previous cases by "jumping" across gaps:

<img src="../media/func_aug_solving.png" width="75%"/>

So, we know that classic Neural-ODE aren't able to represent all functions.

### Is our function/task of translation relying on simplified Transformer representable with Neural-ODE?

At first sight, translation task doesn't look like a continuous function. It is about language which is not a mathematical continuous function. It maps discrete sentences to discrete sentences which are constituted of discrete tokens. Can the feature space of this function be continuous without holes?

Moreover, let's consider the transformer itself. Firstly, Transformer encoder and decoder embeds input and output tokens with a positional embedding that aims at introducing the notion of position of tokens with respect to each other. This embedding is based on a sinusoidal function or a learned function. What is the impact of this positional embedding on our feature space? Does it make it continuous (or the reverse)? Secondly, self-attention and feed-forward function are all based on simple linear operations, softmax function, normalization, but also on non-linear RELU. A priori, there is no reason those functions introduce strange holes in feature space but a deeper study would be necessary to check that.

Anyway, without studying further, it's hard to characterize clearly the translation task based on transformer and be sure it can be learned by a classic Neural-ODE.

**Thus, I decided to enhance NODE-Transformer with Augmented Neural ODE and test it to see if it helps improving performances.**



## Augmented NODE-Transformer

Please note that when augmenting the NODE-Transformer using technique above, for alignment reasons, we need to take into account the Transformer attention blocks which are multi-head blocks. So, when we augment the NODE-Transformer by `n` dimensions, it's in fact an augmentation of `n * number_attention_heads`.

### Failed test 1: Training an Augmented NODE-Transformer

Augmented Neural ODE is implemented by on TorchDiffEq library in this [github repository](https://github.com/EmilienDupont/augmented-neural-odes) so it was quite easy to port it to my own hack of TorchDiffEq.

In next figures, you can see the plots of a training session of augmented full NODE-Transformer with ODE error tolerance of `0.001`:


| Best Validation Loss | Training/Validation Loss |
|:---:|:---:|
| <img src="../media/node_transformer_full_aug1_tol001_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_full_aug1_tol001_loss.svg" width="150%"/> | 

| Decoder ODE Calls | Encoder ODE Calls |
|:---:|:---:|
| <img src="../media/node_transformer_full_aug1_tol001_nfe_decoder.svg" width="150%"/> | <img src="../media/node_transformer_full_aug1_tol001_nfe_encoder.svg" width="150%"/> |

After 7h, I've stopped the training. The Decoder ODE calls stayed stable, even decreasing for a few epochs and then increased a lot in a noisy way. Encoder ODE calls stay stable. It would have taken tens of hours to converge to a minimum loss and I could not wait so long, it was like 30°C in my place due to high temperature outside increased by GPU heat.

With this early-stopping, we can't say whether Augmented NODE-Transformer allows to reach better performance than without augmentation. Neural-ODE, for now, are clearly not GreenAI.

So if I wanted to test further, I needed to reduce required computing resources.

**Conclusion 5: No bigger GPU, no augmented NODE-Transformer**


## Reducing NODE-Transformer computing burden

In current configuration with my computing power and temperature limitations, if I wanted to study further, I clearly needed to change my approach. I was far too ambitious considering the NODE-Transformer in a whole with a NODE-Encoder and NODE-Decoder.

We've seen in all previous experiments that NODE-Encoder seems to learn without increasing its complexity, suggesting that using the Neural-ODE on the encoding part might not be very useful compared to a classic neural network.

**A very simple idea to reduce computing resources and duration is to get rid of Node-Encoder and just keep Node-Decoder**


### Optimization 1: NODE-Transformer with NODE-Decoder only

Here is the architecture of NODE-Transformer-Decoder-Only:

<img src="../media/Node-Transformer-1.png" width="50%"/>

The encoder part is the original Transformer Encoder with N layers and residual connections and the decoder part is our NODE-Decoder with 1 layer.


#### Yet another non concluding experiment: NODE-Decoder only with augmented dimensions

Here is a training session of that architecture with 1 augmented dimension and a high ODE Error Tolerance of `0.01` to reduce training duration:

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_nfe_decoder.svg" width="150%"/> |

After 4h of training, the training ended stuck in an ODE solver never-ending computation. Apparently, it couldn't reach an estimated value in the error tolerance and started refining sampling steps finer and finer without finding any solution. We can imagine that the augmented dimension allows the ODE-Solver to transform the space, but maybe it deforms it a bit too "hard" and it ends in a very perturbated or noisy landscape in which this quite high error tolerance of `0.01` never allows to reach a good estimation. But an in-depth study on the ODE feature space would be required and I wouldn't pretend I know anything serious about that. 

> Please note that in the ODE solver used here, no limit is set to the max number of calls s it can end in this forever-computing case. If you limit the number of ODE calls, how should it be managed in the training?... To be studied!

Anyway, when it fails, never abandon, push further! So I pushed reducing computations further ;)

In Decoding path of Transformer, there are 2 parts:

- A multi-head self-attention block on the output sequence,
- A multi-head attention between this self-attention computed on the output sequence and the output of the Transformer encoder self-attention applied to the input sequence.

So, we can suggest the intuition that the self-attention block on the output sequence might behave like the self-attention block applied to input sequence: it learns "raw" knowledge about the sequence and we showed experimentally that it doesn't take much advantage of Neural-ODE capabilities of increasing its own complexity.

Despite being just an intuition, it was interesting to test this approach with the so-called _Separated Node-Decoder architecture_



### Optimization 2: NODE-Transformer with Separated NODE-Decoder

<img src="../media/Node-Transformer-2.png" width="50%"/>

The Neural-ODE is applied only on the attention between the encoder output computed on the input sequence and the self-attention on the output sequence.

#### Finally a concluding experiment: Augmented NODE-Transformer with Separated NODE-Decoder

We performed a training with 1 augmented dimension, same high error tolerance of 0.01.

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_sep_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_nfe_decoder.svg" width="150%"/> |

This training converged in around 4h.

Loss reached a lower value of ~4.38 which is lower than the 4.5 obtained with 35h first training on full NODE-Transformer with lower error tolerance of 0.001. It is also much lower than the `6.5` obtained with higher error tolerance of 0.01 on full NODE-Transformer.

Yet, it's still much higher than the ~3.2 loss reached by classic transformer with 1 layer in encoder/decoder. 

Moreover, we can remark that ODE calls started much lower than previous trainings (65 compared to the 180-200 in other trainings) and started to rise progressively with longer stability steps before rising again.

Encouraging but not fully satisfying. As expected, the presence of classic Transformer self-attention limitates the computation resources required and seems to provide more stable learning. The NODE-decoder is able to use that knowledge starting from a lower level of complexity and then build more and more complex decoding knowledge. But it is not able to reach the same level of performance as the classic Transformer 1-layer decoder with residual connection pipeline. The too high error tolerance naturally might prevent from refining knowledge. It would be worth testing lower error tolerance but I would need more computing power again. Let's remark that the augmented dimension doesn't seem to help a lot even if we reach a lightly better performance.

As we obtain almost the same results as with full NODE-Transfomer, we can infer that the learning bottleneck of performance of original NODE-Transformer might happen in this part of the Decoder (ie the attention between encoder output on input sequence and self-attention on output sequence). For now, not having the computing power to test the same code with lower error tolerances, it's hard to make further suppositions on the causes.

**Conclusion 6: restricting NODE-Transformer to separated NODE-Decoder reduces drastically computations as expected without impacting final performances suggesting the NODE-Transformer learning  bottleneck might be located in this part of the decoder.**

I decided to test a last solution to improve not the computation but the performance: make NODE-Transformer decoder aware of time!


## Optimization 3: NODE-Transformer trained with Time Dependency

>I had this idea of time-dependency for some time but I saw they had implemented it in the augmented neural-ODE code so I decided to try it myself.

DOPRI5 ODE Solver is an explicit adaptive method that samples time-steps in the interval `[0.0, 1.0]` according to its algorithm and max error tolerance. The sampled time-step is provided to the embedded function, i.e our simplified Transformer neural network. But our Transformer is not aware of time, it doesn't use this time-step in its computation. So if the ODE Solver calls our network with same value at time-step `t1=0.1` and time-step `t2=0.9`, it will return the same result. Intuitively, one would like the Transformer to compute something different at different time-steps as it is not the same "place" in the feature space, right?

So after augmenting the Neural-ODE with a synthetic dimension, I also augmented the inference vector with the time information:

$$
\begin{aligned}
f(x, t) &= f(\begin{bmatrix}
           x \\
           t \\
         \end{bmatrix})
\end{aligned}
$$


#### A semi-concluding experiment: NODE-Transformer with Augmented Separated NODE-Decoder and Time-Dependency

Here is a training of NODE-Transformer restricted to separated NODE-Decoder augmented with one dimension and time-dependency and error tolerance of 0.01.

_In next plots, dark blue and red plots are the current training. The light blue are from previous training "Augmented NODE-Transformer with Separated NODE-Decoder" to compare it directly._

| Best Validation Loss | Training/Validation Loss | Decoder ODE Calls (dark blue) |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_timedep_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_timedep_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_with_timedep_nfe_decoder.svg" width="150%"/> |

We see here that the training is almost the same as without time-dependence. It reaches the same lower value of loss with almost the same curve. The only difference is in the NODE-Decoder ODE calls which start higher and rise faster with time-dependency. It sounds logical as adding time information naturally increases apparent complexity at beginning of training. Yet at the end of training, with or without time-dependency, the number of ODE calls is about the same. Without having done any further study, it might mean that it reaches the same level of complexity in both cases or maybe even the same time-steps distribution.

**Conclusion 7: Adding time-dependency didn't improve our training performance but it remains an interesting idea for other Neural-ODE use-cases in which it could be more meaningful.**


## Optimization 4: NODE-Transformer trained with Weight Decay

I finally tested a last aspect. Is it possible to reduce the increase of ODE calls by acting on our training process. in Paper [FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models](http://arxiv.org/abs/1810.01367), they remarked that number of function evaluations in Neural-ODE is reduced by using different forms of regularization such as weight decay and spectral normalization at the cost of hurting a bit performance.

In our case, performance isn't really good but I decided to try weight decay as it is very easy to put in place. Weight decay is a very simple idea: after updating weights, they are multiplied by a factor below 1 to prevent those weights from growing too fast.

#### Deceptive experiment: NODE-Transformer with Augmented Separated NODE-Decoder and Weight Decay

We trained a NODE-Transformer restricted to separated NODE-Decoder augmented with 1 dimension and weight decay factor of 0.9.

_In next plots, green and grey plots are the current training. The other plots (light blue and red) are from previous training "Augmented NODE-Transformer with Separated NODE-Decoder" to compare it directly._

| Best Validation Loss (grey) | Training/Validation Loss (green & grey) | Decoder ODE Calls (green) |
|:---:|:---:|:---:|
| <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss.svg" width="150%"/> | <img src="../media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder.svg" width="150%"/> |

We can see that the number of ODE calls is the same in first epoch and then decreases with weight decay and stays lower than training without weight decay for a couple of epochs. Then it increases suddenly, higher than training without weight decay at the same epoch. But then it stays quite stable despite being quite noisy. Meanwhile in the training without weight decay, the number of calls grows constantly and goes over the one without weight decay.

So, it seems to lower the ODE-Solver computations globally but the impact on loss is not tiny as expected in [FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models](http://arxiv.org/abs/1810.01367).

**Conclusion 8: using weight decay in NODE-Transformer training reduces globally the number of ODE evaluation but hurts the performance. Yet the idea that regularizations of learnable parameters can help tuning the resource requirements of an ODE Solver at the cost of some performance is really interesting.**


## Implementation



## Conclusion

NODE-Transformer is a nice subject of study but not really an efficient network in reality, both in terms of performance and computing burden.




## Next steps
