# Vi-T for Food11

## Import Libraries

In [None]:
%pip install transformers pygwalker wandb
%pip install "git+https://github.com/b-re-w/lattent.git#egg=lattent[pytorch]"

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from transformers import ViTConfig, ViTModel, PretrainedConfig

from lattent import TTTForCausalLM, TTTCausalLMOutput

from torchvision import transforms, datasets, utils

from os import path, rename, mkdir, listdir

from tqdm.notebook import tqdm
import pygwalker as pyg
import wandb

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

Multiple frameworks detected ['pytorch', 'jax']. Using pytorch. Set LATTENT_FRAMEWORK environment variable to override.


In [2]:
# WandB Initialization
wandb.init(project="food11_pytorch")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbrew[0m ([33mbrew-research[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Check GPU Availability

In [3]:
!nvidia-smi

Sun Dec  8 09:55:11 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  On   | 00000000:04:00.0 Off |                    0 |
| N/A   40C    P0    34W / 250W |   9600MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   39C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Defaul

In [4]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 6

if torch.cuda.is_available():
    #device = torch.device(f"cuda:{DEVICE_NUM}")
    torch.cuda.set_device(DEVICE_NUM)
    device = torch.device(f"cuda")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1
#print(f"INFO: Using device - {device}")
print(f"INFO: Using device - {device}:{DEVICE_NUM}")

INFO: Using device - cuda:6


## Load DataSets

In [5]:
from typing import Callable, Optional

datasets.utils.tqdm = tqdm


class FoodImageDataset(datasets.ImageFolder):
    download_method = datasets.utils.download_and_extract_archive
    download_url = "https://www.kaggle.com/api/v1/datasets/download/trolukovich/food11-image-dataset"

    def __init__(self, root: str, force_download: bool = True, train: bool = True, valid: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
        self.download(root, force=force_download)

        if train:
            if valid:
                root = path.join(root, "validation")
            else:
                root = path.join(root, "training")
        else:
            root = path.join(root, "evaluation")

        super().__init__(root=root, transform=transform, target_transform=target_transform)

    @classmethod
    def download(cls, root: str, force: bool = False):
        if force or not path.isfile(path.join(root, "archive.zip")):
            cls.download_method(cls.download_url, download_root=root, extract_root=root, filename="archive.zip")
            print("INFO: Dataset archive downloaded and extracted.")
        else:
            print("INFO: Dataset archive found in the root directory. Skipping download.")

    @property
    def df(self) -> pd.DataFrame:
        return pd.DataFrame(dict(path=[d[0] for d in self.samples], label=[self.classes[lb] for lb in self.targets]))

In [6]:
# Image Resizing and Tensor Conversion
IMG_SIZE = (224, 224)
IMG_NORM = dict(  # ImageNet Normalization
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

#feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
#vit_converter = lambda img, lb: (feature_extractor(images=img, return_tensors="pt"), lb)

augmenter = transforms.Compose([
    transforms.Resize(IMG_SIZE),  # Resize Image
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),  # Convert Image to Tensor
    transforms.Normalize(**IMG_NORM)  # Normalization
])

resizer = transforms.Compose([
    transforms.Resize(IMG_SIZE),  # Resize Image
    transforms.ToTensor(),  # Convert Image to Tensor
    transforms.Normalize(**IMG_NORM)  # Normalization
])

In [7]:
DATA_ROOT = path.join(".", "data", "food11")

train_dataset = FoodImageDataset(root=DATA_ROOT, force_download=False, train=True, transform=augmenter)
valid_dataset = FoodImageDataset(root=DATA_ROOT, force_download=False, valid=True, transform=resizer)
test_dataset = FoodImageDataset(root=DATA_ROOT, force_download=False, train=False, transform=resizer)

print(f"INFO: Dataset loaded successfully. Number of samples - Train({len(train_dataset)}), Valid({len(valid_dataset)}), Test({len(test_dataset)})")

INFO: Dataset archive found in the root directory. Skipping download.
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Dataset loaded successfully. Number of samples - Train(9866), Valid(3430), Test(3347)


In [8]:
# Train Dataset Distribution
pyg.walk(train_dataset.df)

Box(children=(HTML(value='\n<div id="ifr-pyg-000628bf3e8161a1BRtU87INdzsxWLSe" style="height: auto">\n    <hea…

<pygwalker.api.pygwalker.PygWalker at 0x7f1e87d769c0>

In [9]:
# Valid Dataset Distribution
pyg.walk(train_dataset.df)

Box(children=(HTML(value='\n<div id="ifr-pyg-000628bf3e87aa0fxauIF7dWX5iJs2cK" style="height: auto">\n    <hea…

<pygwalker.api.pygwalker.PygWalker at 0x7f1e87c63440>

In [10]:
# Test Dataset Distribution
pyg.walk(train_dataset.df)

Box(children=(HTML(value='\n<div id="ifr-pyg-000628bf3e8d1332RmBIgfSV8bFLvZWs" style="height: auto">\n    <hea…

<pygwalker.api.pygwalker.PygWalker at 0x7f1e87cbedb0>

## DataLoader

In [11]:
# Set Batch Size
BATCH_SIZE = 64, 64, 10

In [12]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE[0], shuffle=True, num_workers=cpu_cores)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE[1], shuffle=False, num_workers=cpu_cores)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE[2], shuffle=False, num_workers=cpu_cores)

INFO: Number of CPU cores - 48


In [13]:
# Image Visualizer
def imshow(image_list, mean=IMG_NORM['mean'], std=IMG_NORM['std']):
    np_image = np.array(image_list).transpose((1, 2, 0))
    de_norm_image = np_image * std + mean
    plt.figure(figsize=(10, 10))
    plt.imshow(de_norm_image)

In [14]:
images, targets = next(iter(train_loader))
grid_images = utils.make_grid(images, nrow=8, padding=10)
imshow(grid_images)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-4.053115898461357e-09..1.0000000236034394].


## Define Model

In [None]:
common_config = dict(
    image_size=IMG_SIZE[0],
    num_channels=3,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    qkv_bias=True
)

In [None]:
vit_config = dict(
    **common_config,
    patch_size=16,
    encoder_stride=16,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=768
)

In [None]:
ttt_config = dict(  # 2 GPU setting
    **common_config,
    patch_size=7,
    encoder_stride=2,
    hidden_size=768//2,
    num_hidden_layers=6,
    num_attention_heads=8,
    intermediate_size=768//2
)

In [None]:
ttt_config = dict(  # 1 GPU setting
    **common_config,
    patch_size=4,
    encoder_stride=2,
    hidden_size=64,
    num_hidden_layers=6,
    num_attention_heads=8,
    intermediate_size=64,
    scan_checkpoint_group_size=4
)

In [None]:
ttt_config = dict(  # 1 GPU setting
    **common_config,
    patch_size=4,           # Vi-T의 1/4
    encoder_stride=4,       # Vi-T의 1/4
    hidden_size=192,        # Vi-T의 1/4
    num_hidden_layers=6,    # Vi-T의 1/2
    num_attention_heads=8,
    intermediate_size=192,  # Vi-T의 1/4,
    scan_checkpoint_group_size=4
)

In [None]:
# ttt_config = dict(
#     patch_size=8,          # ViT의 1/2
#     encoder_stride=8,      # ViT의 1/2
#     hidden_size=192,       # ViT의 1/4
#     num_hidden_layers=6,   # ViT의 1/2
#     num_attention_heads=8,
#     intermediate_size=192  # ViT의 1/4
# )

### Normal Vi-T

In [None]:
class ViTImageClassifier(nn.Module):
    def __init__(self, config: ViTConfig, num_classes: int):
        super().__init__()
        self.config = config
        self.vit = ViTModel(config=config)
        self.fc = nn.Linear(self.config.hidden_size, num_classes)

    def forward(self, x):
        out = self.vit(x)
        pooled = out.pooler_output  # [batch_size, hidden_size]
        logits = self.fc(pooled)  # [batch_size, num_classes]
        return logits

In [None]:
# Initialize Model
USE_NORMAL = True
model = ViTImageClassifier(config=ViTConfig(**vit_config), num_classes=len(train_dataset.classes))
#model.to(device)
model

### TTT Vi-T

In [None]:
from lattent.models import ttt_pytorch


def scan(f, init, xs, out, checkpoint_group=0):
    """Optimized scan function with checkpoint groups"""
    carry = init
    if isinstance(xs, dict):
        num_items = len(next(iter(xs.values())))
    else:
        num_items = len(xs[0])

    @torch.compile(mode="reduce-overhead")
    def optimized_scan_fn(inputs, init_carry, start_idx, end_idx):
        current_carry = init_carry
        results = []

        for i in range(start_idx, end_idx):
            if isinstance(inputs, dict):
                current_input = {k: inputs[k][i] for k in inputs.keys()}
            else:
                current_input = inputs[i]

            current_carry, y = f(current_carry, current_input)
            results.append(y)

        return current_carry, torch.stack(results)

    # checkpoint_group 활용
    if checkpoint_group > 0:
        ckpt_every_n = max(1, num_items // checkpoint_group)

        for k in range(0, num_items, ckpt_every_n):
            end_idx = min(k + ckpt_every_n, num_items)
            # gradient checkpointing 적용
            carry = torch.utils.checkpoint.checkpoint(
                optimized_scan_fn,
                xs, carry, k, end_idx,
                use_reentrant=False
            )
            # 결과를 out에 복사
            _, results = optimized_scan_fn(xs, carry, k, end_idx)
            for i, res in enumerate(results):
                out[k + i] = res
    else:
        # checkpoint 없이 한번에 처리
        carry, results = optimized_scan_fn(xs, carry, 0, num_items)
        for i, res in enumerate(results):
            out[i] = res

    return carry, out


ttt_pytorch.scan = scan

In [None]:
from lattent.models import ttt_pytorch
from lattent import TTTLinear, TTTCache
ln_fused_l2_bwd, ln_fwd, tree_map = ttt_pytorch.ln_fused_l2_bwd, ttt_pytorch.ln_fwd, ttt_pytorch.tree_map


class EfficientTTTLinear(TTTLinear):
    def ttt(self, inputs, mini_batch_size, last_mini_batch_params_dict, cache_params: Optional[TTTCache] = None):
        if mini_batch_size is None:
            mini_batch_size = self.mini_batch_size

        # in this case, we are decoding
        if last_mini_batch_params_dict is None and cache_params is not None:
            last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)

        # [B, num_heads, num_mini_batch, mini_batch_size, head_dim]
        B = inputs["XV"].shape[0]
        num_mini_batch = inputs["XV"].shape[2]
        L = inputs["XV"].shape[2] * inputs["XV"].shape[3]
        device = inputs["XV"].device
        dtype = inputs["XV"].dtype

        use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0

        def compute_mini_batch(params_dict, inputs):
            # [B, nh, f, f], nh=num_heads, f=head_dim
            W1_init = params_dict["W1_states"]
            # [B, nh, 1, f]
            b1_init = params_dict["b1_states"]

            # [B,nh,K,f], K=mini_batch_size
            XQ_mini_batch = inputs["XQ"]
            XV_mini_batch = inputs["XV"]
            XK_mini_batch = inputs["XK"]
            # [B, nh, K, 1]
            eta_mini_batch = inputs["eta"]
            token_eta_mini_batch = inputs["token_eta"]
            ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]

            X1 = XK_mini_batch
            # [B,nh,K,f] @ [B,nh,f,f] -> [B,nh,K,f]
            Z1 = X1 @ W1_init + b1_init
            reconstruction_target = XV_mini_batch - XK_mini_batch

            ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
            ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
            # [B,nh,K,f]
            grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias)

            if use_dual_form:
                # [B,nh,K,K]
                Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
                # [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,f] -> [B,nh,K,f]
                b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
                # [B,nh,K,f] @ [B,nh,f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f]
                Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar

                last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
                # [B,nh,f,f] - [B,nh,f,K] @ [B,nh,K,f]
                W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
                # [B,nh,1,f]
                b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
                grad_W1_last = torch.zeros_like(W1_last)
                grad_b1_last = torch.zeros_like(b1_last)
            else:
                ttt_lr_eta_mini_batch = torch.broadcast_to(
                    ttt_lr_eta_mini_batch,
                    (
                        *ttt_lr_eta_mini_batch.shape[:2],
                        mini_batch_size,
                        mini_batch_size,
                    ),
                )

                # [B, nh, K, f, f]
                grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
                grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
                grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
                # [B, nh, K, f]
                grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
                grad_b1 = grad_b1 + params_dict["b1_grad"]

                W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
                b1_bar = b1_init - grad_b1 * token_eta_mini_batch

                # [B, nh, K, 1, f] @ [B, nh, K, f, f]
                Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar

                W1_last = W1_bar[:, :, -1]
                b1_last = b1_bar[:, :, -1:]
                grad_W1_last = grad_W1[:, :, -1]
                grad_b1_last = grad_b1[:, :, -1:]

            Z1_bar = ln_fwd(Z1_bar, ln_weight, ln_bias)
            XQW_mini_batch = XQ_mini_batch + Z1_bar

            last_param_dict = {
                "W1_states": W1_last,
                "b1_states": b1_last,
                "W1_grad": grad_W1_last,
                "b1_grad": grad_b1_last,
            }
            return last_param_dict, XQW_mini_batch

        if last_mini_batch_params_dict is not None:
            init_params_dict = last_mini_batch_params_dict
        else:
            init_params_dict = {
                "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
                "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
            }
            init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
            init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))

        # [B,num_heads, num_mini_batch, mini_batch_size, f] -> [num_mini_batch, B, num_heads, mini_batch_size, f]
        inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)

        # Initialize batch_params_dict and XQW_last
        batch_params_dict = init_params_dict
        XQW_last = None

        # Process mini-batches sequentially, only keeping the last result
        for i in range(num_mini_batch - 1):
            current_inputs = tree_map(lambda x: x[i], inputs)
            batch_params_dict, _ = compute_mini_batch(batch_params_dict, current_inputs)

        # Handle last mini-batch with padding if needed
        last_inputs = tree_map(lambda x: x[-1], inputs)
        current_size = last_inputs["XQ"].shape[-2]

        if current_size != mini_batch_size:
            # Pad last mini-batch inputs
            last_inputs = tree_map(
                lambda x: F.pad(x, (0, 0, 0, mini_batch_size - current_size), mode='constant', value=0),
                last_inputs
            )

        # Process final mini-batch
        batch_params_dict, XQW_last = compute_mini_batch(batch_params_dict, last_inputs)

        if cache_params is not None:
            cache_params.update(batch_params_dict, self.layer_idx, L)

        # Reshape final output (keeping the padding)
        B, H, K, D = XQW_last.shape
        XQW_last = XQW_last.permute(0, 2, 1, 3)  # [B, K, num_heads, head_dim]
        XQW_last = XQW_last.reshape(B, K, self.width)  # [B, K, C]

        return XQW_last, batch_params_dict


class EfficientTTTBlock(ttt_pytorch.Block):
    def __init__(self, config: ttt_pytorch.TTTConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.config = config

    def forward(self, hidden_states, attention_mask=None, position_ids=None, cache_params=None):
        if self.pre_conv:
            residual = hidden_states
            hidden_states = self.conv(hidden_states, cache_params=cache_params)
            hidden_states = residual + hidden_states

        # TTT Layer
        residual = hidden_states[:, -self.config.mini_batch_size:]  # 마지막 미니배치만큼만 residual connection
        hidden_states = self.seq_norm(hidden_states)
        hidden_states = self.seq_modeling_block(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache_params=cache_params,
        )
        hidden_states = residual + hidden_states  # 이제 크기가 맞을 것입니다

        # Feed-Forward Network
        residual = hidden_states
        hidden_states = self.ffn_norm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


#ttt_pytorch.TTTLinear = EfficientTTTLinear
#ttt_pytorch.Block = EfficientTTTBlock

In [None]:
class TTTVisionConfig(PretrainedConfig):
    """Vision TTT configuration."""

    model_type = "vision_ttt"

    def __init__(
            self,
            image_size=224,
            patch_size=16,
            num_channels=3,
            num_classes=1000,
            hidden_size=768,
            intermediate_size=3072,
            num_hidden_layers=12,
            num_attention_heads=12,
            hidden_act="gelu",
            initializer_range=0.02,
            rms_norm_eps=1e-6,
            pretraining_tp=1,
            use_cache=True,
            rope_theta=10000.0,
            mini_batch_size=16,
            use_gate=False,
            share_qk=False,
            ttt_layer_type="linear",
            ttt_base_lr=1.0,
            pre_conv=True,
            conv_kernel=4,
            scan_checkpoint_group_size=0,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels

        self.num_classes = num_classes
        self.vocab_size = num_classes

        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.pretraining_tp = pretraining_tp
        self.use_cache = use_cache
        self.rope_theta = rope_theta

        self.use_gate = use_gate
        self.share_qk = share_qk
        self.ttt_layer_type = ttt_layer_type
        self.ttt_base_lr = ttt_base_lr
        self.mini_batch_size = mini_batch_size

        self.pre_conv = pre_conv
        self.conv_kernel = conv_kernel
        self.scan_checkpoint_group_size = scan_checkpoint_group_size

        # Vision-specific attributes
        self.num_patches = (image_size // patch_size) ** 2

In [None]:
from einops.layers.torch import Rearrange


class PatchEmbedding(nn.Module):
    """Converts images into patches and projects them into the model dimension."""
    def __init__(self, config: TTTVisionConfig):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.num_patches = config.num_patches
        self.num_channels = config.num_channels

        # Patch splitting and flattening
        patch_dim = self.num_channels * self.patch_size * self.patch_size

        self.to_patch_embedding = nn.Sequential(
            # Split image into patches and flatten
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size),
            nn.LayerNorm(patch_dim),
            # Project to hidden dimension
            nn.Linear(patch_dim, config.hidden_size),
            nn.LayerNorm(config.hidden_size),
        )

        # Learnable position embeddings for patches
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, config.hidden_size))

        # [CLS] token embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))

        # Dropout
        self.dropout = nn.Dropout(0.1)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
        B = pixel_values.shape[0]

        # Convert image to patches -> (batch_size, num_patches, hidden_size)
        x = self.to_patch_embedding(pixel_values)

        # Add position embeddings
        x = x + self.pos_embedding

        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        return self.dropout(x)

In [None]:
from transformers.utils import ModelOutput
from typing import *


class TTTForVisionCausalLM(TTTForCausalLM):
    config_class = TTTVisionConfig

    def __init__(self, config: TTTVisionConfig):
        super().__init__(config)
        self.patch_embed = PatchEmbedding(config)

        # Initialize weights
        self.post_init()

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,  # [batch_size, seq_len]
            attention_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
            position_ids: Optional[torch.LongTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,  # [batch_size, seq_len, hidden_size]
            pixel_values: Optional[torch.FloatTensor] = None,  # [batch_size, channels, height, width]
            cache_params: Optional[Any] = None,
            labels: Optional[torch.LongTensor] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            use_cache: Optional[bool] = None,
            **kwargs
    ) -> Union[Tuple, TTTCausalLMOutput]:
        batch_size = None
        if pixel_values is not None:
            batch_size = pixel_values.shape[0]
        elif input_ids is not None:
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            batch_size = inputs_embeds.shape[0]

        if batch_size is None:
            raise ValueError("No valid input provided to determine batch size")

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        final_attention_mask = None
        final_embeds = None

        # 이미지 처리
        if pixel_values is not None:
            # [batch_size, num_patches + 1, hidden_size]
            image_embeds = self.patch_embed(pixel_values)
            image_attention_mask = torch.ones(
                (batch_size, image_embeds.shape[1]),
                dtype=torch.long,
                device=image_embeds.device
            )
            final_embeds = image_embeds  # [batch_size, num_patches + 1, hidden_size]
            final_attention_mask = image_attention_mask  # [batch_size, num_patches + 1]

        # 텍스트 처리 (선택적)
        if input_ids is not None or inputs_embeds is not None:
            if inputs_embeds is None:
                # [batch_size, seq_len] -> [batch_size, seq_len, hidden_size]
                inputs_embeds = self.get_input_embeddings()(input_ids)

            text_attention_mask = attention_mask if attention_mask is not None else torch.ones(
                (batch_size, inputs_embeds.shape[1]),
                dtype=torch.long,
                device=inputs_embeds.device
            )

            # 이미지가 있는 경우 텍스트와 연결
            if final_embeds is not None:
                # [batch_size, num_patches + 1 + seq_len, hidden_size]
                final_embeds = torch.cat([final_embeds, inputs_embeds], dim=1)
                # [batch_size, num_patches + 1 + seq_len]
                final_attention_mask = torch.cat([final_attention_mask, text_attention_mask], dim=1)
            else:
                final_embeds = inputs_embeds  # [batch_size, seq_len, hidden_size]
                final_attention_mask = text_attention_mask  # [batch_size, seq_len]

        if final_embeds is None:
            raise ValueError("Either pixel_values or input_ids/inputs_embeds must be provided")

        # TTT 모델 실행
        outputs = self.model(
            input_ids=None,
            attention_mask=final_attention_mask,
            position_ids=position_ids,
            inputs_embeds=final_embeds,
            cache_params=cache_params,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=use_cache,
        )

        # [batch_size, total_seq_len, hidden_size]
        hidden_states = outputs[0]

        # 로짓 계산
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)  # [batch_size, total_seq_len, vocab_size]
        else:
            # [batch_size, total_seq_len, vocab_size]
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        # Loss 계산
        loss = None
        if labels is not None and (input_ids is not None or inputs_embeds is not None):
            text_start = self.config.num_patches + 1 if pixel_values is not None else 0
            # [batch_size, seq_len - 1, vocab_size]
            shift_logits = logits[:, text_start:-1, :].contiguous()
            # [batch_size, seq_len - 1]
            shift_labels = labels[:, 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)  # [batch_size * (seq_len - 1), vocab_size]
            shift_labels = shift_labels.view(-1)  # [batch_size * (seq_len - 1)]
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return TTTCausalLMOutput(
            loss=loss,
            logits=logits,  # [batch_size, total_seq_len, vocab_size]
            cache_params=outputs.cache_params,
            hidden_states=outputs.hidden_states
        )

    def prepare_inputs_for_generation(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            cache_params: Optional[Any] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            pixel_values: Optional[torch.FloatTensor] = None,
            **kwargs
    ):
        model_inputs = {}

        # 캐시가 있는 경우 마지막 토큰만 처리
        if cache_params is not None:
            if input_ids is not None:
                input_ids = input_ids[:, -1].unsqueeze(-1)
            if attention_mask is not None:
                attention_mask = attention_mask[:, -1].unsqueeze(-1)
            if inputs_embeds is not None:
                inputs_embeds = inputs_embeds[:, -1:, :]

        # 입력 설정
        if inputs_embeds is not None and cache_params is None:
            model_inputs["inputs_embeds"] = inputs_embeds
        elif input_ids is not None:
            model_inputs["input_ids"] = input_ids

        # 첫 forward pass에서만 이미지 처리
        if pixel_values is not None and cache_params is None:
            model_inputs["pixel_values"] = pixel_values

        model_inputs.update(
            {
                "cache_params": cache_params,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )

        return model_inputs

    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            **kwargs
    ) -> Dict[str, Any]:
        model_kwargs["cache_params"] = outputs.get("cache_params", None)

        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            if attention_mask is not None:
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
                    dim=-1,
                )

        # 첫 번째 forward pass 이후에는 이미지 처리 불필요
        if "pixel_values" in model_kwargs:
            del model_kwargs["pixel_values"]

        return model_kwargs

In [None]:
# Initialize Model
USE_NORMAL = False
model = TTTForVisionCausalLM(config=TTTVisionConfig(
    **ttt_config, num_classes=len(train_dataset.classes), mini_batch_size=56
))
#model = nn.DataParallel(model, device_ids=[DEVICE_NUM, DEVICE_NUM+1])
model.to(device)

## Training Loop

In [None]:
from IPython.display import display
import ipywidgets as widgets

# Interactive Loss Plot Update
def create_plot():
    train_losses, valid_losses = [], []

    # Enable Interactive Mode
    plt.ion()

    # Loss Plot Setting
    fig, ax = plt.subplots(figsize=(6, 2))
    train_line, = ax.plot(train_losses, label="Train Loss", color="purple")
    valid_line, = ax.plot(valid_losses, label="Valid Loss", color="red")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Cross Entropy Loss")

    # Display Plot
    plot = widgets.Output()
    display(plot)

    def update_plot(train_loss=None, valid_loss=None):
        if train_loss is not None:
            train_losses.append(train_loss)
        if valid_loss is not None:
            valid_losses.append(valid_loss)
        train_line.set_ydata(train_losses)
        train_line.set_xdata(range(len(train_losses)))
        valid_line.set_ydata(valid_losses)
        valid_line.set_xdata(range(len(valid_losses)))
        ax.relim()
        ax.autoscale_view()
        with plot:
            plot.clear_output(wait=True)
            display(fig)

    return update_plot

In [None]:
def avg(lst):
    try:
        return sum(lst) / len(lst)
    except ZeroDivisionError:
        return 0

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 50
LEARNING_RATE = 1e-4, 1e-6
WEIGHT_DECAY = 0.05
MAX_GRAD_NORM = 1.0
USE_CACHE = False

criterion = nn.CrossEntropyLoss()
wandb.watch(model, criterion, log="all", log_freq=10)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE[0], weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE[1])

### Normal Vi-T

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))
if not USE_NORMAL:
    raise ValueError("Model is not set to Normal Vi-T. Please set USE_NORMAL to True.")

epochs = tqdm(range(EPOCHS), desc="Running Epochs")
with (tqdm(total=train_length, desc="Training") as train_progress,
        tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars
    update = create_plot()  # Create Loss Plot

    for epoch in epochs:
        train_progress.reset(total=train_length)
        valid_progress.reset(total=valid_length)

        train_acc, train_loss = 0, 0

        # Training
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() / train_length
            corrects = (torch.max(outputs, 1)[1] == targets.data).sum()
            train_acc += corrects / len(train_dataset)

            train_progress.update(1)
            if i != train_length-1: wandb.log({'Acc': corrects/len(inputs)*100, 'Loss': loss.item()})
            print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{i+1:2}/{train_length}], Acc: {corrects/len(inputs):.6%}, Loss: {loss.item():.6f}", end="")

        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], Acc: {train_acc:.6%}, Loss: {train_loss:.6f}", end="")
        val_acc, val_loss = 0, 0

        # Validation
        model.eval()
        with torch.no_grad():
            for inputs, targets in valid_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                val_loss += criterion(outputs, targets).item() / valid_length
                val_acc += (torch.max(outputs, 1)[1] == targets.data).sum() / (len(inputs) * valid_length)
                valid_progress.update(1)

        update(train_loss=train_loss, valid_loss=val_loss)
        wandb.log({'Train Acc': train_acc*100, 'Train Loss': train_loss, 'Val Acc': val_acc*100, 'Val Loss': val_loss})
        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], Acc: {train_acc:.6%}, Loss: {train_loss:.6f}, Valid Acc: {val_acc:.6%}, Valid Loss: {val_loss:.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")

In [None]:
if not path.isdir(path.join(".", "models")):
    import os
    os.mkdir(path.join(".", "models"))

# Model Save
save_path = path.join(".", "models", f"normal_vit_model.pt")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

### TTT Vi-T

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))

epochs = tqdm(range(EPOCHS), desc="Running Epochs")
with (tqdm(total=train_length, desc="Training") as train_progress,
    tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars
    update = create_plot()  # Create Loss Plot

    for epoch in epochs:
        train_progress.reset(total=train_length)
        valid_progress.reset(total=valid_length)

        train_acc, train_loss, val_acc, val_loss = [], [], [], []

        # Training
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(pixel_values=inputs, labels=targets, use_cache=False).logits[:, -1, :]

            loss = criterion(outputs, targets)
            loss.backward()
            if MAX_GRAD_NORM > 0:
                nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            scheduler.step()

            train_loss.append(loss.item())
            train_acc.append((torch.max(outputs, 1)[1] == targets.data).sum().item() / len(inputs))

            train_progress.update(1)
            if i != train_length-1: wandb.log({'Acc': avg(train_acc)*100, 'Loss': avg(train_loss)})
            print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{i+1:2}/{train_length}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}", end="")

        # Validation
        model.eval()
        cache_params = None  # Save the training cache state
        with torch.no_grad():
            for inputs, targets in valid_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                if USE_CACHE:
                    outputs = model(
                        pixel_values=inputs,
                        labels=targets,  # Pass labels to compute model's internal loss
                        cache_params=cache_params,
                        use_cache=True
                    )
                else:
                    outputs = model(pixel_values=inputs, labels=targets, use_cache=False)
                outputs, cache_params = outputs.logits[:, -1, :], outputs.cache_params

                val_loss.append(criterion(outputs, targets).item())
                val_acc.append((torch.max(outputs, 1)[1] == targets.data).sum().item() / len(inputs))
                valid_progress.update(1)

        update(train_loss=avg(train_loss), valid_loss=avg(val_loss))
        wandb.log({'Train Acc': avg(train_acc)*100, 'Train Loss': avg(train_loss), 'Val Acc': avg(val_acc)*100, 'Val Loss': avg(val_loss)})
        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}, Valid Acc: {avg(val_acc):.6%}, Valid Loss: {avg(val_loss):.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")

In [None]:
if not path.isdir(path.join(".", "models")):
    import os
    os.mkdir(path.join(".", "models"))

# Model Save
save_path = path.join(".", "models", f"ttt_vit_model.pt")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Model Evaluation

In [None]:
# Load Model
model_id = "normal_vit_model"

model.load_state_dict(torch.load(path.join(".", "models", f"{model_id}.pt")))
model.to(device)

In [None]:
corrects = 0
test_length = len(test_dataset)

model.eval()
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        corrects += (preds == targets.data).sum()
        print(f"Model Accuracy: {corrects/test_length:%}", end="\r")

In [None]:
corrects = 0
test_length = len(test_dataset)

model.eval()
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(pixel_values=inputs, labels=targets, use_cache=False)
        _, preds = torch.max(outputs.logits[:, -1, :], 1)
        corrects += (preds == targets.data).sum()
        print(f"Model Accuracy: {corrects/test_length:%}", end="\r")