# Ondrej Platek Report on Multi Token Prediction for BottleCap

## Initial Notes

- Typical a single token prediction is trained using NextTokenPrediction (NTP) task / loss. Using Cross-Entropy (CE) loss.
- Speculative decoding will improve the task and is cheap for inference
- How multiple token loss will speed up training? Ignore speed-ups on inference
- The code is in a private repo https://github.com/oplatek/llama2.c and the experimens are logged to private https://wandb.ai/moosenet/bottlecap project. Ask me for access if interested.
- Be sure to see also the last section `Suggestions for other approaches to speed up training`.

## Plan
1. Quickly scan some relevant papers. Found [Better & Faster Large Language Models via Multi-token Prediction](https://arxiv.org/pdf/2404.19737).
2. Recap the baseline model training in the code-base
3. Test model export and running inference
4. Read code & implement the auxilary heads and losses so the export is still valid
5. Finally, experiment with the losses and hyperparameters

Initial thoughts:
- CE losses for (1-default), 2, 3, 4, k NTP as the first thing to try.
- Consistency loss between having access to -1, -2, -3 token (as standard in the masked attention) and not knowing the attention.
- Alternative and ambitios approaches to training speedup: ?Multiple-tokenizer support?, ?Diffusion based transformers?... see the end of report for more details.

## Problems encountered/Progress description
In general, most of the time I have spent solving technical details mostly regarding logging.
On the other hand, at the end I have a professional experiment logging setup at https://wandb.ai/moosenet/bottlecap/overview (Send me wandb login for access if you want to see details)

### Update the script for the smaller model
I updated the hyperparametrs to the smaller models

| model | dim | n_layers | n_heads | n_kv_heads | max context length | parameters | val loss | download |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 260K | 64 | 5 | 8 | 4 | 512 | 260K | 1.297 | [stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K) |

### Update the code to have auxilary function: Technical problems 
Mostly classics:
1. Transfering tensors between model and training loop for logging
2. Again transferring tensors to from cuda device to cpu for logging

### Finding the right setup of next token prediction loss
First, I implemented a naive implementation where the baseline CrossEntropy (CE) loss was summed with a `n` CrossEntropy losses for predicting `n+1` token using the same embedding `h_{n-1}`. 

In other words, the baseline next token prediction `NTP` CE loss used Transformer archicture and applied CE loss for context of len $n-1$ and logits for $n$ token where $n \in \{1,\ldots,\text{MAX\_SEQ\_LEN}\}$. 

Each of the `NiTP` Next $n + i$ token prediction loss used extra $head_i$.
The hyperparameter `aux_losses` allowed me to experiment with multiple losses.
However, I will present results only with 1 auxilarry loss.

**It was important** to scale down the influence of the auxilarry loss so I introduced weigthing of the NiTP loss.

I also introduced **consistency loss** between NTP logits for context of length $n - 1 +i$ and NiTP logits predicting $n + i$ token from context $n-1$ ie the same heads I used for NiTP loss.
The basic idea that the transformer should predict the same distributions even for context in the future.
I used KL divergance loss


## Experiments notes

I have run only very basic experiments due to time constraints (2h on Tuesday and 5h on Thursday).
I submit the first reasonable result.

The evaluated the speed of training I in terms of NTP (Next token prediction cross entropy) validation and training loss improvement per iteration step.
I used the default and fixed `batch_size`, `learning_rate`, `optimizer`, ..., `max_iter` parameters.

The screenshots below depicts four setups
1. [Base (Cyan)](): a baseline where no auxilarry loss was used
2. [NiTP (Yellow)](https://wandb.ai/moosenet/bottlecap/runs/q1jnto67/overview): On top of baseline only Next i=2 Token Prediction CE loss was used with weight 0.1
3. [Consist (Seafoam)](https://wandb.ai/moosenet/bottlecap/runs/j1gi1ayi/overview): On top of baseline only Consistency loss for head predicting Next i=2 token was used the loss was mixed with 0.01 (other values like 0.1 are left for future work)
4. [Consist&NiTP(red)](https://wandb.ai/moosenet/bottlecap/runs/808ar57j/overview): Both losses from above: combination of Consist and NiTP losses

**On the image below you see that at 20k steps the _NiTP(yellow)_ is actually better in terms of validation NTP loss.**
At the same time, you see that the consistency loss and setup with both consistency and NiTP loss perform worse than the baseline.

However, on the second and third image below, the situation is different at 3k steps;  all three setups with auxillary losses outperformed the baseline as seen on the images bellow. 
It may suggest that auxillary loss scheduling or simple better hyperparameters setup could show that consistency loss may be useful.

![val20k.png](https://raw.githubusercontent.com/oplatek/llama2.c/refs/heads/oplatek/doc/val20k.png?token=GHSAT0AAAAAACOZMRFDEKKAIQLQP33NV5WGZ7PB2RQ)


![val3k.png](https://raw.githubusercontent.com/oplatek/llama2.c/refs/heads/oplatek/doc/val3k.png?token=GHSAT0AAAAAACOZMRFCGZOZFCDCR7AQIVOCZ7PB6OA)

![train3k.png](https://raw.githubusercontent.com/oplatek/llama2.c/refs/heads/oplatek/doc/train3k.png?token=GHSAT0AAAAAACOZMRFDF42DISG7V2QOWKHSZ7PB7QA)

<!-- https://github.com/oplatek/llama2.c/blob/oplatek/doc/train3k.png?raw=true) -->

## Note on auxilarry losses behaviour
I noticed interesting behaviour of the three setups with auxilary losses. 
In the image below there the same three setups:

2. [NiTP (Yellow)](https://wandb.ai/moosenet/bottlecap/runs/q1jnto67/overview): On top of baseline only Next i=2 Token Prediction CE loss was used with weight 0.1
3. [Consist (Seafoam)](https://wandb.ai/moosenet/bottlecap/runs/j1gi1ayi/overview): On top of baseline only Consistency loss for head predicting Next i=2 token was used the loss was mixed with 0.01 (other values like 0.1 are left for future work)
4. [Consist&NiTP(red)](https://wandb.ai/moosenet/bottlecap/runs/808ar57j/overview): Both losses from above: combination of Consist and NiTP losses

and we logged the `ntp` loss the `aux_loss_0` which was NiTP, Consist and NiTP loss for setups 2, 3, 4.
The setup 4 logged also the consistency loss as the second auxillary loss `aux_loss_0`.

**Interestingly the consistency loss seems to be completely minimize around 3k and then it is probably in conflict with the NTP loss because it's value rises ten times to 0.2 around 15k and then decreases consistenly till the end.**

Note we logged the total loss function which was used in the optimizer but for evaluation one should looked only at NTP loss for comparison and to the individual auxilarry functions for more detailed behaviour.
![auxilary losses](https://raw.githubusercontent.com/oplatek/llama2.c/refs/heads/oplatek/doc/aux_losses.png?token=GHSAT0AAAAAACOZMRFDCBPASTK6JUTVDHHGZ7PBZZA)

## Summary
I managed to run the training on both CPU (on Mac) and GPU.
I achieved a prosiming prelimiary results where I showed that even within 1MD the next token prediction losses may (slightly) benefit the training speed.

The immediate future work would be to find the optimal setup for the hyperparameters and also optimize the code so the speed up can be measured not only in terms of iterations but also in terms of wall time.
For me personally, I would need to setup a environment with a GPU where identical runs will last the same amount of time.

# Suggestions for other approaches to speed up training

## Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
When I was thinking about predicting multiple tokens I immediately remember diffusion models and in particular a similar model from the paper [Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach](https://arxiv.org/abs/2502.05171).
I think it is worth reading the paper, but diffusion processes are anything but fast or efficient but in terms of iterations.
This direction is not among the fist I would start with after thinking about it for longer than 2 minutes.

## Redefining what the next tokens mean 🚀

### Spelling loss
I suggest that we would use new tokenization for the targets, namely used only a single letter tokens.
We would simply have a copy of targets tokenized as letters and pointers where the original targets starts in the letter array.
The auxilary head should simply predict the letters following the context.

For this particular loss there are three easy to implement variants:
1. True case
2. Lower case
3. Upper case

Benefits are:
- Targets are easy to prepare
- Head thanks to very limited vocabulary just letters could be very small
- LLM will learn dependencies between different tokenization (arguably)

Note: This is from top of my head, I should do literature review before implementing it.

### Use softer cross-entropy especially for next token
I am not sure at the moment how fast can be implemented softer CE as defined in paper [Soft Alignment Objectives for Robust Adaptation of Language Generation](https://aclanthology.org/2023.acl-long.492.pdf),
but for next token prediction it would be certainly beneficial to consider synonym tokens as valid prediction.
It would smooth the gradients.