# Performance Verification and Transferability Evaluation of TTT Layer

:reference: https://github.com/test-time-training/ttt-lm-pytorch

:suggesting paper: https://arxiv.org/abs/2407.04620

## Import Libraries

In [None]:
import os

import torch
import torchaudio
import torchvision

import numpy as np
import pandas as pd

from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from huggingface_hub import hf_hub_download
from tqdm.notebook import tqdm

### Check GPU Availability

In [2]:
!nvidia-smi

Thu Aug  8 09:40:03 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  On   | 00000000:04:00.0 Off |                    0 |
| N/A   40C    P0    34W / 250W |   5661MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   39C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Defaul

In [3]:
# Set CUDA Device
device_num = 1

if torch.cuda.is_available() and device_num != -1:
    torch.cuda.set_device(device_num)
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    device_num = -1  # cpu
print(f"INFO: Using device - {device}:{device_num}")

INFO: Using device - cuda:1


## 0. From Quick Start Example

[**Paper**](https://arxiv.org/abs/2407.04620)
| [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
| [**Setup**](#environment-setup)
| [**Quick Start**](#quick-start)
| [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)

This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620). 
We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.


For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).

## Abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
have linear complexity, but their performance in long context is limited by the expressive power
of their hidden state. We propose a new class of sequence modeling layers with linear complexity
and an expressive hidden state. The key idea is to make the hidden state a machine learning
model itself, and the update rule a step of self-supervised learning. 

Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
and a two-layer MLP respectively. 

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from ttt.lm.pytorch import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
import torch

In [5]:
model_id = "meta-llama/Llama-2-7b-hf"

# Quantization Config
DO_QUANTIZATION = False
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

original_model_params = dict(
    low_cpu_mem_usage=True, quantization_config=bnb_config, device_map=device.type
) if DO_QUANTIZATION else dict(low_cpu_mem_usage=True, device_map=device.type)

In [6]:
# Common Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig()
configuration

TTTConfig {
  "bos_token_id": 1,
  "conv_kernel": 4,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5504,
  "max_position_embeddings": 2048,
  "mini_batch_size": 16,
  "model_type": "ttt",
  "num_attention_heads": 32,
  "num_hidden_layers": 24,
  "pre_conv": false,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "scan_checkpoint_group_size": 0,
  "share_qk": false,
  "transformers_version": "4.44.0",
  "ttt_base_lr": 1.0,
  "ttt_layer_type": "linear",
  "use_cache": false,
  "use_gate": false,
  "vocab_size": 32000
}

##### Model Arch Comparison

In [8]:
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM(configuration)
model.to(device)
model.eval()

TTTForCausalLM(
  (model): TTTModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-23): 24 x Block(
        (seq_modeling_block): TTTLinear(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): RotaryEmbedding()
          (post_norm): LayerNorm((2048,), eps=1e-06, elementwise_affine=True)
        )
        (mlp): SwiGluMLP(
          (gate_proj): Linear(in_features=2048, out_features=5504, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5504, bias=False)
          (down_proj): Linear(in_features=5504, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (seq_norm): RMSNorm()
        (ffn_norm): RMSNorm()
      )
    )
    (norm): RMSNorm()

In [8]:
# For comparison with the normal llm model architecture
original = AutoModelForCausalLM.from_pretrained(model_id, cache_dir="./.cache", **original_model_params)
original.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB. GPU  has a total capacity of 15.90 GiB of which 113.75 MiB is free. Including non-PyTorch memory, this process has 15.79 GiB memory in use. Of the allocated memory 15.15 GiB is allocated by PyTorch, and 1.39 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

##### Model Output Comparison

In [10]:
input_text = "Greeting from TTT! Please generate a text for me only in Korean."

inf_params = dict(
    input_ids=tokenizer(input_text, return_tensors="pt").to(device).input_ids,
    max_length=50,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    temperature=0.7,
    num_return_sequences=1,
    pad_token_id=tokenizer.eos_token_id
)

In [11]:
# Inference using TTT
with torch.no_grad():
    out_ids = model.generate(**inf_params)
    print(*tokenizer.batch_decode(out_ids, skip_special_tokens=True))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Greeting from TTT! Please generate a text for me only in Korean.ühr opport Ton--------UTEinners ultimately FF beyondലnete Nord чу NC Login)`, reasonableívían]],Per MS Einwo abolt nearcost Abd--+ Whitearlo vector gab


In [None]:
# Inference using the Original Model
with torch.no_grad():
    out_ids = model.generate(**inf_params)
    print(*tokenizer.batch_decode(out_ids, skip_special_tokens=True))

## 1. [Vision][PyTorch] Training Speed & Accuracy Comparison (1)
    - replace attention layer with TTT layer from ResNet-like model
    - start from random initialized weights

## 2. [Audio][PyTorch] Training Speed & Accuracy Comparison (2)
    - replace attention layer with TTT layer from ResNet-like model
    - start from random initialized weights
    - evaluate the music genre classification performance (using dataset below)
    - https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=data&dataSetSn=71544

## 3. [Vision][JAX] Training Speed & Accuracy Comparison (3)
    - replace attention layer with TTT layer from Vi-T model

### 3-1. Start from random initialized weights

In [1]:
import jax
jax.default_backend()

'gpu'

### 3-2. Use pretrained weights

## 4. [Vision][JAX] Weight Transferability Evaluation (1)
    - replace attention layer with TTT layer from a Pre-Trained Vi-T model and transfer weights

## 5. [NLP][JAX] Weight Transferability Evaluation (2)
    - replace attention layer with TTT layer from a Llama3.1 model and transfer weights
    - evaluate the performance via perplexity / likelihood

## 6. [NLP][JAX] Weight Transferability Evaluation (3)
    - replace attention layer with TTT layer from a Llama3.1 model and transfer weights
    - evaluate the sentence domain classification performance (using dataset below)
    - https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=data&dataSetSn=71633