In [1]:
# Example: optimize ViT-base-16 for RTX4090

In [2]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [3]:
import gc
import torch
from huggingface_hub import notebook_login, hf_hub_download

# notebook_login()

In [4]:
slim_model_repo  = "hawada/vit-base-patch16-224-rtx4090-slim"
gated_model_repo = "hawada/vit-base-patch16-224-rtx4090-gated"

In [5]:
from core.profiler import measure_latency_ms
from data.vision import build_imagenet_like_loaders, VisionDataConfig, _images_from_batch

from adapters.huggingface.vit import (
    ViTAdapter,
    ViTGatingConfig,
    ViTExportPolicy,
    ViTLatencyProxy,
    ViTProxyConfig,
    ViTGrid,
    vit_search_best_export,
    SlimViTForImageClassification
)

# Script to build config from a recipe
from examples.run_vit_optimize import build_from_recipe, make_vit_policy

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/vit_base_patch16_224.yaml")


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Latency proxy scale set to: 9.856479e-09
Keep-all latency: 169.4334 ms on batch size = 64
Target latency = 169.4334 * 0.60 ~ 101.6600


In [6]:
pack.keys()

dict_keys(['student', 'teacher', 'student_head', 'teacher_head', 'adapter', 'export_policy', 'probe_policy', 'proxy', 'trainer_cfg', 'train_loader', 'val_loader', 'get_s', 'get_t', 'img_size', 'batch_size', 'device', 'recipe'])

In [7]:
gated_model = pack["student"].to(DEVICE) # ViT with attached gates

In [8]:
img_size = pack["img_size"]   # Image size
B = pack["batch_size"]        # Recommended batch size

In [None]:
# Get a pruned model from Huggingface (optimised for RTX4090)
slim_model = SlimViTForImageClassification.from_pretrained(slim_model_repo).to(DEVICE)

In [10]:
print(f"\nStarting benchmarking with batch size = {B}...")

full_model = ViTAdapter.export_keepall(gated_model).to(DEVICE)
shape = (B, 3, img_size, img_size)
    
mean_keep, p95_keep, _ = measure_latency_ms(full_model, shape, device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model, shape, device=DEVICE)

print(f"Keep-all: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms | Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms | \n"
      f"Speedup={(mean_keep-mean_slim)/max(1e-6,mean_keep)*100:.2f}%") 


Starting benchmarking with batch size = 64...
Keep-all: mean=170.051ms p95=184.496ms | Slim: mean=148.580ms p95=162.181ms | 
Speedup=12.63%


In [11]:
# Get your own slim model with custom export parameters

In [12]:
# Download pre-trained gates for ViT on RTX4090

ckpt_path  = hf_hub_download(gated_model_repo, "pytorch_model.bin")
state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)

missing, unexpected = gated_model.load_state_dict(state_dict, strict=False)
print("missing:", len(missing), "unexpected:", len(unexpected))

missing: 0 unexpected: 0


In [13]:
# Check configuration for pruning and export
print("Policy for the probes during training:", pack["probe_policy"])
print("\nPolicy for the final pruning:", pack["export_policy"])

slim_model = ViTAdapter.export_pruned(gated_model, 
                                   policy=pack["export_policy"], 
                                   step=9999).to(DEVICE)

Policy for the probes during training: ExportPolicy(warmup_steps=0, rounding=Rounding(floor_groups=1, multiple_groups=1, min_keep_ratio=0.0))

Policy for the final pruning: ExportPolicy(warmup_steps=150, rounding=Rounding(floor_groups=1, multiple_groups=1, min_keep_ratio=0.0))


In [14]:
# Run the grid search for the best export parameters

num_heads = int(gated_model.config.num_attention_heads)
grid_cfg = pack["recipe"].get("export").get("search")
head_grid = tuple(grid_cfg.get("grid_multiple_groups"))
ffn_grid  = tuple(grid_cfg.get("ffn_snaps"))       

search = vit_search_best_export(
    gated_model.to(DEVICE),
    export_fn=ViTAdapter.export_pruned,
    num_heads=num_heads,
    step=9999,  # no warmup
    batch_shape=(B, 3, img_size, img_size),
    device=pack["device"],
    make_policy=make_vit_policy,
    grid=ViTGrid(head_multiple_grid=head_grid, ffn_snap_grid=ffn_grid),
)


slim_model = search.best_model.to(DEVICE)
print("Best export params:", search.best_params)

[0/11] head_multiple 2 | ffn_snap 1 | mean_ms = 143.0825799560547
Best export params: {'head_multiple': 2, 'ffn_snap': 1}


In [15]:
print(f"\nStarting benchmarking with batch size = {B}...")
    
full_model = ViTAdapter.export_keepall(gated_model).to(DEVICE)
shape = (B, 3, img_size, img_size)
    
mean_keep, p95_keep, _ = measure_latency_ms(full_model, shape, device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model, shape, device=DEVICE)

print(f"Keep-all: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms | Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms | \n"
      f"Speedup={(mean_keep-mean_slim)/max(1e-6,mean_keep)*100:.2f}%") 


Starting benchmarking with batch size = 64...
Keep-all: mean=172.554ms p95=185.943ms | Slim: mean=145.194ms p95=156.236ms | 
Speedup=15.86%


In [16]:
# This slim model needs fine-tuning for downstream task

In [17]:
from core.finetune import FinetuneConfig, finetune_student
from core.distill import KDConfig

teacher = pack["teacher"].to(DEVICE) # Another instance of ViT for optional training / fine tuning

ft_epochs = 1 # int(pack["recipe"].get("finetune").get("epochs")
print(f"\nStarting fine tuning for {ft_epochs} epochs...")

ft_cfg = FinetuneConfig(
    epochs=ft_epochs,
    lr=float(pack["recipe"].get("finetune").get("lr")),
    kd=KDConfig(**pack["recipe"].get("trainer").get("kd")),
    amp=bool(pack["recipe"].get("trainer").get("amp")),
    device=pack["device"],
    log_every=200,
)

slim = finetune_student(
    slim_model,
    teacher,
    pack["train_loader"],
    get_student_logits=pack["get_s"],
    get_teacher_logits=pack["get_t"],
    cfg=ft_cfg,
    val_loader=pack["val_loader"],
    save_best=True
)


# torch.save(slim, os.path.join(args.outdir, "vit_slim_finetune.pth"))


Starting fine tuning for 1 epochs...
[AMP] Mode=BF16 | GradScaler=OFF | KD: T=4.0 alpha=2.0 | LR=0.0001 WD=0.0 | Trainable params=69,938,404
Step 200/1625 (ep 1/1): running loss = 4.3769


KeyboardInterrupt: 