### Imports

In [1]:
import torch
import torch.nn as nn
import torch.functional as F
from data.dataset import (
    cifar10_trainloader,
    ciaf10_testloader,
    cifar100_trainloader,
    ciaf100_testloader,
)

from pruning.GraSP import saliency_scores, rank_by_saliency, apply_mask

  import pynvml  # type: ignore[import]


### Model and Dataloading

In [2]:
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_vgg16_bn", pretrained=False)
model100 = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_vgg16_bn", pretrained=False)
CEloss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
epochs = 5

Using cache found in C:\Users\Fatim_Sproj/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master
Using cache found in C:\Users\Fatim_Sproj/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


In [9]:
train10 = cifar10_trainloader()
test10= ciaf10_testloader()
train100 = cifar100_trainloader(batch_size=256)
test100 =  ciaf100_testloader(batch_size=256)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Training till 20 percent

In [4]:

from torch.utils.data import Subset
import random

random.seed(42)
indices = random.sample(range(len(train10.dataset)), 500)
subset_500 = Subset(train10.dataset, indices)
subset_loader = torch.utils.data.DataLoader(subset_500, batch_size=64, shuffle=True)


In [5]:
def train_until(model,loss_fn, target_acc, train_loader, test_loader, device,epochs=50):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for (x_val, y_val) in test_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                preds = model(x_val)
                predicted = preds.argmax(dim=1)
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        val_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%")
        if val_acc >= target_acc:
            print(f"Stopping early at epoch {epoch+1} (val acc = {val_acc:.2f}%)")
            break
    return model


In [6]:
def train_until_masked(model, loss_fn, optimizer, target_acc, train_loader, test_loader, device, mask_dict, epochs=50):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask_dict:
                        param.mul_(mask_dict[name]) 

            running_loss += loss.item()

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for (x_val, y_val) in test_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                preds = model(x_val)
                predicted = preds.argmax(dim=1)
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        val_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%")
        if val_acc >= target_acc:
            print(f"Stopping early at epoch {epoch+1} (val acc = {val_acc:.2f}%)")
            break
    return model


### Grasp

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

initial_target_acc = 20
final_target_sparsity = 0.8
stage_fractions = [0.5, 0.75, 1.0] 
target_accuracies = [40, 60]      
epochs_per_stage = 50

CEloss = torch.nn.CrossEntropyLoss()

print("\n[Stage 0] Training randomly initialized model to ~20% accuracy")
model = train_until(
    model=model,
    loss_fn=CEloss,
    target_acc=initial_target_acc,
    train_loader=train10,
    test_loader=test10,
    device=device,
    epochs=epochs_per_stage
)
torch.save(model.state_dict(), "stage0_trained_model.pt")
print("Saved Stage 0 trained model.\n")

current_mask = None
current_sparsity = 0.0

for stage_idx, fraction in enumerate(stage_fractions):
    target_sparsity = fraction * final_target_sparsity
    print(f"\n=== Pruning Stage {stage_idx + 1} ===")
    print(f"Target sparsity: {target_sparsity*100:.1f}%")
    
    prev_stage_model = f"stage{stage_idx}_trained_model.pt" if stage_idx > 0 else "stage0_trained_model.pt"
    model.load_state_dict(torch.load(prev_stage_model, map_location=device))
    model.to(device)

    print("→ Computing saliency scores...")
    scores = saliency_scores(model, subset_loader, device, CEloss)

    mask, thresh = rank_by_saliency(
        scores=scores,
        current_mask=current_mask,
        current_sparsity=current_sparsity,
        target_sparsity=target_sparsity
    )

    if mask is not None:
        apply_mask(model, mask)
        print(f"→ Applied pruning mask up to {target_sparsity*100:.1f}% sparsity (threshold={thresh}).")

    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.reset_running_stats()

    if stage_idx < len(stage_fractions) - 1:
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        model = train_until_masked(
            model=model,
            loss_fn=CEloss,
            optimizer=optimizer,
            target_acc=target_accuracies[stage_idx],
            train_loader=train10,
            test_loader=test10,
            device=device,
            mask_dict=mask if mask is not None else current_mask,
            epochs=epochs_per_stage
        )

    torch.save(model.state_dict(), f"stage{stage_idx+1}_trained_model.pt")
    if mask is not None:
        torch.save(mask, f"stage{stage_idx+1}_mask.pt")
        current_mask = mask

    current_sparsity = target_sparsity
    print(f"Stage {stage_idx + 1} complete and saved.")

print("\n=== Final Profiling ===")
model.load_state_dict(torch.load(f"stage{len(stage_fractions)}_trained_model.pt", map_location=device))



[Stage 0] Training randomly initialized model to ~20% accuracy
Epoch [1/50] - Loss: 2.3014 | Val Acc: 17.70%
Epoch [2/50] - Loss: 2.2951 | Val Acc: 23.08%
Stopping early at epoch 2 (val acc = 23.08%)
Saved Stage 0 trained model.


=== Pruning Stage 1 ===
Target sparsity: 40.0%
→ Computing saliency scores...


  model.load_state_dict(torch.load(prev_stage_model, map_location=device))


[rank_by_saliency] Pruned 6101431 / 15253578 (target 6101431). Kept 9152147 params. threshold=6.12037e-07
→ Applied pruning mask up to 40.0% sparsity (threshold=6.120370699136402e-07).
Epoch [1/50] - Loss: 1.5019 | Val Acc: 55.11%
Stopping early at epoch 1 (val acc = 55.11%)
Stage 1 complete and saved.

=== Pruning Stage 2 ===
Target sparsity: 60.0%
→ Computing saliency scores...
[rank_by_saliency] Pruned 9152147 / 15253578 (target 9152147). Kept 6101431 params. threshold=0.000415872
→ Applied pruning mask up to 60.0% sparsity (threshold=0.00041587205487303436).
Epoch [1/50] - Loss: 1.0177 | Val Acc: 70.07%
Stopping early at epoch 1 (val acc = 70.07%)
Stage 2 complete and saved.

=== Pruning Stage 3 ===
Target sparsity: 80.0%
→ Computing saliency scores...
[rank_by_saliency] Pruned 12202862 / 15253578 (target 12202862). Kept 3050716 params. threshold=0.000828134
→ Applied pruning mask up to 80.0% sparsity (threshold=0.0008281336631625891).
Stage 3 complete and saved.

=== Final Profili

  model.load_state_dict(torch.load(f"stage{len(stage_fractions)}_trained_model.pt", map_location=device))


<All keys matched successfully>

In [16]:
import os
import torch
from torch import nn, optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_vgg16_bn", pretrained=False)
modelpath = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\aiedge\Pruning\intermediate_models\stage3_trained_model.pt"
model.load_state_dict(torch.load(modelpath, map_location=device))
model.to(device)

try:
    maskpath = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\aiedge\Pruning\intermediate_models\stage3_mask.pt"
    final_mask = torch.load(maskpath, map_location=device)
    from pruning.GraSP import apply_mask
    apply_mask(model, final_mask)
    print("Applied final pruning mask before fine-tuning.")
except FileNotFoundError:
    final_mask = None
    print("No final mask found. Continuing without reapplying mask.")

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.reset_running_stats()

CEloss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_ckpt_path = "best_finetuned_model10.pt"

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    acc = 100.0 * correct / total if total > 0 else 0.0
    return acc

def finetune(model, train_loader, test_loader, loss_fn, optimizer, epochs, device, mask=None, best_ckpt_path=best_ckpt_path):
    best_acc = -1.0
    best_epoch = -1

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()

            if mask is not None:
                apply_mask(model, mask)

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

        test_acc = evaluate(model, test_loader, device)

        print(f"Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f} - Test Acc: {test_acc:.2f}%")

        if test_acc > best_acc:
            best_acc = test_acc
            best_epoch = epoch

            if mask is not None:
                apply_mask(model, mask)

            ckpt = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }
            torch.save(ckpt, best_ckpt_path)
            print(f"--> New best model saved (epoch {epoch}, acc {best_acc:.2f}%) to: {best_ckpt_path}")

    print(f"Finished fine-tuning. Best epoch: {best_epoch} with Test Acc: {best_acc:.2f}%")
    return best_epoch, best_acc

best_epoch, best_acc = finetune(
    model=model,
    train_loader=train10,
    test_loader=test10,
    loss_fn=CEloss,
    optimizer=optimizer,
    epochs=50,
    device=device,
    mask=final_mask,
    best_ckpt_path=best_ckpt_path
)

final_state_path = "finetuned_model10.pt"
if final_mask is not None:
    apply_mask(model, final_mask)
torch.save(model.state_dict(), final_state_path)
print(f"Final fine-tuned model saved to: {final_state_path}")

Using cache found in C:\Users\Fatim_Sproj/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master
  model.load_state_dict(torch.load(modelpath, map_location=device))
  final_mask = torch.load(maskpath, map_location=device)


Applied final pruning mask before fine-tuning.
Epoch [1/50] - Train Loss: 0.7103 - Test Acc: 75.71%
--> New best model saved (epoch 1, acc 75.71%) to: best_finetuned_model10.pt
Epoch [2/50] - Train Loss: 0.5746 - Test Acc: 77.90%
--> New best model saved (epoch 2, acc 77.90%) to: best_finetuned_model10.pt
Epoch [3/50] - Train Loss: 0.4755 - Test Acc: 78.16%
--> New best model saved (epoch 3, acc 78.16%) to: best_finetuned_model10.pt
Epoch [4/50] - Train Loss: 0.3821 - Test Acc: 79.31%
--> New best model saved (epoch 4, acc 79.31%) to: best_finetuned_model10.pt
Epoch [5/50] - Train Loss: 0.3118 - Test Acc: 79.90%
--> New best model saved (epoch 5, acc 79.90%) to: best_finetuned_model10.pt
Epoch [6/50] - Train Loss: 0.2452 - Test Acc: 79.91%
--> New best model saved (epoch 6, acc 79.91%) to: best_finetuned_model10.pt
Epoch [7/50] - Train Loss: 0.1961 - Test Acc: 79.73%
Epoch [8/50] - Train Loss: 0.1605 - Test Acc: 79.51%
Epoch [9/50] - Train Loss: 0.1391 - Test Acc: 79.45%
Epoch [10/50] 

### cifar100

In [4]:

from torch.utils.data import Subset
import random

random.seed(42)
indices = random.sample(range(len(train100.dataset)), 500)
subset_500 = Subset(train10.dataset, indices)
subset_loader = torch.utils.data.DataLoader(subset_500, batch_size=64, shuffle=True)


In [5]:
def train_until(model, loss_fn, optimizer, target_acc, train_loader, test_loader, device, epochs=50):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for (x_val, y_val) in test_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                preds = model(x_val)
                predicted = preds.argmax(dim=1)
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        
        val_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%")
        if val_acc >= target_acc:
            print(f"Stopping early at epoch {epoch+1} (val acc = {val_acc:.2f}%)")
            break
    return model

def train_until_masked(model, loss_fn, optimizer, target_acc, train_loader, test_loader, device, mask_dict, epochs=50):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask_dict:
                        param.mul_(mask_dict[name]) 

            running_loss += loss.item()

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for (x_val, y_val) in test_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                preds = model(x_val)
                predicted = preds.argmax(dim=1)
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        
        val_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%")
        if val_acc >= target_acc:
            print(f"Stopping early at epoch {epoch+1} (val acc = {val_acc:.2f}%)")
            break
    return model

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model100 = model100.to(device)

initial_target_acc = 20
final_target_sparsity = 0.8
stage_fractions = [0.5, 0.75, 1.0] 
target_accuracies = [40, 60]      
epochs_per_stage = 50

CEloss = torch.nn.CrossEntropyLoss()

optimizer_100 = torch.optim.Adam(model100.parameters(), lr=3e-4)

print("\n[Stage 0] Training CIFAR-100 model to ~20% accuracy")
model100 = train_until(
    model=model100,
    loss_fn=CEloss,
    optimizer=optimizer_100,
    target_acc=initial_target_acc,
    train_loader=train100, 
    test_loader=test100,   
    device=device,
    epochs=epochs_per_stage
)
torch.save(model100.state_dict(), "stage0_trained_model100.pt")
print("Saved Stage 0 trained model for CIFAR-100.\n")

current_mask = None
current_sparsity = 0.0

for stage_idx, fraction in enumerate(stage_fractions):
    target_sparsity = fraction * final_target_sparsity
    print(f"\n=== Pruning Stage {stage_idx + 1} ===")
    print(f"Target sparsity: {target_sparsity*100:.1f}%")
    
    prev_stage_model = f"stage{stage_idx}_trained_model100.pt" if stage_idx > 0 else "stage0_trained_model100.pt"
    model100.load_state_dict(torch.load(prev_stage_model, map_location=device))
    model100.to(device)

    print("→ Computing saliency scores...")
    scores = saliency_scores(model100, subset_loader, device, CEloss) 

    mask, thresh = rank_by_saliency(
        scores=scores,
        current_mask=current_mask,
        current_sparsity=current_sparsity,
        target_sparsity=target_sparsity
    )

    if mask is not None:
        apply_mask(model100, mask)
        print(f"→ Applied pruning mask up to {target_sparsity*100:.1f}% sparsity (threshold={thresh}).")

    for m in model100.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.reset_running_stats()

    if stage_idx < len(stage_fractions) - 1:
        optimizer_stage = torch.optim.Adam(model100.parameters(), lr=1e-4)
        model100 = train_until_masked(
            model=model100,
            loss_fn=CEloss,
            optimizer=optimizer_stage, 
            target_acc=target_accuracies[stage_idx],
            train_loader=train100,
            test_loader=test100,  
            device=device,
            mask_dict=mask if mask is not None else current_mask,
            epochs=epochs_per_stage
        )
    torch.save(model100.state_dict(), f"stage{stage_idx+1}_trained_model100.pt")
    if mask is not None:
        torch.save(mask, f"stage{stage_idx+1}_mask100.pt")
        current_mask = mask
    current_sparsity = target_sparsity
    print(f"Stage {stage_idx + 1} complete and saved.")
print("\n=== Final Profiling ===\n")
model100.load_state_dict(torch.load(f"stage{len(stage_fractions)}_trained_model100.pt", map_location=device))



[Stage 0] Training CIFAR-100 model to ~20% accuracy
Epoch [1/50] - Loss: 4.3059 | Val Acc: 3.13%
Epoch [2/50] - Loss: 4.0419 | Val Acc: 5.96%
Epoch [3/50] - Loss: 3.8464 | Val Acc: 8.51%
Epoch [4/50] - Loss: 3.5799 | Val Acc: 11.58%
Epoch [5/50] - Loss: 3.2819 | Val Acc: 18.06%
Epoch [6/50] - Loss: 3.0223 | Val Acc: 21.48%
Stopping early at epoch 6 (val acc = 21.48%)
Saved Stage 0 trained model for CIFAR-100.


=== Pruning Stage 1 ===
Target sparsity: 40.0%
→ Computing saliency scores...


  model100.load_state_dict(torch.load(prev_stage_model, map_location=device))


[rank_by_saliency] Pruned 6119899 / 15299748 (target 6119899). Kept 9179849 params. threshold=0.000217016
→ Applied pruning mask up to 40.0% sparsity (threshold=0.00021701646619476378).
Epoch [1/50] - Loss: 2.6827 | Val Acc: 27.75%
Epoch [2/50] - Loss: 2.5425 | Val Acc: 29.68%
Epoch [3/50] - Loss: 2.4388 | Val Acc: 31.33%
Epoch [4/50] - Loss: 2.3319 | Val Acc: 32.72%
Epoch [5/50] - Loss: 2.2289 | Val Acc: 34.07%
Epoch [6/50] - Loss: 2.1276 | Val Acc: 34.49%
Epoch [7/50] - Loss: 2.0274 | Val Acc: 35.49%
Epoch [8/50] - Loss: 1.9306 | Val Acc: 36.35%
Epoch [9/50] - Loss: 1.8410 | Val Acc: 36.67%
Epoch [10/50] - Loss: 1.7530 | Val Acc: 37.42%
Epoch [11/50] - Loss: 1.6629 | Val Acc: 37.33%
Epoch [12/50] - Loss: 1.5765 | Val Acc: 38.21%
Epoch [13/50] - Loss: 1.4927 | Val Acc: 38.02%
Epoch [14/50] - Loss: 1.4070 | Val Acc: 38.09%
Epoch [15/50] - Loss: 1.3287 | Val Acc: 38.03%
Epoch [16/50] - Loss: 1.2641 | Val Acc: 37.63%
Epoch [17/50] - Loss: 1.1885 | Val Acc: 38.58%
Epoch [18/50] - Loss: 1.

  model100.load_state_dict(torch.load(f"stage{len(stage_fractions)}_trained_model100.pt", map_location=device))


<All keys matched successfully>

In [10]:
import os
import torch
from torch import nn, optim


model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_vgg16_bn", pretrained=False)
modelpath = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\aiedge\Pruning\intermediate_models\stage3_trained_model100.pt"
model.load_state_dict(torch.load(modelpath, map_location=device))
model.to(device)

try:
    maskpath = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\aiedge\Pruning\intermediate_models\stage3_mask100.pt"
    final_mask = torch.load(maskpath, map_location=device)
    from pruning.GraSP import apply_mask
    apply_mask(model, final_mask)
    print("Applied final pruning mask before fine-tuning.")
except FileNotFoundError:
    final_mask = None
    print("No final mask found. Continuing without reapplying mask.")

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.reset_running_stats()

CEloss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_ckpt_path = "best_finetuned_model100.pt"

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    acc = 100.0 * correct / total if total > 0 else 0.0
    return acc

def finetune(model, train_loader, test_loader, loss_fn, optimizer, epochs, device, mask=None, best_ckpt_path=best_ckpt_path):
    best_acc = -1.0
    best_epoch = -1

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()

            if mask is not None:
                apply_mask(model, mask)

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

        test_acc = evaluate(model, test_loader, device)

        print(f"Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f} - Test Acc: {test_acc:.2f}%")

        if test_acc > best_acc:
            best_acc = test_acc
            best_epoch = epoch

            if mask is not None:
                apply_mask(model, mask)

            ckpt = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }
            torch.save(ckpt, best_ckpt_path)
            print(f"--> New best model saved (epoch {epoch}, acc {best_acc:.2f}%) to: {best_ckpt_path}")

    print(f"Finished fine-tuning. Best epoch: {best_epoch} with Test Acc: {best_acc:.2f}%")
    return best_epoch, best_acc

best_epoch, best_acc = finetune(
    model=model,
    train_loader=train100,
    test_loader=test100,
    loss_fn=CEloss,
    optimizer=optimizer,
    epochs=50,
    device=device,
    mask=final_mask,
    best_ckpt_path=best_ckpt_path
)

final_state_path = "finetuned_model100.pt"
if final_mask is not None:
    apply_mask(model, final_mask)
torch.save(model.state_dict(), final_state_path)
print(f"Final fine-tuned model saved to: {final_state_path}")

Using cache found in C:\Users\Fatim_Sproj/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master
  model.load_state_dict(torch.load(modelpath, map_location=device))
  final_mask = torch.load(maskpath, map_location=device)


Applied final pruning mask before fine-tuning.
Epoch [1/50] - Train Loss: 1.3987 - Test Acc: 38.68%
--> New best model saved (epoch 1, acc 38.68%) to: best_finetuned_model100.pt
Epoch [2/50] - Train Loss: 0.9197 - Test Acc: 40.10%
--> New best model saved (epoch 2, acc 40.10%) to: best_finetuned_model100.pt
Epoch [3/50] - Train Loss: 0.7861 - Test Acc: 40.69%
--> New best model saved (epoch 3, acc 40.69%) to: best_finetuned_model100.pt
Epoch [4/50] - Train Loss: 0.7156 - Test Acc: 40.95%
--> New best model saved (epoch 4, acc 40.95%) to: best_finetuned_model100.pt
Epoch [5/50] - Train Loss: 0.6628 - Test Acc: 41.05%
--> New best model saved (epoch 5, acc 41.05%) to: best_finetuned_model100.pt
Epoch [6/50] - Train Loss: 0.6268 - Test Acc: 41.05%
Epoch [7/50] - Train Loss: 0.6013 - Test Acc: 41.40%
--> New best model saved (epoch 7, acc 41.40%) to: best_finetuned_model100.pt
Epoch [8/50] - Train Loss: 0.5744 - Test Acc: 41.30%
Epoch [9/50] - Train Loss: 0.5607 - Test Acc: 41.56%
--> New 