# Performance Verification and Transferability Evaluation of TTT Layer

:reference: https://github.com/test-time-training/ttt-lm-pytorch

:suggesting paper: https://arxiv.org/abs/2407.04620

## 0. From Quick Start Example

[**Paper**](https://arxiv.org/abs/2407.04620)
| [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
| [**Setup**](#environment-setup)
| [**Quick Start**](#quick-start)
| [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)

This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620). 
We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.


For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).

## Abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
have linear complexity, but their performance in long context is limited by the expressive power
of their hidden state. We propose a new class of sequence modeling layers with linear complexity
and an expressive hidden state. The key idea is to make the hidden state a machine learning
model itself, and the update rule a step of self-supervised learning. 

Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
and a two-layer MLP respectively. 

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
import torch

In [None]:
model_id = "meta-llama/Llama-2-7b-hf"

# Quantization Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
# Common Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig()
configuration

##### Model Arch Comparison

In [None]:
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM(configuration)
model.eval()

In [None]:
# For comparison with the normal llm model architecture
original = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, quantization_config=bnb_config)
original.eval()

##### Model Output Comparision

In [None]:
input_text = "Greeting from TTT!"

input_ids = tokenizer(input_text, return_tensors="pt").input_ids

In [None]:
# Inference
out_ids = model.generate(input_ids=input_ids, max_length=50)
out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
out_str