# Example: optimize ResNet-18 for a target GPU

In [1]:
# !pip install torch torchvision torchaudio -U
# !pip install numpy

In [19]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [21]:
# !pip install kagglehub
# import kagglehub
# path = kagglehub.dataset_download("lyfora/processed-imagenet-dataset-224")
# print("Path to dataset files:", path)

In [22]:
# os.makedirs("/data/imagenet100-224/train", exist_ok=True)
# !mv /root/.cache/kagglehub/datasets/lyfora/processed-imagenet-dataset-224/versions/1 /data/imagenet100-224/train

In [27]:
# Load / batch helpers
from core.utils import load_yaml, _images_from_batch

# Script to build config from a recipe
from examples.run_resnet_optimize import build_from_recipe

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/resnet18_imagenet224.yaml")

print("Model name:", pack['recipe']['model']['name'])
print("Dataset root:", pack['recipe']['data']['train_root'])

Model name: resnet18
Dataset root: /data/imagenet100-224/train


In [None]:
## Download snaphots: gated model + slim (pruned) model

In [5]:
# import transformers
# from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download

In [None]:
gated_model_repo  = "hawada/resnet18-rtx4090-gated"
gated_local_dir   = "ckpt/resnet/gated"

# Download pre-trained gated model (full weights + tuned gates for pruning)
# snapshot_download(repo_id=gated_model_repo, local_dir=gated_local_dir, repo_type="model")

In [7]:
slim_model_repo  = "hawada/resnet18-rtx4090-slim"
slim_local_dir   = "ckpt/resnet/slim"

# Download pre-trained slim model (already pruned)
# snapshot_download(repo_id=slim_model_repo, local_dir=slim_local_dir, repo_type="model")

In [11]:
from notebooks.ckpt.resnet.slim.minimal_resnet_loader import load_student

# Load slim model using a custom loader
slim_model  = load_student(slim_local_dir+"/pytorch_model.bin", device=DEVICE)

# Load gateg model using a custom loader
gated_model = load_student(gated_local_dir+"/pytorch_model.bin", device=DEVICE)

In [33]:
from adapters.torchvision.resnet import ResNetAdapter
from core.profiler import measure_latency_ms

B = pack["batch_size"]; H = W = pack["img_size"]

print(f"Starting benchmarking with batch size = {B}...\n")
mean_keep, p95_keep, _ = measure_latency_ms(ResNetAdapter.export_keepall(gated_model), (B, 3, H, W), device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model, (B, 3, H, W), device=DEVICE)

print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
if mean_keep > 0:
    print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

Starting benchmarking with batch size = 512...

Base: mean=92.094ms p95=92.516ms
Slim: mean=70.768ms p95=70.941ms

Speedup=23.16%


In [None]:
## Prune gated model with a custom export policy

In [36]:
from adapters.torchvision.resnet import ResNetExportPolicy
from core.export import Rounding as CoreRounding

M = 2    # Multiples of this number will be used for pruned layers shapes
K = 0.1  # Minimum ratio of kept shapes

policy = ResNetExportPolicy(
    warmup_steps=0,
    rounding=CoreRounding(floor_groups=1, multiple_groups=M, min_keep_ratio=K),
    min_keep_ratio=K,
)

# Obtain a new pruned model
slim_model_new = ResNetAdapter.export_pruned(gated_model, policy, step=9999).to(DEVICE)

print(f"Starting benchmarking with batch size = {B}...\n")
mean_keep, p95_keep, _ = measure_latency_ms(ResNetAdapter.export_keepall(gated_model), (B, 3, H, W), device=DEVICE)
mean_slim, p95_slim, _ = measure_latency_ms(slim_model_new, (B, 3, H, W), device=DEVICE)

print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
if mean_keep > 0:
    print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

Starting benchmarking with batch size = 512...

Base: mean=91.822ms p95=92.367ms
Slim: mean=53.908ms p95=54.082ms

Speedup=41.29%


In [None]:
## Fine-tune a new slim model (distillation from teacher)

In [None]:
from core.finetune import FinetuneConfig, finetune_student, recalibrate_bn_stats
from core.distill import KDConfig

# Get the suggested fine-tuning config from the recipe package

teacher = pack["teacher"] # Distallation from a bigger ResNet by default
ft_epochs = 1 # int(pack["recipe"].get("finetune", {}).get("epochs", 10))
learning_rate = float(pack["recipe"].get("finetune", {}).get("lr", 3e-4))
weight_decay = float(pack["recipe"].get("finetune", {}).get("wd", 1e-5))

train_loader = pack["train_loader"]
val_loader   = pack["val_loader"]

# Build a config for fine-tuning
ft_cfg = FinetuneConfig(
    epochs=ft_epochs,
    lr=learning_rate,
    wd=weight_decay,
    kd=KDConfig(**pack["recipe"].get("trainer", {}).get("kd", {})),
    amp=bool(pack["recipe"].get("trainer", {}).get("amp", True)),
    mse_weight = float(pack["recipe"].get("trainer", {}).get("mse_weight", 0.0)),
    device=DEVICE,
    log_every=50,
)

print(f"\nStarting fine tuning for {ft_epochs} epochs, LR={learning_rate} ...\n")
slim_finetuned = finetune_student(
    slim_model_new,
    teacher,
    train_loader,
    get_student_logits=lambda m, batch: m(_images_from_batch(batch)),
    get_teacher_logits=lambda m, batch: m(_images_from_batch(batch)).detach(),
    cfg=ft_cfg,
    val_loader=val_loader,
)

# Recalibrate BatchNorm stats before export
recalibrate_bn_stats(slim_finetuned, train_loader, max_batches=1000)

# # Now you have a faster model that behaves well on the selected dataset
# import torch
# out_path = "resnet18_slim_new.pth"
# torch.save(slim_finetuned.state_dict(), out_path)
# print("Saved pruned model to", out_path)

