# Example: optimize ResNet-18 for a target GPU

In [1]:
# !pip install torch torchvision torchaudio -U
# !pip install numpy

In [2]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [3]:
# !pip install kagglehub
# import kagglehub
# path = kagglehub.dataset_download("lyfora/processed-imagenet-dataset-224")
# print("Path to dataset files:", path)

In [4]:
# os.makedirs("/data/imagenet100-224/train", exist_ok=True)
# !mv /root/.cache/kagglehub/datasets/lyfora/processed-imagenet-dataset-224/versions/1 /data/imagenet100-224/train

In [5]:
# Load / batch helpers
from core.utils import load_yaml, _images_from_batch

# Script to build config from a recipe
from examples.run_resnet_optimize import build_from_recipe

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/resnet18_imagenet224.yaml")

print("Model name:", pack['recipe']['model']['name'])
print("Dataset root:", pack['recipe']['data']['train_root'])

Model name: resnet18
Dataset root: /data/imagenet100-224/train


In [6]:
## Download snaphots: gated model + slim (pruned) model

In [7]:
# import transformers
# from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
gated_model_repo  = "hawada/resnet18-rtx4090-gated"
gated_local_dir   = "ckpt/resnet/gated"

# Download pre-trained gated model (full weights + tuned gates for pruning)
# snapshot_download(repo_id=gated_model_repo, local_dir=gated_local_dir, repo_type="model")

In [9]:
slim_model_repo  = "hawada/resnet18-rtx4090-slim"
slim_local_dir   = "ckpt/resnet/slim"

# Download pre-trained slim model (already pruned)
# snapshot_download(repo_id=slim_model_repo, local_dir=slim_local_dir, repo_type="model")

In [10]:
from notebooks.ckpt.resnet.slim.minimal_resnet_loader import load_student

# Load slim model using a custom loader
slim_model  = load_student(slim_local_dir+"/pytorch_model.bin", device=DEVICE)

# Load gateg model using a custom loader
gated_model = load_student(gated_local_dir+"/pytorch_model.bin", device=DEVICE)

In [11]:
from adapters.torchvision.resnet import ResNetAdapter
from core.profiler import measure_latency_ms

# B = pack["batch_size"]; H = W = pack["img_size"]

# print(f"Starting benchmarking with batch size = {B}...\n")
# mean_keep, p95_keep, _ = measure_latency_ms(ResNetAdapter.export_keepall(gated_model), (B, 3, H, W), device=DEVICE)
# mean_slim, p95_slim, _ = measure_latency_ms(slim_model, (B, 3, H, W), device=DEVICE)

# print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
# print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
# if mean_keep > 0:
#     print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

In [12]:
## Prune gated model with a custom export policy

In [13]:
from adapters.torchvision.resnet import ResNetExportPolicy
from core.export import Rounding as CoreRounding

# M = 2    # Multiples of this number will be used for pruned layers shapes
# K = 0.1  # Minimum ratio of kept shapes

# policy = ResNetExportPolicy(
#     warmup_steps=0,
#     rounding=CoreRounding(floor_groups=1, multiple_groups=M, min_keep_ratio=K),
#     min_keep_ratio=K,
# )

# # Obtain a new pruned model
# slim_model_new = ResNetAdapter.export_pruned(gated_model, policy, step=9999).to(DEVICE)

# print(f"Starting benchmarking with batch size = {B}...\n")
# mean_keep, p95_keep, _ = measure_latency_ms(ResNetAdapter.export_keepall(gated_model), (B, 3, H, W), device=DEVICE)
# mean_slim, p95_slim, _ = measure_latency_ms(slim_model_new, (B, 3, H, W), device=DEVICE)

# print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
# print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
# if mean_keep > 0:
#     print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

In [14]:
## Fine-tune a new slim model (distillation from teacher)

In [15]:
# from core.finetune import FinetuneConfig, finetune_student, recalibrate_bn_stats
# from core.distill import KDConfig

# # Get the suggested fine-tuning config from the recipe package

# teacher = pack["teacher"] # Distallation from a bigger ResNet by default
# ft_epochs = 1 # int(pack["recipe"].get("finetune", {}).get("epochs", 10))
# learning_rate = float(pack["recipe"].get("finetune", {}).get("lr", 3e-4))
# weight_decay = float(pack["recipe"].get("finetune", {}).get("wd", 1e-5))

# train_loader = pack["train_loader"]
# val_loader   = pack["val_loader"]

# # Build a config for fine-tuning
# ft_cfg = FinetuneConfig(
#     epochs=ft_epochs,
#     lr=learning_rate,
#     wd=weight_decay,
#     kd=KDConfig(**pack["recipe"].get("trainer", {}).get("kd", {})),
#     amp=bool(pack["recipe"].get("trainer", {}).get("amp", True)),
#     mse_weight = float(pack["recipe"].get("trainer", {}).get("mse_weight", 0.0)),
#     device=DEVICE,
#     log_every=50,
# )

# print(f"\nStarting fine tuning for {ft_epochs} epochs, LR={learning_rate} ...\n")
# slim_finetuned = finetune_student(
#     slim_model_new,
#     teacher,
#     train_loader,
#     get_student_logits=lambda m, batch: m(_images_from_batch(batch)),
#     get_teacher_logits=lambda m, batch: m(_images_from_batch(batch)).detach(),
#     cfg=ft_cfg,
#     val_loader=val_loader,
# )

# # Recalibrate BatchNorm stats before export
# recalibrate_bn_stats(slim_finetuned, train_loader, max_batches=1000)

# # # Now you have a faster model that behaves well on the selected dataset
# # import torch
# # out_path = "resnet18_slim_new.pth"
# # torch.save(slim_finetuned.state_dict(), out_path)
# # print("Saved pruned model to", out_path)



In [16]:
## Train a new gated model on your GPU

In [None]:
import gc

del gated_model, slim_model, slim_model_new, slim_finetuned
gc.collect()

In [17]:
from core.train import TrainerConfig


os.makedirs("runs/resnet18", exist_ok=True)

# Student to prune and teracher to distill
student = pack["student"] # A new instance of ResNet-18
teacher = pack["teacher"] # A bigger ResNet-50

# Loaders built via ImageNet224 dataset
train_loader = pack["train_loader"]
val_loader   = pack["val_loader"]

# Differentiable proxy to estimate latency
proxy = pack["proxy"]

# Pruning policy
export_policy = pack["export_policy"]

# Training configuration
train_cfg = pack["trainer_cfg"]


print("Pruning policy: multiple =", export_policy.rounding.multiple_groups)
print("Pruning policy: min keep ratio =", export_policy.rounding.min_keep_ratio)

print("\nTraining conf: early stopping patience =", train_cfg.early_stopping_patience)
print("Training conf: LR for gates =", train_cfg.lr_gate)
print("Training conf: LR for linear layers =", train_cfg.lr_linear)
print("Training conf: LR for affine layers =", train_cfg.lr_affine)

Pruning policy: multiple = 1
Pruning policy: min keep ratio = 0.8

Training conf: early stopping patience = 3
Training conf: LR for gates = 0.01
Training conf: LR for linear layers = 0.0001
Training conf: LR for affine layers = 0.0003


In [18]:
from core.train import LagrangeTrainer

# Build training configuration from the recipe package
trainer = LagrangeTrainer(
    student=student,
    teacher=teacher,
    proxy=proxy,  
    adapter_get_student_logits=lambda m, batch: m(_images_from_batch(batch)),
    adapter_get_teacher_logits=lambda m, batch: m(_images_from_batch(batch)).detach(),
    adapter_export_keepall=ResNetAdapter.export_keepall,
    adapter_export_pruned=lambda m, pol, step: ResNetAdapter.export_pruned(m, pol, step),
    export_policy=export_policy,
    cfg=train_cfg,
)

In [None]:
EPOCHS = 1

lambdas = []
for ep in range(EPOCHS):
    print(f"=== Epoch {ep+1}/{EPOCHS} ===")
    lam = trainer.train_epoch(train_loader)
    lambdas.append(lam)

    last = lambdas[:train_cfg.early_stopping_patience]
    last = [x for x in last if x < train_cfg.early_stopping_lambda]
    if len(last) == train_cfg.early_stopping_patience:
        print(f"Early stopping: lambda < {train_cfg.early_stopping_lambda} for last {train_cfg.early_stopping_patience} epochs")
        break


# # Save gated model        
# out_path = os.path.join("runs/resnet", "resnet18_gated.pth")
# torch.save(student.state_dict(), out_path)
# print("Saved gated model to", out_path)

In [None]:
## Grid search the best pruning multiples on your GPU

In [21]:
from examples.run_resnet_optimize import grid_search_export

print("Running export grid search...")

# Multiples - set up accorging to your GPU or use recipe
multiples = pack["recipe"].get("export").get("grid_multiple_groups")

res = grid_search_export(
    student, # Gated model with pre-trained gate weights
    device=DEVICE,
    img_size=pack["img_size"],
    B=pack["batch_size"],
    multiples=multiples,
    min_keep_ratio=float(pack["recipe"].get("trainer").get("lagrange").get("min_keep_ratio")),
)


# Select the fastest model
slim = res["best_model"]
print("Best export params:", res["best_params"])

# BatchNorm recalibration before saving the model
print("\nRecalibrating BN stats on the slim model...")
ResNetAdapter.bn_recalibration(slim, train_loader, num_batches=1000, device=DEVICE)
print("Done")

# # Save model
# out_path = os.path.join("runs/resnet", "resnet18_slim.pth")
# torch.save(slim, out_path)
# print("Saved pruned model to", out_path)

Running export grid search...
[0/4] multiple_groups=2 | mean_ms=70.899
[1/4] multiple_groups=3 | mean_ms=70.962
[2/4] multiple_groups=4 | mean_ms=71.049
[3/4] multiple_groups=5 | mean_ms=71.202
Best export params: {'multiple_groups': 2}
Recalibrating BN stats on the slim model...
Done


In [22]:
B = pack["batch_size"]; H = W = pack["img_size"]

print(f"Starting benchmarking with batch size = {B}...\n")
mean_keep, p95_keep, _ = measure_latency_ms(ResNetAdapter.export_keepall(student), (B, 3, H, W), device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim, (B, 3, H, W), device=DEVICE)

print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
if mean_keep > 0:
    print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

Starting benchmarking with batch size = 512...

Base: mean=93.156ms p95=93.591ms
Slim: mean=71.559ms p95=71.835ms

Speedup=23.18%


In [None]:
# Now you have a slim model optimized for your GPU. 
# You can fine-tune it on your downstream task with a teacher, as we did above

In [None]:
# Export to HuggingFace

In [None]:
# python -m tools.export_to_hf \ 
# --task resnet \ 
# --base_id torchvision/resnet18 \ 
# --student_ckpt runs/resnet18/resnet18_gated.pth \ 
# --slim_ckpt runs/resnet18/resnet18_slim.pth \ 
# --repo_gated hawada/resnet18-gated \ 
# --repo_slim hawada/resnet18-slim \ 
# --token hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx \ 
# --include_code adapters/torchvision,core,gates