# **Template for Torch-Pruning**

This template is just built for your convinience.

You are not required to follow the steps and method given below.

In [1]:
!pip install --upgrade torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-1.3.7-py3-none-any.whl.metadata (29 kB)
Downloading torch_pruning-1.3.7-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_pruning
Successfully installed torch_pruning-1.3.7


In [2]:
import torch
import torchvision
from torchvision.models import mobilenet_v2
import torch_pruning as tp

  from .autonotebook import tqdm as notebook_tqdm


##A Minimal Example ##  
In this section, you will perform channel pruning using the library [Torch-Pruning](https://github.com/VainF/Torch-Pruning).  

The puuner in Torch-Pruning has three main functions: sparse training (optional), importance estimation, and parameter removal.  
Torch-pruning offers two core features to support this process:

tp.importance(): This criteria is utilized to measure the importance of weights.  

tp.pruner(): This is a pruner used for the actual pruning of the parameters.  

For detailed information on this process, please refer to this [tutorial](https://github.com/VainF/Torch-Pruning/wiki/4.-High%E2%80%90level-Pruners/). Additionally, a more practical example is available in [here](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/main.py).

### 1. Load model


In [3]:
model = torch.load('./mobilenetv2_0.963.pth', map_location="cpu")
device = torch.device("cuda:0")
model = model.to(device)

### 2. Prepare a pruner
By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [4]:
# Importance criterion
imp = tp.importance.GroupNormImportance(p=2, normalizer="max") # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

ignored_layers = []
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == 10: # ignore the classifier
    ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner ( # you can choose any pruner you like.
    model,
    example_inputs,
    importance        = imp,
    pruning_ratio     = 0.75,
    max_pruning_ratio = 0.9,
    iterative_steps   = 200,
    ignored_layers    = ignored_layers,
    global_pruning    = True
)

### 3. Prune the model

In [3]:
import torch.nn as nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchvision.datasets import *
from torchvision.transforms import *
from torchvision.models.mobilenetv2 import MobileNetV2

from tqdm import tqdm


transforms = {
    "train": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    "test": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )

# You can apply your own batch_size
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=64,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
    drop_last=True
  )

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def finetune(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR | None = None,
  callbacks = None,
  device: torch.device = torch.device("cpu"),
  l1_coeff: float | None = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # L1 reg
    if l1_coeff:
      l1_loss = 0.
      for p in model.parameters():
        l1_loss += p.abs().sum()
      
      loss = loss + l1_coeff * l1_loss

    # Backward propagation
    loss.backward()

    # Update optimizer and LR scheduler
    optimizer.step()
    if scheduler: scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()


@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
  device: torch.device = torch.device("cpu")
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to GPU
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

In [7]:
# Parameters:
epochs_per_prune = 1
finetune_lr      = 0.01
target_MFLOPs    = 45.

# Model size before pruning
base_macs, base_nparams = tp.utils.count_ops_and_params(model, torch.randn(1, 3, 224, 224, device=device))
base_MFLOPs = base_macs / 1e6
print(f"[Original MFLOPs = {base_MFLOPs}]")

for i in range(1, pruner.iterative_steps + 1):
  
  model.eval()

  if isinstance(imp, tp.importance.GroupTaylorImportance):
    # Taylor expansion requires gradients for importance estimation
    loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
    loss.backward() # before pruner.step()

  # prune
  pruner.step()

  # Parameter & MACs Counter
  pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, torch.randn(1, 3, 224, 224, device=device))
  MFLOPs = pruned_macs / 1e6
  print("\n" + "-" * 50 + f"Iter {i}" + "-" * 50)
  print(f"MFLOPs: {MFLOPs}")

  # Finetune the pruned model here
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=1e-4)
  print(f"\nFinetuning ... ")
  finetune(model, dataloader['train'], criterion, optimizer, device=device)
  acc = evaluate(model, dataloader['test'], device=device)
  print(f"Accuracy: {acc}")

  # Stopping:
  if MFLOPs <= target_MFLOPs:
    break
  

[Original MFLOPs = 318.969098]

--------------------------------------------------Iter 1--------------------------------------------------
MFLOPs: 308.050134

Finetuning ... 


                                                        

Accuracy: 95.1221923828125

--------------------------------------------------Iter 2--------------------------------------------------
MFLOPs: 299.791233

Finetuning ... 


                                                        

Accuracy: 95.48277282714844

--------------------------------------------------Iter 3--------------------------------------------------
MFLOPs: 295.74839

Finetuning ... 


                                                        

Accuracy: 95.58293151855469

--------------------------------------------------Iter 4--------------------------------------------------
MFLOPs: 294.269815

Finetuning ... 


                                                        

Accuracy: 95.72315979003906

--------------------------------------------------Iter 5--------------------------------------------------
MFLOPs: 292.576571

Finetuning ... 


                                                        

Accuracy: 95.59294891357422

--------------------------------------------------Iter 6--------------------------------------------------
MFLOPs: 291.086383

Finetuning ... 


                                                        

Accuracy: 95.95352172851562

--------------------------------------------------Iter 7--------------------------------------------------
MFLOPs: 289.687237

Finetuning ... 


                                                        

Accuracy: 95.87339782714844

--------------------------------------------------Iter 8--------------------------------------------------
MFLOPs: 288.405985

Finetuning ... 


                                                        

Accuracy: 95.92347717285156

--------------------------------------------------Iter 9--------------------------------------------------
MFLOPs: 287.637763

Finetuning ... 


                                                        

Accuracy: 95.76322174072266

--------------------------------------------------Iter 10--------------------------------------------------
MFLOPs: 286.771688

Finetuning ... 


                                                        

Accuracy: 95.9034423828125

--------------------------------------------------Iter 11--------------------------------------------------
MFLOPs: 285.474266

Finetuning ... 


                                                        

Accuracy: 95.75320434570312

--------------------------------------------------Iter 12--------------------------------------------------
MFLOPs: 284.710846

Finetuning ... 


                                                        

Accuracy: 95.79327392578125

--------------------------------------------------Iter 13--------------------------------------------------
MFLOPs: 283.935813

Finetuning ... 


                                                        

Accuracy: 95.78324890136719

--------------------------------------------------Iter 14--------------------------------------------------
MFLOPs: 282.797935

Finetuning ... 


                                                        

Accuracy: 95.77323913574219

--------------------------------------------------Iter 15--------------------------------------------------
MFLOPs: 282.022902

Finetuning ... 


                                                        

Accuracy: 95.75320434570312

--------------------------------------------------Iter 16--------------------------------------------------
MFLOPs: 281.384432

Finetuning ... 


                                                        

Accuracy: 95.67308044433594

--------------------------------------------------Iter 17--------------------------------------------------
MFLOPs: 280.364693

Finetuning ... 


                                                        

Accuracy: 95.84335327148438

--------------------------------------------------Iter 18--------------------------------------------------
MFLOPs: 279.458977

Finetuning ... 


                                                        

Accuracy: 95.78324890136719

--------------------------------------------------Iter 19--------------------------------------------------
MFLOPs: 278.717852

Finetuning ... 


                                                        

Accuracy: 95.693115234375

--------------------------------------------------Iter 20--------------------------------------------------
MFLOPs: 277.607071

Finetuning ... 


                                                        

Accuracy: 95.26242065429688

--------------------------------------------------Iter 21--------------------------------------------------
MFLOPs: 276.508834

Finetuning ... 


                                                        

Accuracy: 95.5428695678711

--------------------------------------------------Iter 22--------------------------------------------------
MFLOPs: 275.198623

Finetuning ... 


                                                        

Accuracy: 95.302490234375

--------------------------------------------------Iter 23--------------------------------------------------
MFLOPs: 274.467249

Finetuning ... 


                                                        

Accuracy: 95.52284240722656

--------------------------------------------------Iter 24--------------------------------------------------
MFLOPs: 273.724262

Finetuning ... 


                                                        

Accuracy: 95.35256958007812

--------------------------------------------------Iter 25--------------------------------------------------
MFLOPs: 272.970593

Finetuning ... 


                                                        

Accuracy: 95.32251739501953

--------------------------------------------------Iter 26--------------------------------------------------
MFLOPs: 271.858881

Finetuning ... 


                                                        

Accuracy: 95.13221740722656

--------------------------------------------------Iter 27--------------------------------------------------
MFLOPs: 271.065571

Finetuning ... 


                                                        

Accuracy: 94.62139892578125

--------------------------------------------------Iter 28--------------------------------------------------
MFLOPs: 270.34581

Finetuning ... 


                                                        

Accuracy: 95.20232391357422

--------------------------------------------------Iter 29--------------------------------------------------
MFLOPs: 269.535007

Finetuning ... 


                                                        

Accuracy: 94.97195434570312

--------------------------------------------------Iter 30--------------------------------------------------
MFLOPs: 268.678683

Finetuning ... 


                                                        

Accuracy: 95.32251739501953

--------------------------------------------------Iter 31--------------------------------------------------
MFLOPs: 267.920212

Finetuning ... 


                                                        

Accuracy: 94.89183044433594

--------------------------------------------------Iter 32--------------------------------------------------
MFLOPs: 266.983528

Finetuning ... 


                                                        

Accuracy: 95.302490234375

--------------------------------------------------Iter 33--------------------------------------------------
MFLOPs: 265.859272

Finetuning ... 


                                                        

Accuracy: 95.01202392578125

--------------------------------------------------Iter 34--------------------------------------------------
MFLOPs: 264.141773

Finetuning ... 


                                                        

Accuracy: 95.00199890136719

--------------------------------------------------Iter 35--------------------------------------------------
MFLOPs: 263.081462

Finetuning ... 


                                                        

Accuracy: 94.97195434570312

--------------------------------------------------Iter 36--------------------------------------------------
MFLOPs: 262.315249

Finetuning ... 


                                                        

Accuracy: 94.83173370361328

--------------------------------------------------Iter 37--------------------------------------------------
MFLOPs: 261.444519

Finetuning ... 


                                                        

Accuracy: 95.03205108642578

--------------------------------------------------Iter 38--------------------------------------------------
MFLOPs: 260.204133

Finetuning ... 


                                                        

Accuracy: 94.55128479003906

--------------------------------------------------Iter 39--------------------------------------------------
MFLOPs: 259.264656

Finetuning ... 


                                                        

Accuracy: 94.89183044433594

--------------------------------------------------Iter 40--------------------------------------------------
MFLOPs: 258.410488

Finetuning ... 


                                                        

Accuracy: 94.94190979003906

--------------------------------------------------Iter 41--------------------------------------------------
MFLOPs: 257.391582

Finetuning ... 


                                                        

Accuracy: 95.04206848144531

--------------------------------------------------Iter 42--------------------------------------------------
MFLOPs: 256.317404

Finetuning ... 


                                                        

Accuracy: 93.86017608642578

--------------------------------------------------Iter 43--------------------------------------------------
MFLOPs: 256.317404

Finetuning ... 


                                                        

Accuracy: 95.20232391357422

--------------------------------------------------Iter 44--------------------------------------------------
MFLOPs: 256.317404

Finetuning ... 


                                                        

Accuracy: 95.22235870361328

--------------------------------------------------Iter 45--------------------------------------------------
MFLOPs: 255.822504

Finetuning ... 


                                                        

Accuracy: 95.18229675292969

--------------------------------------------------Iter 46--------------------------------------------------
MFLOPs: 254.628472

Finetuning ... 


                                                        

Accuracy: 95.06210327148438

--------------------------------------------------Iter 47--------------------------------------------------
MFLOPs: 253.200416

Finetuning ... 


                                                        

Accuracy: 94.78164672851562

--------------------------------------------------Iter 48--------------------------------------------------
MFLOPs: 251.970908

Finetuning ... 


                                                        

Accuracy: 94.99198913574219

--------------------------------------------------Iter 49--------------------------------------------------
MFLOPs: 250.79824

Finetuning ... 


                                                        

Accuracy: 94.98197174072266

--------------------------------------------------Iter 50--------------------------------------------------
MFLOPs: 249.336276

Finetuning ... 


                                                        

Accuracy: 94.921875

--------------------------------------------------Iter 51--------------------------------------------------
MFLOPs: 248.215156

Finetuning ... 


                                                        

Accuracy: 94.74159240722656

--------------------------------------------------Iter 52--------------------------------------------------
MFLOPs: 247.065224

Finetuning ... 


                                                        

Accuracy: 94.00039672851562

--------------------------------------------------Iter 53--------------------------------------------------
MFLOPs: 245.468853

Finetuning ... 


                                                        

Accuracy: 94.36097717285156

--------------------------------------------------Iter 54--------------------------------------------------
MFLOPs: 244.131741

Finetuning ... 


                                                        

Accuracy: 93.33934020996094

--------------------------------------------------Iter 55--------------------------------------------------
MFLOPs: 242.801489

Finetuning ... 


                                                        

Accuracy: 94.48117065429688

--------------------------------------------------Iter 56--------------------------------------------------
MFLOPs: 241.244269

Finetuning ... 


                                                        

Accuracy: 93.9102554321289

--------------------------------------------------Iter 57--------------------------------------------------
MFLOPs: 239.926953

Finetuning ... 


                                                        

Accuracy: 94.3409423828125

--------------------------------------------------Iter 58--------------------------------------------------
MFLOPs: 238.541037

Finetuning ... 


                                                        

Accuracy: 94.3709945678711

--------------------------------------------------Iter 59--------------------------------------------------
MFLOPs: 237.307021

Finetuning ... 


                                                        

Accuracy: 94.6915054321289

--------------------------------------------------Iter 60--------------------------------------------------
MFLOPs: 235.753525

Finetuning ... 


                                                        

Accuracy: 94.15064239501953

--------------------------------------------------Iter 61--------------------------------------------------
MFLOPs: 234.075373

Finetuning ... 


                                                        

Accuracy: 93.82011413574219

--------------------------------------------------Iter 62--------------------------------------------------
MFLOPs: 232.828813

Finetuning ... 


                                                        

Accuracy: 94.12059020996094

--------------------------------------------------Iter 63--------------------------------------------------
MFLOPs: 231.425257

Finetuning ... 


                                                        

Accuracy: 94.19070434570312

--------------------------------------------------Iter 64--------------------------------------------------
MFLOPs: 229.873525

Finetuning ... 


                                                        

Accuracy: 93.9803695678711

--------------------------------------------------Iter 65--------------------------------------------------
MFLOPs: 228.530337

Finetuning ... 


                                                        

Accuracy: 94.18069458007812

--------------------------------------------------Iter 66--------------------------------------------------
MFLOPs: 226.835917

Finetuning ... 


                                                        

Accuracy: 92.72836303710938

--------------------------------------------------Iter 67--------------------------------------------------
MFLOPs: 225.866697

Finetuning ... 


                                                        

Accuracy: 92.75841522216797

--------------------------------------------------Iter 68--------------------------------------------------
MFLOPs: 224.497049

Finetuning ... 


                                                        

Accuracy: 93.05889892578125

--------------------------------------------------Iter 69--------------------------------------------------
MFLOPs: 223.216581

Finetuning ... 


                                                        

Accuracy: 94.58132934570312

--------------------------------------------------Iter 70--------------------------------------------------
MFLOPs: 221.959633

Finetuning ... 


                                                        

Accuracy: 93.88021087646484

--------------------------------------------------Iter 71--------------------------------------------------
MFLOPs: 220.691709

Finetuning ... 


                                                        

Accuracy: 94.521240234375

--------------------------------------------------Iter 72--------------------------------------------------
MFLOPs: 220.691709

Finetuning ... 


                                                        

Accuracy: 94.63140869140625

--------------------------------------------------Iter 73--------------------------------------------------
MFLOPs: 219.374393

Finetuning ... 


                                                        

Accuracy: 94.67147827148438

--------------------------------------------------Iter 74--------------------------------------------------
MFLOPs: 218.799574

Finetuning ... 


                                                        

Accuracy: 94.64142608642578

--------------------------------------------------Iter 75--------------------------------------------------
MFLOPs: 217.489363

Finetuning ... 


                                                        

Accuracy: 94.25080108642578

--------------------------------------------------Iter 76--------------------------------------------------
MFLOPs: 216.161512

Finetuning ... 


                                                        

Accuracy: 93.36939239501953

--------------------------------------------------Iter 77--------------------------------------------------
MFLOPs: 214.811366

Finetuning ... 


                                                        

Accuracy: 94.08052825927734

--------------------------------------------------Iter 78--------------------------------------------------
MFLOPs: 213.518795

Finetuning ... 


                                                        

Accuracy: 93.56971740722656

--------------------------------------------------Iter 79--------------------------------------------------
MFLOPs: 212.18428

Finetuning ... 


                                                        

Accuracy: 94.03044891357422

--------------------------------------------------Iter 80--------------------------------------------------
MFLOPs: 210.828548

Finetuning ... 


                                                        

Accuracy: 93.7099380493164

--------------------------------------------------Iter 81--------------------------------------------------
MFLOPs: 209.461105

Finetuning ... 


                                                        

Accuracy: 90.53485870361328

--------------------------------------------------Iter 82--------------------------------------------------
MFLOPs: 208.198277

Finetuning ... 


                                                        

Accuracy: 93.44952392578125

--------------------------------------------------Iter 83--------------------------------------------------
MFLOPs: 206.842398

Finetuning ... 


                                                        

Accuracy: 92.55809020996094

--------------------------------------------------Iter 84--------------------------------------------------
MFLOPs: 204.742111

Finetuning ... 


                                                        

Accuracy: 88.78205108642578

--------------------------------------------------Iter 85--------------------------------------------------
MFLOPs: 203.411467

Finetuning ... 


                                                        

Accuracy: 93.5897445678711

--------------------------------------------------Iter 86--------------------------------------------------
MFLOPs: 203.411467

Finetuning ... 


                                                        

Accuracy: 93.20913696289062

--------------------------------------------------Iter 87--------------------------------------------------
MFLOPs: 201.98052

Finetuning ... 


                                                        

Accuracy: 92.568115234375

--------------------------------------------------Iter 88--------------------------------------------------
MFLOPs: 200.588577

Finetuning ... 


                                                        

Accuracy: 94.08052825927734

--------------------------------------------------Iter 89--------------------------------------------------
MFLOPs: 199.266655

Finetuning ... 


                                                        

Accuracy: 93.66987609863281

--------------------------------------------------Iter 90--------------------------------------------------
MFLOPs: 197.171464

Finetuning ... 


                                                        

Accuracy: 89.833740234375

--------------------------------------------------Iter 91--------------------------------------------------
MFLOPs: 196.000168

Finetuning ... 


                                                        

Accuracy: 92.7684326171875

--------------------------------------------------Iter 92--------------------------------------------------
MFLOPs: 194.753902

Finetuning ... 


                                                        

Accuracy: 93.1690673828125

--------------------------------------------------Iter 93--------------------------------------------------
MFLOPs: 193.476129

Finetuning ... 


                                                        

Accuracy: 90.16426086425781

--------------------------------------------------Iter 94--------------------------------------------------
MFLOPs: 191.975308

Finetuning ... 


                                                        

Accuracy: 91.6065673828125

--------------------------------------------------Iter 95--------------------------------------------------
MFLOPs: 190.831403

Finetuning ... 


                                                        

Accuracy: 91.18589782714844

--------------------------------------------------Iter 96--------------------------------------------------
MFLOPs: 189.540008

Finetuning ... 


                                                        

Accuracy: 83.27323913574219

--------------------------------------------------Iter 97--------------------------------------------------
MFLOPs: 188.445348

Finetuning ... 


                                                        

Accuracy: 92.89863586425781

--------------------------------------------------Iter 98--------------------------------------------------
MFLOPs: 187.255922

Finetuning ... 


                                                        

Accuracy: 91.93710327148438

--------------------------------------------------Iter 99--------------------------------------------------
MFLOPs: 186.107852

Finetuning ... 


                                                        

Accuracy: 92.07732391357422

--------------------------------------------------Iter 100--------------------------------------------------
MFLOPs: 184.793868

Finetuning ... 


                                                        

Accuracy: 91.9971923828125

--------------------------------------------------Iter 101--------------------------------------------------
MFLOPs: 183.528492

Finetuning ... 


                                                        

Accuracy: 91.2459945678711

--------------------------------------------------Iter 102--------------------------------------------------
MFLOPs: 181.908405

Finetuning ... 


                                                        

Accuracy: 91.80689239501953

--------------------------------------------------Iter 103--------------------------------------------------
MFLOPs: 180.917331

Finetuning ... 


                                                        

Accuracy: 92.52804565429688

--------------------------------------------------Iter 104--------------------------------------------------
MFLOPs: 179.950904

Finetuning ... 


                                                        

Accuracy: 92.96875

--------------------------------------------------Iter 105--------------------------------------------------
MFLOPs: 179.083261

Finetuning ... 


                                                        

Accuracy: 88.16105651855469

--------------------------------------------------Iter 106--------------------------------------------------
MFLOPs: 177.981545

Finetuning ... 


                                                        

Accuracy: 91.6366195678711

--------------------------------------------------Iter 107--------------------------------------------------
MFLOPs: 176.721902

Finetuning ... 


                                                        

Accuracy: 92.75841522216797

--------------------------------------------------Iter 108--------------------------------------------------
MFLOPs: 176.721902

Finetuning ... 


                                                        

Accuracy: 92.21754455566406

--------------------------------------------------Iter 109--------------------------------------------------
MFLOPs: 176.198484

Finetuning ... 


                                                        

Accuracy: 93.09896087646484

--------------------------------------------------Iter 110--------------------------------------------------
MFLOPs: 175.18144

Finetuning ... 


                                                        

Accuracy: 93.1991195678711

--------------------------------------------------Iter 111--------------------------------------------------
MFLOPs: 173.974864

Finetuning ... 


                                                        

Accuracy: 88.37139892578125

--------------------------------------------------Iter 112--------------------------------------------------
MFLOPs: 172.998686

Finetuning ... 


                                                        

Accuracy: 84.97596740722656

--------------------------------------------------Iter 113--------------------------------------------------
MFLOPs: 172.190186

Finetuning ... 


                                                        

Accuracy: 92.12740325927734

--------------------------------------------------Iter 114--------------------------------------------------
MFLOPs: 170.843911

Finetuning ... 


                                                        

Accuracy: 91.05569458007812

--------------------------------------------------Iter 115--------------------------------------------------
MFLOPs: 169.879346

Finetuning ... 


                                                        

Accuracy: 86.81890869140625

--------------------------------------------------Iter 116--------------------------------------------------
MFLOPs: 168.928991

Finetuning ... 


                                                        

Accuracy: 89.39302825927734

--------------------------------------------------Iter 117--------------------------------------------------
MFLOPs: 167.942131

Finetuning ... 


                                                        

Accuracy: 91.18589782714844

--------------------------------------------------Iter 118--------------------------------------------------
MFLOPs: 166.852273

Finetuning ... 


                                                        

Accuracy: 92.40785217285156

--------------------------------------------------Iter 119--------------------------------------------------
MFLOPs: 164.654182

Finetuning ... 


                                                        

Accuracy: 92.1875

--------------------------------------------------Iter 120--------------------------------------------------
MFLOPs: 163.669086

Finetuning ... 


                                                        

Accuracy: 90.60496520996094

--------------------------------------------------Iter 121--------------------------------------------------
MFLOPs: 162.433894

Finetuning ... 


                                                        

Accuracy: 91.71675109863281

--------------------------------------------------Iter 122--------------------------------------------------
MFLOPs: 160.53191

Finetuning ... 


                                                        

Accuracy: 91.91706848144531

--------------------------------------------------Iter 123--------------------------------------------------
MFLOPs: 160.53191

Finetuning ... 


                                                        

Accuracy: 91.396240234375

--------------------------------------------------Iter 124--------------------------------------------------
MFLOPs: 160.339732

Finetuning ... 


                                                        

Accuracy: 92.59815979003906

--------------------------------------------------Iter 125--------------------------------------------------
MFLOPs: 158.450047

Finetuning ... 


                                                        

Accuracy: 92.958740234375

--------------------------------------------------Iter 126--------------------------------------------------
MFLOPs: 157.064572

Finetuning ... 


                                                        

Accuracy: 85.94751739501953

--------------------------------------------------Iter 127--------------------------------------------------
MFLOPs: 154.512701

Finetuning ... 


                                                        

Accuracy: 91.1758804321289

--------------------------------------------------Iter 128--------------------------------------------------
MFLOPs: 153.151481

Finetuning ... 


                                                        

Accuracy: 90.1943130493164

--------------------------------------------------Iter 129--------------------------------------------------
MFLOPs: 152.094159

Finetuning ... 


                                                        

Accuracy: 74.599365234375

--------------------------------------------------Iter 130--------------------------------------------------
MFLOPs: 150.258227

Finetuning ... 


                                                        

Accuracy: 90.47476196289062

--------------------------------------------------Iter 131--------------------------------------------------
MFLOPs: 148.734229

Finetuning ... 


                                                        

Accuracy: 91.34615325927734

--------------------------------------------------Iter 132--------------------------------------------------
MFLOPs: 147.510062

Finetuning ... 


                                                        

Accuracy: 82.45191955566406

--------------------------------------------------Iter 133--------------------------------------------------
MFLOPs: 145.584999

Finetuning ... 


                                                        

Accuracy: 90.77523803710938

--------------------------------------------------Iter 134--------------------------------------------------
MFLOPs: 143.496815

Finetuning ... 


                                                        

Accuracy: 89.52323913574219

--------------------------------------------------Iter 135--------------------------------------------------
MFLOPs: 143.496815

Finetuning ... 


                                                        

Accuracy: 91.82691955566406

--------------------------------------------------Iter 136--------------------------------------------------
MFLOPs: 141.28275

Finetuning ... 


                                                        

Accuracy: 91.43629455566406

--------------------------------------------------Iter 137--------------------------------------------------
MFLOPs: 139.360235

Finetuning ... 


                                                        

Accuracy: 82.31169891357422

--------------------------------------------------Iter 138--------------------------------------------------
MFLOPs: 136.959382

Finetuning ... 


                                                        

Accuracy: 89.50320434570312

--------------------------------------------------Iter 139--------------------------------------------------
MFLOPs: 134.942836

Finetuning ... 


                                                        

Accuracy: 77.34375

--------------------------------------------------Iter 140--------------------------------------------------
MFLOPs: 132.137978

Finetuning ... 


                                                        

Accuracy: 88.75199890136719

--------------------------------------------------Iter 141--------------------------------------------------
MFLOPs: 129.428033

Finetuning ... 


                                                        

Accuracy: 89.01242065429688

--------------------------------------------------Iter 142--------------------------------------------------
MFLOPs: 127.33294

Finetuning ... 


                                                        

Accuracy: 91.16586303710938

--------------------------------------------------Iter 143--------------------------------------------------
MFLOPs: 125.502888

Finetuning ... 


                                                        

Accuracy: 89.95392608642578

--------------------------------------------------Iter 144--------------------------------------------------
MFLOPs: 123.874275

Finetuning ... 


                                                        

Accuracy: 90.87539672851562

--------------------------------------------------Iter 145--------------------------------------------------
MFLOPs: 123.874275

Finetuning ... 


                                                        

Accuracy: 90.64503479003906

--------------------------------------------------Iter 146--------------------------------------------------
MFLOPs: 123.874275

Finetuning ... 


                                                        

Accuracy: 83.80409240722656

--------------------------------------------------Iter 147--------------------------------------------------
MFLOPs: 123.603991

Finetuning ... 


                                                        

Accuracy: 91.12580108642578

--------------------------------------------------Iter 148--------------------------------------------------
MFLOPs: 121.678095

Finetuning ... 


                                                        

Accuracy: 90.80529022216797

--------------------------------------------------Iter 149--------------------------------------------------
MFLOPs: 120.497391

Finetuning ... 


                                                        

Accuracy: 90.0040054321289

--------------------------------------------------Iter 150--------------------------------------------------
MFLOPs: 118.760978

Finetuning ... 


                                                        

Accuracy: 91.41626739501953

--------------------------------------------------Iter 151--------------------------------------------------
MFLOPs: 116.639768

Finetuning ... 


                                                        

Accuracy: 89.91386413574219

--------------------------------------------------Iter 152--------------------------------------------------
MFLOPs: 115.06187

Finetuning ... 


                                                        

Accuracy: 88.91226196289062

--------------------------------------------------Iter 153--------------------------------------------------
MFLOPs: 113.659049

Finetuning ... 


                                                        

Accuracy: 89.19271087646484

--------------------------------------------------Iter 154--------------------------------------------------
MFLOPs: 111.459439

Finetuning ... 


                                                        

Accuracy: 90.24439239501953

--------------------------------------------------Iter 155--------------------------------------------------
MFLOPs: 108.611902

Finetuning ... 


                                                        

Accuracy: 89.09254455566406

--------------------------------------------------Iter 156--------------------------------------------------
MFLOPs: 107.163119

Finetuning ... 


                                                        

Accuracy: 91.40625

--------------------------------------------------Iter 157--------------------------------------------------
MFLOPs: 105.679987

Finetuning ... 


                                                        

Accuracy: 86.63862609863281

--------------------------------------------------Iter 158--------------------------------------------------
MFLOPs: 103.344647

Finetuning ... 


                                                        

Accuracy: 89.52323913574219

--------------------------------------------------Iter 159--------------------------------------------------
MFLOPs: 102.413353

Finetuning ... 


                                                        

Accuracy: 90.89543151855469

--------------------------------------------------Iter 160--------------------------------------------------
MFLOPs: 100.139459

Finetuning ... 


                                                        

Accuracy: 90.99559020996094

--------------------------------------------------Iter 161--------------------------------------------------
MFLOPs: 98.929551

Finetuning ... 


                                                        

Accuracy: 90.98558044433594

--------------------------------------------------Iter 162--------------------------------------------------
MFLOPs: 96.477199

Finetuning ... 


                                                        

Accuracy: 89.79367065429688

--------------------------------------------------Iter 163--------------------------------------------------
MFLOPs: 95.516113

Finetuning ... 


                                                        

Accuracy: 85.58694458007812

--------------------------------------------------Iter 164--------------------------------------------------
MFLOPs: 93.338308

Finetuning ... 


                                                        

Accuracy: 89.35296630859375

--------------------------------------------------Iter 165--------------------------------------------------
MFLOPs: 91.832881

Finetuning ... 


                                                        

Accuracy: 91.41626739501953

--------------------------------------------------Iter 166--------------------------------------------------
MFLOPs: 90.261549

Finetuning ... 


                                                        

Accuracy: 89.99398803710938

--------------------------------------------------Iter 167--------------------------------------------------
MFLOPs: 88.392444

Finetuning ... 


                                                        

Accuracy: 91.25601196289062

--------------------------------------------------Iter 168--------------------------------------------------
MFLOPs: 86.661911

Finetuning ... 


                                                        

Accuracy: 90.98558044433594

--------------------------------------------------Iter 169--------------------------------------------------
MFLOPs: 85.607725

Finetuning ... 


                                                        

Accuracy: 75.24038696289062

--------------------------------------------------Iter 170--------------------------------------------------
MFLOPs: 83.127717

Finetuning ... 


                                                        

Accuracy: 87.6903076171875

--------------------------------------------------Iter 171--------------------------------------------------
MFLOPs: 82.026275

Finetuning ... 


                                                        

Accuracy: 88.69190979003906

--------------------------------------------------Iter 172--------------------------------------------------
MFLOPs: 81.313031

Finetuning ... 


                                                        

Accuracy: 89.95392608642578

--------------------------------------------------Iter 173--------------------------------------------------
MFLOPs: 78.78418

Finetuning ... 


                                                        

Accuracy: 89.01242065429688

--------------------------------------------------Iter 174--------------------------------------------------
MFLOPs: 77.180851

Finetuning ... 


                                                        

Accuracy: 86.81890869140625

--------------------------------------------------Iter 175--------------------------------------------------
MFLOPs: 75.145116

Finetuning ... 


                                                        

Accuracy: 88.22115325927734

--------------------------------------------------Iter 176--------------------------------------------------
MFLOPs: 72.81839

Finetuning ... 


                                                        

Accuracy: 89.76362609863281

--------------------------------------------------Iter 177--------------------------------------------------
MFLOPs: 71.007369

Finetuning ... 


                                                        

Accuracy: 89.07251739501953

--------------------------------------------------Iter 178--------------------------------------------------
MFLOPs: 69.259421

Finetuning ... 


                                                        

Accuracy: 90.45472717285156

--------------------------------------------------Iter 179--------------------------------------------------
MFLOPs: 68.206606

Finetuning ... 


                                                        

Accuracy: 89.87379455566406

--------------------------------------------------Iter 180--------------------------------------------------
MFLOPs: 66.893424

Finetuning ... 


                                                        

Accuracy: 90.15425109863281

--------------------------------------------------Iter 181--------------------------------------------------
MFLOPs: 66.46438

Finetuning ... 


                                                        

Accuracy: 90.68509674072266

--------------------------------------------------Iter 182--------------------------------------------------
MFLOPs: 65.265487

Finetuning ... 


                                                        

Accuracy: 89.43309020996094

--------------------------------------------------Iter 183--------------------------------------------------
MFLOPs: 63.34984

Finetuning ... 


                                                        

Accuracy: 90.10417175292969

--------------------------------------------------Iter 184--------------------------------------------------
MFLOPs: 60.751566

Finetuning ... 


                                                        

Accuracy: 89.0224380493164

--------------------------------------------------Iter 185--------------------------------------------------
MFLOPs: 57.586009

Finetuning ... 


                                                        

Accuracy: 87.07933044433594

--------------------------------------------------Iter 186--------------------------------------------------
MFLOPs: 55.895145

Finetuning ... 


                                                        

Accuracy: 89.6133804321289

--------------------------------------------------Iter 187--------------------------------------------------
MFLOPs: 54.171079

Finetuning ... 


                                                        

Accuracy: 87.16947174072266

--------------------------------------------------Iter 188--------------------------------------------------
MFLOPs: 52.082041

Finetuning ... 


                                                        

Accuracy: 89.14262390136719

--------------------------------------------------Iter 189--------------------------------------------------
MFLOPs: 51.429791

Finetuning ... 


                                                        

Accuracy: 90.71514892578125

--------------------------------------------------Iter 190--------------------------------------------------
MFLOPs: 50.279957

Finetuning ... 


                                                        

Accuracy: 89.84375

--------------------------------------------------Iter 191--------------------------------------------------
MFLOPs: 49.564457

Finetuning ... 


                                                        

Accuracy: 90.31449890136719

--------------------------------------------------Iter 192--------------------------------------------------
MFLOPs: 48.755358

Finetuning ... 


                                                        

Accuracy: 89.17267608642578

--------------------------------------------------Iter 193--------------------------------------------------
MFLOPs: 47.821191

Finetuning ... 


                                                        

Accuracy: 87.61017608642578

--------------------------------------------------Iter 194--------------------------------------------------
MFLOPs: 47.000078

Finetuning ... 


                                                        

Accuracy: 89.0224380493164

--------------------------------------------------Iter 195--------------------------------------------------
MFLOPs: 46.032052

Finetuning ... 


                                                        

Accuracy: 90.224365234375

--------------------------------------------------Iter 196--------------------------------------------------
MFLOPs: 44.467766

Finetuning ... 


                                                        

Accuracy: 88.79206848144531




In [9]:
final_fintune_epochs = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, final_fintune_epochs)

best_acc = 0.
for e in range(1, final_fintune_epochs + 1):
    print(f"Epoch {e}:")
    finetune(model, dataloader['train'], criterion, optimizer, device=device)
    acc = evaluate(model, dataloader['test'], device=device)
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), f"./save/best_DepGraph_model_{best_acc :.4f}.pth")
    
    print(f"Accuracy = {acc} / Best = {best_acc}")

Epoch 1:


                                                        

Accuracy = 92.46794891357422 / Best = 92.46794891357422
Epoch 2:


                                                        

Accuracy = 92.54808044433594 / Best = 92.54808044433594
Epoch 3:


                                                        

Accuracy = 92.8084945678711 / Best = 92.8084945678711
Epoch 4:


                                                        

Accuracy = 92.46794891357422 / Best = 92.8084945678711
Epoch 5:


                                                        

Accuracy = 92.70833587646484 / Best = 92.8084945678711
Epoch 6:


                                                        

Accuracy = 92.65824890136719 / Best = 92.8084945678711
Epoch 7:


                                                        

Accuracy = 92.66827392578125 / Best = 92.8084945678711
Epoch 8:


                                                        

Accuracy = 92.84855651855469 / Best = 92.84855651855469
Epoch 9:


                                                        

Accuracy = 92.94871520996094 / Best = 92.94871520996094
Epoch 10:


                                                        

Accuracy = 92.8084945678711 / Best = 92.94871520996094
Epoch 11:


                                                        

Accuracy = 92.87860870361328 / Best = 92.94871520996094
Epoch 12:


                                                        

Accuracy = 92.93870544433594 / Best = 92.94871520996094
Epoch 13:


                                                        

Accuracy = 92.70833587646484 / Best = 92.94871520996094
Epoch 14:


                                                        

Accuracy = 92.9286880493164 / Best = 92.94871520996094
Epoch 15:


                                                        

Accuracy = 92.78846740722656 / Best = 92.94871520996094
Epoch 16:


                                                        

Accuracy = 92.67828369140625 / Best = 92.94871520996094
Epoch 17:


                                                        

Accuracy = 92.82852172851562 / Best = 92.94871520996094
Epoch 18:


                                                        

Accuracy = 93.02884674072266 / Best = 93.02884674072266
Epoch 19:


                                                        

Accuracy = 92.98878479003906 / Best = 93.02884674072266
Epoch 20:


                                                        

Accuracy = 92.93870544433594 / Best = 93.02884674072266
Epoch 21:


                                                        

Accuracy = 92.85857391357422 / Best = 93.02884674072266
Epoch 22:


                                                        

Accuracy = 92.88862609863281 / Best = 93.02884674072266
Epoch 23:


                                                        

Accuracy = 92.93870544433594 / Best = 93.02884674072266
Epoch 24:


                                                        

Accuracy = 92.88862609863281 / Best = 93.02884674072266
Epoch 25:


                                                        

Accuracy = 92.85857391357422 / Best = 93.02884674072266
Epoch 26:


                                                        

Accuracy = 92.958740234375 / Best = 93.02884674072266
Epoch 27:


                                                        

Accuracy = 92.958740234375 / Best = 93.02884674072266
Epoch 28:


                                                        

Accuracy = 92.97876739501953 / Best = 93.02884674072266
Epoch 29:


                                                        

Accuracy = 92.86859130859375 / Best = 93.02884674072266
Epoch 30:


                                                        

Accuracy = 92.84855651855469 / Best = 93.02884674072266
Epoch 31:


                                                        

Accuracy = 92.87860870361328 / Best = 93.02884674072266
Epoch 32:


                                                        

Accuracy = 92.96875 / Best = 93.02884674072266
Epoch 33:


                                                        

Accuracy = 92.96875 / Best = 93.02884674072266
Epoch 34:


                                                        

Accuracy = 92.94871520996094 / Best = 93.02884674072266
Epoch 35:


                                                        

Accuracy = 92.98878479003906 / Best = 93.02884674072266
Epoch 36:


                                                        

Accuracy = 92.81851196289062 / Best = 93.02884674072266
Epoch 37:


                                                        

Accuracy = 93.00881958007812 / Best = 93.02884674072266
Epoch 38:


                                                        

Accuracy = 92.97876739501953 / Best = 93.02884674072266
Epoch 39:


                                                        

Accuracy = 93.10897827148438 / Best = 93.10897827148438
Epoch 40:


                                                        

Accuracy = 92.97876739501953 / Best = 93.10897827148438
Epoch 41:


                                                        

Accuracy = 92.99879455566406 / Best = 93.10897827148438
Epoch 42:


                                                        

Accuracy = 92.97876739501953 / Best = 93.10897827148438
Epoch 43:


                                                        

Accuracy = 92.96875 / Best = 93.10897827148438
Epoch 44:


                                                        

Accuracy = 93.02884674072266 / Best = 93.10897827148438
Epoch 45:


                                                        

Accuracy = 92.958740234375 / Best = 93.10897827148438
Epoch 46:


                                                        

Accuracy = 93.1690673828125 / Best = 93.1690673828125
Epoch 47:


                                                        

Accuracy = 92.90865325927734 / Best = 93.1690673828125
Epoch 48:


                                                        

Accuracy = 93.06890869140625 / Best = 93.1690673828125
Epoch 49:


                                                        

Accuracy = 93.14904022216797 / Best = 93.1690673828125
Epoch 50:


                                                        

Accuracy = 93.03886413574219 / Best = 93.1690673828125
Epoch 51:


                                                        

Accuracy = 93.07892608642578 / Best = 93.1690673828125
Epoch 52:


                                                        

Accuracy = 93.04887390136719 / Best = 93.1690673828125
Epoch 53:


                                                        

Accuracy = 93.08894348144531 / Best = 93.1690673828125
Epoch 54:


                                                        

Accuracy = 93.04887390136719 / Best = 93.1690673828125
Epoch 55:


                                                        

Accuracy = 93.06890869140625 / Best = 93.1690673828125
Epoch 56:


                                                        

Accuracy = 93.10897827148438 / Best = 93.1690673828125
Epoch 57:


                                                        

Accuracy = 93.00881958007812 / Best = 93.1690673828125
Epoch 58:


                                                        

Accuracy = 93.1290054321289 / Best = 93.1690673828125
Epoch 59:


                                                        

Accuracy = 92.99879455566406 / Best = 93.1690673828125
Epoch 60:


                                                        

Accuracy = 92.98878479003906 / Best = 93.1690673828125
Epoch 61:


                                                        

Accuracy = 93.1290054321289 / Best = 93.1690673828125
Epoch 62:


                                                        

Accuracy = 93.08894348144531 / Best = 93.1690673828125
Epoch 63:


                                                        

Accuracy = 92.99879455566406 / Best = 93.1690673828125
Epoch 64:


                                                        

Accuracy = 93.02884674072266 / Best = 93.1690673828125
Epoch 65:


                                                        

Accuracy = 93.14904022216797 / Best = 93.1690673828125
Epoch 66:


                                                        

Accuracy = 93.1690673828125 / Best = 93.1690673828125
Epoch 67:


                                                        

Accuracy = 92.89863586425781 / Best = 93.1690673828125
Epoch 68:


                                                        

Accuracy = 93.06890869140625 / Best = 93.1690673828125
Epoch 69:


                                                        

Accuracy = 93.1590576171875 / Best = 93.1690673828125
Epoch 70:


                                                        

Accuracy = 93.10897827148438 / Best = 93.1690673828125
Epoch 71:


                                                        

Accuracy = 93.08894348144531 / Best = 93.1690673828125
Epoch 72:


                                                        

Accuracy = 92.97876739501953 / Best = 93.1690673828125
Epoch 73:


                                                        

Accuracy = 93.03886413574219 / Best = 93.1690673828125
Epoch 74:


                                                        

Accuracy = 93.04887390136719 / Best = 93.1690673828125
Epoch 75:


                                                        

Accuracy = 93.04887390136719 / Best = 93.1690673828125
Epoch 76:


                                                        

Accuracy = 93.1290054321289 / Best = 93.1690673828125
Epoch 77:


                                                        

Accuracy = 93.14904022216797 / Best = 93.1690673828125
Epoch 78:


                                                        

Accuracy = 93.09896087646484 / Best = 93.1690673828125
Epoch 79:


                                                        

Accuracy = 93.08894348144531 / Best = 93.1690673828125
Epoch 80:


                                                        

Accuracy = 93.17909240722656 / Best = 93.17909240722656
Epoch 81:


                                                        

Accuracy = 93.13902282714844 / Best = 93.17909240722656
Epoch 82:


                                                        

Accuracy = 93.10897827148438 / Best = 93.17909240722656
Epoch 83:


                                                        

Accuracy = 93.05889892578125 / Best = 93.17909240722656
Epoch 84:


                                                        

Accuracy = 93.1991195678711 / Best = 93.1991195678711
Epoch 85:


                                                        

Accuracy = 93.09896087646484 / Best = 93.1991195678711
Epoch 86:


                                                        

Accuracy = 93.10897827148438 / Best = 93.1991195678711
Epoch 87:


                                                        

Accuracy = 93.05889892578125 / Best = 93.1991195678711
Epoch 88:


                                                        

Accuracy = 93.04887390136719 / Best = 93.1991195678711
Epoch 89:


                                                        

Accuracy = 93.11898803710938 / Best = 93.1991195678711
Epoch 90:


                                                        

Accuracy = 93.1991195678711 / Best = 93.1991195678711
Epoch 91:


                                                        

Accuracy = 93.24919891357422 / Best = 93.24919891357422
Epoch 92:


                                                        

Accuracy = 93.06890869140625 / Best = 93.24919891357422
Epoch 93:


                                                        

Accuracy = 93.01882934570312 / Best = 93.24919891357422
Epoch 94:


                                                        

Accuracy = 93.09896087646484 / Best = 93.24919891357422
Epoch 95:


                                                        

Accuracy = 93.05889892578125 / Best = 93.24919891357422
Epoch 96:


                                                        

Accuracy = 93.08894348144531 / Best = 93.24919891357422
Epoch 97:


                                                        

Accuracy = 93.11898803710938 / Best = 93.24919891357422
Epoch 98:


                                                        

Accuracy = 92.98878479003906 / Best = 93.24919891357422
Epoch 99:


                                                        

Accuracy = 93.14904022216797 / Best = 93.24919891357422
Epoch 100:


                                                        

Accuracy = 93.1290054321289 / Best = 93.24919891357422


In [16]:
device = torch.device("cuda:0")

# Load last best model:
last_best_model_file = "save/best_DepGraph_2_model_93.1991.pth"
last_best_model = torch.load(last_best_model_file, map_location="cpu")
last_best_model.to(device)
last_best_model.eval()

# Importance criterion
imp = tp.importance.GroupNormImportance(p=2, normalizer="max") # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

ignored_layers = []
for m in last_best_model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == 10:
    ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner(
    last_best_model,
    example_inputs,
    importance        = imp,
    pruning_ratio     = 0.2,
    iterative_steps   = 10,
    ignored_layers    = ignored_layers,
    global_pruning    = True
)

In [17]:
# Parameters:
epochs_per_prune = 1
finetune_lr      = 0.01
target_MFLOPs    = 27.

# Model size before pruning
base_macs, base_nparams = tp.utils.count_ops_and_params(last_best_model, torch.randn(1, 3, 224, 224, device=device))
base_MFLOPs = base_macs / 1e6
print(f"[Original MFLOPs = {base_MFLOPs}]")

for i in range(1, pruner.iterative_steps + 1):
  
  last_best_model.eval()

  if isinstance(imp, tp.importance.GroupTaylorImportance):
    # Taylor expansion requires gradients for importance estimation
    loss = last_best_model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
    loss.backward() # before pruner.step()

  # prune
  pruner.step()

  # Parameter & MACs Counter
  pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(last_best_model, torch.randn(1, 3, 224, 224, device=device))
  MFLOPs = pruned_macs / 1e6
  print("\n" + "-" * 50 + f"Iter {i}" + "-" * 50)
  print(f"MFLOPs: {MFLOPs}")

  # Finetune the pruned model here
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(last_best_model.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=1e-5)
  print(f"\nFinetuning ... ")
  finetune(last_best_model, dataloader['train'], criterion, optimizer, device=device)
  acc = evaluate(last_best_model, dataloader['test'], device=device)
  print(f"Accuracy: {acc}")

  # Save ckpt
  torch.save(last_best_model, f"save/Prune{i}_{acc}.pth")

  # Stopping
  if MFLOPs <= target_MFLOPs:
    break

[Original MFLOPs = 32.86549]

--------------------------------------------------Iter 1--------------------------------------------------
MFLOPs: 31.66037

Finetuning ... 


                                                        

Accuracy: 92.96875

--------------------------------------------------Iter 2--------------------------------------------------
MFLOPs: 31.461372

Finetuning ... 


                                                        

Accuracy: 92.44792175292969

--------------------------------------------------Iter 3--------------------------------------------------
MFLOPs: 31.214874

Finetuning ... 


                                                        

Accuracy: 92.62820434570312

--------------------------------------------------Iter 4--------------------------------------------------
MFLOPs: 30.456338

Finetuning ... 


                                                        

Accuracy: 92.11738586425781

--------------------------------------------------Iter 5--------------------------------------------------
MFLOPs: 29.62826

Finetuning ... 


                                                        

Accuracy: 88.42147827148438

--------------------------------------------------Iter 6--------------------------------------------------
MFLOPs: 28.835111

Finetuning ... 


                                                        

Accuracy: 88.57171630859375

--------------------------------------------------Iter 7--------------------------------------------------
MFLOPs: 28.307119

Finetuning ... 


                                                        

Accuracy: 90.93550109863281

--------------------------------------------------Iter 8--------------------------------------------------
MFLOPs: 27.93037

Finetuning ... 


                                                        

Accuracy: 89.92387390136719

--------------------------------------------------Iter 9--------------------------------------------------
MFLOPs: 27.47286

Finetuning ... 


                                                        

Accuracy: 89.40304565429688

--------------------------------------------------Iter 10--------------------------------------------------
MFLOPs: 26.326102

Finetuning ... 


                                                        

Accuracy: 89.43309020996094


In [18]:
final_fintune_epochs = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(last_best_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, final_fintune_epochs)

best_acc = 0.
for e in range(1, final_fintune_epochs + 1):
    print(f"Epoch {e}:")
    finetune(last_best_model, dataloader['train'], criterion, optimizer, device=device)
    acc = evaluate(last_best_model, dataloader['test'], device=device)
    if acc > best_acc:
        best_acc = acc
        torch.save(last_best_model, f"./save/best_next_DepGraph_model_{best_acc :.4f}.pth")
    
    print(f"Accuracy = {acc} / Best = {best_acc}")

Epoch 1:


                                                        

Accuracy = 88.98237609863281 / Best = 88.98237609863281
Epoch 2:


                                                        

Accuracy = 89.46314239501953 / Best = 89.46314239501953
Epoch 3:


                                                        

Accuracy = 89.20272827148438 / Best = 89.46314239501953
Epoch 4:


                                                        

Accuracy = 89.11257934570312 / Best = 89.46314239501953
Epoch 5:


                                                        

Accuracy = 90.33453369140625 / Best = 90.33453369140625
Epoch 6:


                                                        

Accuracy = 89.97396087646484 / Best = 90.33453369140625
Epoch 7:


                                                        

Accuracy = 90.71514892578125 / Best = 90.71514892578125
Epoch 8:


                                                        

Accuracy = 90.33453369140625 / Best = 90.71514892578125
Epoch 9:


                                                        

Accuracy = 89.2528076171875 / Best = 90.71514892578125
Epoch 10:


                                                        

Accuracy = 89.99398803710938 / Best = 90.71514892578125
Epoch 11:


                                                        

Accuracy = 89.70352172851562 / Best = 90.71514892578125
Epoch 12:


                                                        

Accuracy = 89.93389892578125 / Best = 90.71514892578125
Epoch 13:


                                                        

Accuracy = 89.72355651855469 / Best = 90.71514892578125
Epoch 14:


                                                        

Accuracy = 90.25440979003906 / Best = 90.71514892578125
Epoch 15:


                                                        

Accuracy = 90.27444458007812 / Best = 90.71514892578125
Epoch 16:


                                                        

Accuracy = 90.77523803710938 / Best = 90.77523803710938
Epoch 17:


                                                        

Accuracy = 90.69511413574219 / Best = 90.77523803710938
Epoch 18:


                                                        

Accuracy = 90.59495544433594 / Best = 90.77523803710938
Epoch 19:


                                                        

Accuracy = 90.51482391357422 / Best = 90.77523803710938
Epoch 20:


                                                        

Accuracy = 90.66506958007812 / Best = 90.77523803710938
Epoch 21:


                                                        

Accuracy = 90.44471740722656 / Best = 90.77523803710938
Epoch 22:


                                                        

Accuracy = 90.06410217285156 / Best = 90.77523803710938
Epoch 23:


                                                        

Accuracy = 90.37460327148438 / Best = 90.77523803710938
Epoch 24:


                                                        

Accuracy = 90.44471740722656 / Best = 90.77523803710938
Epoch 25:


                                                        

Accuracy = 90.38461303710938 / Best = 90.77523803710938
Epoch 26:


                                                        

Accuracy = 90.26441955566406 / Best = 90.77523803710938
Epoch 27:


                                                        

Accuracy = 90.70512390136719 / Best = 90.77523803710938
Epoch 28:


                                                        

Accuracy = 90.38461303710938 / Best = 90.77523803710938
Epoch 29:


                                                        

Accuracy = 90.31449890136719 / Best = 90.77523803710938
Epoch 30:


                                                        

Accuracy = 90.90544891357422 / Best = 90.90544891357422
Epoch 31:


                                                        

Accuracy = 90.8153076171875 / Best = 90.90544891357422
Epoch 32:


                                                        

Accuracy = 90.52484130859375 / Best = 90.90544891357422
Epoch 33:


                                                        

Accuracy = 90.76522827148438 / Best = 90.90544891357422
Epoch 34:


                                                        

Accuracy = 90.98558044433594 / Best = 90.98558044433594
Epoch 35:


                                                        

Accuracy = 90.88542175292969 / Best = 90.98558044433594
Epoch 36:


                                                        

Accuracy = 90.48477172851562 / Best = 90.98558044433594
Epoch 37:


                                                        

Accuracy = 90.63501739501953 / Best = 90.98558044433594
Epoch 38:


                                                        

Accuracy = 91.05569458007812 / Best = 91.05569458007812
Epoch 39:


                                                        

Accuracy = 90.86538696289062 / Best = 91.05569458007812
Epoch 40:


                                                        

Accuracy = 90.9755630493164 / Best = 91.05569458007812
Epoch 41:


                                                        

Accuracy = 91.43629455566406 / Best = 91.43629455566406
Epoch 42:


                                                        

Accuracy = 91.11578369140625 / Best = 91.43629455566406
Epoch 43:


                                                        

Accuracy = 91.27604675292969 / Best = 91.43629455566406
Epoch 44:


                                                        

Accuracy = 91.74679565429688 / Best = 91.74679565429688
Epoch 45:


                                                        

Accuracy = 90.55488586425781 / Best = 91.74679565429688
Epoch 46:


                                                        

Accuracy = 90.98558044433594 / Best = 91.74679565429688
Epoch 47:


                                                        

Accuracy = 90.65504455566406 / Best = 91.74679565429688
Epoch 48:


                                                        

Accuracy = 91.03565979003906 / Best = 91.74679565429688
Epoch 49:


                                                        

Accuracy = 90.69511413574219 / Best = 91.74679565429688
Epoch 50:


                                                        

Accuracy = 90.87539672851562 / Best = 91.74679565429688
Epoch 51:


                                                        

Accuracy = 90.37460327148438 / Best = 91.74679565429688
Epoch 52:


                                                        

Accuracy = 90.79527282714844 / Best = 91.74679565429688
Epoch 53:


                                                        

Accuracy = 91.15585327148438 / Best = 91.74679565429688
Epoch 54:


                                                        

Accuracy = 90.12419891357422 / Best = 91.74679565429688
Epoch 55:


                                                        

Accuracy = 90.24439239501953 / Best = 91.74679565429688
Epoch 56:


                                                        

Accuracy = 90.69511413574219 / Best = 91.74679565429688
Epoch 57:


                                                        

Accuracy = 91.09574890136719 / Best = 91.74679565429688
Epoch 58:


                                                        

Accuracy = 90.31449890136719 / Best = 91.74679565429688
Epoch 59:


                                                        

Accuracy = 90.59495544433594 / Best = 91.74679565429688
Epoch 60:


                                                        

Accuracy = 91.03565979003906 / Best = 91.74679565429688
Epoch 61:


                                                        

Accuracy = 90.80529022216797 / Best = 91.74679565429688
Epoch 62:


                                                        

Accuracy = 91.2059326171875 / Best = 91.74679565429688
Epoch 63:


                                                        

Accuracy = 91.22596740722656 / Best = 91.74679565429688
Epoch 64:


                                                        

Accuracy = 91.45632934570312 / Best = 91.74679565429688
Epoch 65:


                                                        

Accuracy = 91.07572174072266 / Best = 91.74679565429688
Epoch 66:


                                                        

Accuracy = 91.51642608642578 / Best = 91.74679565429688
Epoch 67:


                                                        

Accuracy = 91.71675109863281 / Best = 91.74679565429688
Epoch 68:


                                                        

Accuracy = 91.6366195678711 / Best = 91.74679565429688
Epoch 69:


                                                        

Accuracy = 91.80689239501953 / Best = 91.80689239501953
Epoch 70:


                                                        

Accuracy = 91.7568130493164 / Best = 91.80689239501953
Epoch 71:


                                                        

Accuracy = 91.81690979003906 / Best = 91.81690979003906
Epoch 72:


                                                        

Accuracy = 91.796875 / Best = 91.81690979003906
Epoch 73:


                                                        

Accuracy = 91.96714782714844 / Best = 91.96714782714844
Epoch 74:


                                                        

Accuracy = 91.50640869140625 / Best = 91.96714782714844
Epoch 75:


                                                        

Accuracy = 91.88702392578125 / Best = 91.96714782714844
Epoch 76:


                                                        

Accuracy = 91.7568130493164 / Best = 91.96714782714844
Epoch 77:


                                                        

Accuracy = 91.83694458007812 / Best = 91.96714782714844
Epoch 78:


                                                        

Accuracy = 91.91706848144531 / Best = 91.96714782714844
Epoch 79:


                                                        

Accuracy = 91.92708587646484 / Best = 91.96714782714844
Epoch 80:


                                                        

Accuracy = 92.0272445678711 / Best = 92.0272445678711
Epoch 81:


                                                        

Accuracy = 91.796875 / Best = 92.0272445678711
Epoch 82:


                                                        

Accuracy = 91.85697174072266 / Best = 92.0272445678711
Epoch 83:


                                                        

Accuracy = 91.74679565429688 / Best = 92.0272445678711
Epoch 84:


                                                        

Accuracy = 91.9571304321289 / Best = 92.0272445678711
Epoch 85:


                                                        

Accuracy = 91.93710327148438 / Best = 92.0272445678711
Epoch 86:


                                                        

Accuracy = 91.94711303710938 / Best = 92.0272445678711
Epoch 87:


                                                        

Accuracy = 92.05729675292969 / Best = 92.05729675292969
Epoch 88:


                                                        

Accuracy = 91.9871826171875 / Best = 92.05729675292969
Epoch 89:


                                                        

Accuracy = 92.01722717285156 / Best = 92.05729675292969
Epoch 90:


                                                        

Accuracy = 91.82691955566406 / Best = 92.05729675292969
Epoch 91:


                                                        

Accuracy = 91.90705108642578 / Best = 92.05729675292969
Epoch 92:


                                                        

Accuracy = 92.09735870361328 / Best = 92.09735870361328
Epoch 93:


                                                        

Accuracy = 92.1875 / Best = 92.1875
Epoch 94:


                                                        

Accuracy = 92.11738586425781 / Best = 92.1875
Epoch 95:


                                                        

Accuracy = 92.15745544433594 / Best = 92.1875
Epoch 96:


                                                        

Accuracy = 92.24759674072266 / Best = 92.24759674072266
Epoch 97:


                                                        

Accuracy = 92.1875 / Best = 92.24759674072266
Epoch 98:


                                                        

Accuracy = 92.24759674072266 / Best = 92.24759674072266
Epoch 99:


                                                        

Accuracy = 92.24759674072266 / Best = 92.24759674072266
Epoch 100:


                                                        

Accuracy = 92.31771087646484 / Best = 92.31771087646484
