# Example: ViT-base-16 on RTX4090 

In [None]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [None]:
import gc
import torch
from huggingface_hub import notebook_login, hf_hub_download

notebook_login()

## Build model with trainable gates

First we will use it without any pruning to compare latency with a ready-to-use optimized model from HawAda repo.

Then we will use gates to export our own "slim" model

In [None]:
from core.profiler import measure_latency_ms
from data.vision import build_imagenet_like_loaders, VisionDataConfig, _images_from_batch

from adapters.huggingface.vit import (
    ViTAdapter,
    ViTGatingConfig,
    ViTExportPolicy,
    ViTLatencyProxy,
    ViTProxyConfig,
    ViTGrid,
    vit_search_best_export,
    SlimViTForImageClassification
)

# Script to build config from a recipe
from examples.run_vit_optimize import build_from_recipe, make_vit_policy

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/vit_base_patch16_224.yaml")


In [None]:
gated_model = pack["student"].to(DEVICE) # ViT with attached gates

In [None]:
img_size = pack["img_size"]   # Image size
B = pack["batch_size"]        # Recommended batch size

## Get a pruned version from HawAda collection

In [None]:
slim_model_repo  = "hawada/vit-base-patch16-224-rtx4090-slim"

# Get a pruned model from Huggingface (optimised for RTX4090)
slim_model = SlimViTForImageClassification.from_pretrained(slim_model_repo).to(DEVICE)

### Measure latency

In [None]:
print(f"\nStarting benchmarking with batch size = {B}...")

full_model = ViTAdapter.export_keepall(gated_model).to(DEVICE)
shape = (B, 3, img_size, img_size)
    
mean_keep, p95_keep, _ = measure_latency_ms(full_model, shape, device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model, shape, device=DEVICE)

print(f"Keep-all: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms | Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms | \n"
      f"Speedup={(mean_keep-mean_slim)/max(1e-6,mean_keep)*100:.2f}%") 

## Download pre-trained gates for ViT on RTX4090

To export your custom pruned model, obtain the gates from HawAda repo to know which layers to prune and how

In [None]:
gated_model_repo = "hawada/vit-base-patch16-224-rtx4090-gated"

ckpt_path  = hf_hub_download(gated_model_repo, "pytorch_model.bin")
state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)

missing, unexpected = gated_model.load_state_dict(state_dict, strict=False)
print("missing:", len(missing), "unexpected:", len(unexpected))

In [None]:
# Check configuration for pruning and export
print("Policy for the probes during training:", pack["probe_policy"])
print("\nPolicy for the final pruning:", pack["export_policy"])

slim_model = ViTAdapter.export_pruned(gated_model, 
                                   policy=pack["export_policy"], 
                                   step=9999).to(DEVICE)

In [None]:
# Run the grid search for the best export parameters

num_heads = int(gated_model.config.num_attention_heads)
grid_cfg = pack["recipe"].get("export").get("search")
head_grid = tuple(grid_cfg.get("grid_multiple_groups"))
ffn_grid  = tuple(grid_cfg.get("ffn_snaps"))       

search = vit_search_best_export(
    gated_model.to(DEVICE),
    export_fn=ViTAdapter.export_pruned,
    num_heads=num_heads,
    step=9999,  # no warmup
    batch_shape=(B, 3, img_size, img_size),
    device=pack["device"],
    make_policy=make_vit_policy,
    grid=ViTGrid(head_multiple_grid=head_grid, ffn_snap_grid=ffn_grid),
)


slim_model = search.best_model.to(DEVICE)
print("Best export params:", search.best_params)

### Measure latency

In [None]:
print(f"\nStarting benchmarking with batch size = {B}...")
    
full_model = ViTAdapter.export_keepall(gated_model).to(DEVICE)
shape = (B, 3, img_size, img_size)
    
mean_keep, p95_keep, _ = measure_latency_ms(full_model, shape, device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model, shape, device=DEVICE)

print(f"Keep-all: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms | Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms | \n"
      f"Speedup={(mean_keep-mean_slim)/max(1e-6,mean_keep)*100:.2f}%") 

# Distillation

This slim model needs fine-tuning for your downstream task. In this notebook, we use 10-class classification head for ViT and ImageNet dataset.

[!] For this step you can use any other device

In [None]:
from core.finetune import FinetuneConfig, finetune_student
from core.distill import KDConfig

teacher = pack["teacher"].to(DEVICE) # Another instance of ViT for optional training / fine tuning

ft_epochs = 1 # int(pack["recipe"].get("finetune").get("epochs")
print(f"\nStarting fine tuning for {ft_epochs} epochs...")

ft_cfg = FinetuneConfig(
    epochs=ft_epochs,
    lr=float(pack["recipe"].get("finetune").get("lr")),
    kd=KDConfig(**pack["recipe"].get("trainer").get("kd")),
    amp=bool(pack["recipe"].get("trainer").get("amp")),
    device=pack["device"],
    log_every=200,
)

slim = finetune_student(
    slim_model,
    teacher,
    pack["train_loader"],
    get_student_logits=pack["get_s"],
    get_teacher_logits=pack["get_t"],
    cfg=ft_cfg,
    val_loader=pack["val_loader"],
    save_best=True
)


# torch.save(slim, os.path.join(args.outdir, "vit_slim_finetune.pth"))

Now you your own ViT version, optimized for RTX4090 and fine-tuned on ImageNet.

HawAda framework allows you to optimize the model for other GPUs; To do it, follow the following steps:

* Create your own recipe
* Attach HawAda adapter to the model
* Run the gates training on your device (see ResNet notebook)
* Export pruned model after gates are trained
* Run grid search to choose the best shapes
* Run distillation (fine tuning) for your downstream task