# ResNet20 1W1A (BinaryConnect) + Pruning

Notebook para executar seu fluxo completo com:
- ResNet20 binarizada (1W1A) via `BinaryConnect`
- Pruning estruturado, nao estruturado ou combinado
- Treino, avaliacao e salvamento de checkpoint


In [None]:
# Execute esta celula uma vez para instalar dependencias no kernel atual.
%pip install -U pip
%pip install numpy matplotlib tqdm wandb

# PyTorch CPU (padrao)
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

# Se voce for usar GPU NVIDIA (CUDA 12.1), comente a linha CPU acima e use esta:
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


In [None]:
import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from cifar10 import trainloader, trainloader_subset, testloader
from models.resnet_s import resnet20
from models.binaryconnect import BC, BinaryActivationSTE
from techniques.quantization import QuantizationAwareConfig
from techniques.prunning import PrunningStructured, PrunningUnstructured


In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

CFG = {
    "label": "rn20a1_notebook",
    "project_name": "imt_efficient_dl",
    "path_backup": "./",
    "checkpoint_path": "./resnet20_bc1w1a_best.pth",
    "use_checkpoint": True,
    "use_subset": False,
    "use_wandb": False,
    "run_training": False,
    "num_epochs": 50,
    "learning_rate": 0.01,
    "weight_decay": 1e-3,
    "momentum": 0.9,
    "nesterov": True,
    "label_smoothing": 0.1,
    "warmup_epochs": 5,
    "early_stopping_patience": 150,
    "early_stopping_min_delta": 0.0,
    "pruning_method": "combined",
    "structured_ratios": [0.05],
    "unstructured_ratios": [0.7],
    "avoid_overlap": True,
}


In [None]:
train_data = trainloader_subset if CFG["use_subset"] else trainloader
print(f"Train samples: {len(train_data.dataset)}")
print(f"Test samples: {len(testloader.dataset)}")


In [None]:
def _strip_module_prefix(state_dict):
    keys = list(state_dict.keys())
    if keys and all(k.startswith("module.") for k in keys):
        return {k[len("module."):]: v for k, v in state_dict.items()}
    return state_dict

def load_bc_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]

    checkpoint = _strip_module_prefix(checkpoint)

    try:
        missing, unexpected = model.load_state_dict(checkpoint, strict=False)
    except RuntimeError:
        checkpoint = {f"model.{k}": v for k, v in checkpoint.items()}
        missing, unexpected = model.load_state_dict(checkpoint, strict=False)

    print(f"Checkpoint carregado: {checkpoint_path}")
    print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")


base_model = resnet20()
num_relu_modules = sum(1 for m in base_model.modules() if isinstance(m, nn.ReLU))
if num_relu_modules == 0:
    raise RuntimeError("resnet20 sem nn.ReLU modulos. A troca para BinaryActivationSTE nao sera aplicada.")
print(f"nn.ReLU encontrados na ResNet20 base: {num_relu_modules}")

model = BC(base_model).to(device)
num_binary_act = sum(1 for m in model.modules() if isinstance(m, BinaryActivationSTE))
print(f"BinaryActivationSTE apos wrapping BC: {num_binary_act}")

if CFG["use_checkpoint"] and os.path.exists(CFG["checkpoint_path"]):
    load_bc_checkpoint(model, CFG["checkpoint_path"])
else:
    print("Treinando do zero (checkpoint nao carregado).")

model.eval()


In [None]:
cfg = QuantizationAwareConfig(
    label=CFG["label"],
    model=model,
    train_data=train_data,
    test_data=testloader,
    project_name=CFG["project_name"],
    path_backup=CFG["path_backup"],
    wand_on=CFG["use_wandb"],
    input_dtype="bc",
    num_epochs=CFG["num_epochs"],
    learning_rate=CFG["learning_rate"],
    weight_decay=CFG["weight_decay"],
    optimizer_name="SGD",
    momentum=CFG["momentum"],
    nesterov=CFG["nesterov"],
    label_smoothing=CFG["label_smoothing"],
    warmup_epochs=CFG["warmup_epochs"],
    scheduler_name="LinearLR+CosineAnnealingLR",
    early_stopping_patience=CFG["early_stopping_patience"],
    early_stopping_min_delta=CFG["early_stopping_min_delta"],
    pruning_method=CFG["pruning_method"],
    structured_ratios=CFG["structured_ratios"],
    unstructured_ratios=CFG["unstructured_ratios"],
    avoid_overlap=CFG["avoid_overlap"],
    device=device,
)

cfg.criterion = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)


In [None]:
def reset_opt_and_sched(cfg_obj):
    cfg_obj.optimizer = optim.SGD(
        cfg_obj.model.parameters(),
        lr=cfg_obj.learning_rate,
        momentum=cfg_obj.momentum,
        weight_decay=cfg_obj.weight_decay,
        nesterov=cfg_obj.nesterov,
    )

    warmup = optim.lr_scheduler.LinearLR(
        cfg_obj.optimizer,
        start_factor=0.1,
        end_factor=1.0,
        total_iters=cfg_obj.warmup_epochs,
    )
    cosine = optim.lr_scheduler.CosineAnnealingLR(
        cfg_obj.optimizer,
        T_max=max(1, cfg_obj.num_epochs - cfg_obj.warmup_epochs),
    )
    cfg_obj.scheduler = optim.lr_scheduler.SequentialLR(
        cfg_obj.optimizer,
        schedulers=[warmup, cosine],
        milestones=[cfg_obj.warmup_epochs],
    )


reset_opt_and_sched(cfg)


In [None]:
params_to_prune = [
    (module, "weight")
    for module in cfg.model.modules()
    if isinstance(module, (nn.Conv2d, nn.Linear))
]

prune_unstructured = PrunningUnstructured(
    cfgModel=cfg,
    ratios=cfg.unstructured_ratios,
    params_to_prune=params_to_prune,
)

prune_structured = PrunningStructured(
    cfgModel=cfg,
    ratios=cfg.structured_ratios,
    params_to_prune=params_to_prune,
)

print(f"Camadas prunaveis: {len(params_to_prune)}")


In [None]:
if cfg.pruning_method == "unstructured":
    pruned_model = prune_unstructured.unstructured(
        ratios=cfg.unstructured_ratios,
    )
elif cfg.pruning_method == "structured":
    pruned_model = prune_structured.structured(
        ratios=cfg.structured_ratios,
    )
elif cfg.pruning_method == "combined":
    pruned_model = prune_structured.combined(
        structured_ratios=cfg.structured_ratios,
        unstructured_ratios=cfg.unstructured_ratios,
        avoid_overlap=cfg.avoid_overlap,
    )
else:
    raise ValueError(
        f"pruning_method invalido: {cfg.pruning_method}. Use unstructured, structured ou combined."
    )

if pruned_model is not None:
    cfg.model = pruned_model
    reset_opt_and_sched(cfg)

print("Pruning aplicado com sucesso.")


In [None]:
score = cfg.calculate_score(
    p_s=cfg.structured_ratios[0],
    p_u=cfg.unstructured_ratios[0],
    q_w=1,
    q_a=1,
    w=sum(p.numel() for p in cfg.model.parameters()),
    f=40.55e6,
)
print(f"Score estimado: {score:.6f}")

loss_before, acc_before = cfg.evaluate()
print(f"Antes do treino -> Loss: {loss_before:.4f} | Acc: {acc_before:.2f}%")


In [None]:
if CFG["run_training"]:
    cfg.train_loop()
else:
    print("Treino desabilitado. Ajuste CFG['run_training'] = True para treinar.")


In [None]:
loss_after, acc_after = cfg.evaluate()
print(f"Estado final -> Loss: {loss_after:.4f} | Acc: {acc_after:.2f}%")

pruned_path = os.path.join(
    CFG["path_backup"],
    f"pruned_model_{cfg.pruning_method}_{cfg.label}.pth",
)
torch.save(cfg.model.state_dict(), pruned_path)
print(f"Modelo salvo em: {pruned_path}")


## Uso rapido

1. Ajuste o dicionario `CFG`.
2. Execute todas as celulas em ordem.
3. Para treinar de fato, use `CFG['run_training'] = True`.
4. Para alternar pruning: `unstructured`, `structured` ou `combined`.
