In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from pruning import iterative_pruning

DATA_DIR = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data" 
ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=False
)

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")

  import pynvml  # type: ignore[import]


--- Starting Warmup Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2038 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Combined source dataloader created with 6052 total images.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.
No warmup model found. Running warmup training...
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2038 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Combined source dataloader created with 6052 total images.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.


Epoch 1 Vanilla Training:   0%|          | 0/23 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Warmup Epoch 1 Target Accuracy: 65.21%
  New best warmup accuracy: 65.21%. Checkpoint saved.


Epoch 2 Vanilla Training:   0%|          | 0/23 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Warmup Epoch 2 Target Accuracy: 60.78%


Epoch 3 Vanilla Training:   0%|          | 0/23 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Warmup Epoch 3 Target Accuracy: 60.07%


Epoch 4 Vanilla Training:   0%|          | 0/23 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Warmup Epoch 4 Target Accuracy: 62.10%


Epoch 5 Vanilla Training:   0%|          | 0/23 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from pruning import iterative_pruning

DATA_DIR = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data" 
ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=True
)

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")

--- Starting Warmup Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Combined source dataloader created with 6062 total images.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.
Warmup model already exists. Loading and skipping warmup...

--- Starting Iterative Pruning Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Created 3 separate source dataloaders.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

Initial Baseline Target Accuracy: 64.88%

--- Pruning Iteration 1/3 with base rate 0.1 ---
Computing filter activations per domain...
  - Domain 1/3


  0%|          | 0/8 [00:00<?, ?it/s]

  - Domain 2/3


  0%|          | 0/9 [00:00<?, ?it/s]

  - Domain 3/3


  0%|          | 0/6 [00:00<?, ?it/s]


Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 25/256 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 25/256 active filters (rate 0.100).
  - Layer 'layer3.0.downsample.0': Pruning 25/256 active filters (rate 0.100).
  - Layer 'layer3.1.c

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 60.52%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 67.27%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 63.86%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 65.82%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 65.64%
Iteration 1 | Best Accuracy in this round: 67.27%

--- Pruning Iteration 2/3 with base rate 0.1 ---
Computing filter activations per domain...
  - Domain 1/3


  0%|          | 0/8 [00:00<?, ?it/s]

  - Domain 2/3


  0%|          | 0/9 [00:00<?, ?it/s]

  - Domain 3/3


  0%|          | 0/6 [00:00<?, ?it/s]


Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 23/231 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 23/231 active filters (rate 0.100).
  - Layer 'layer3.0.downsample.0': Pruning 23/231 active filters (rate 0.100).
  - Layer 'layer3.1.c

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 37.54%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 61.21%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 59.76%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 61.80%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 62.69%
Iteration 2 | Best Accuracy in this round: 62.69%

--- Pruning Iteration 3/3 with base rate 0.1 ---
Computing filter activations per domain...
  - Domain 1/3


  0%|          | 0/8 [00:00<?, ?it/s]

  - Domain 2/3


  0%|          | 0/9 [00:00<?, ?it/s]

  - Domain 3/3


  0%|          | 0/6 [00:00<?, ?it/s]


Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 20/208 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 20/208 active filters (rate 0.100).
  - Layer 'layer3.0.downsample.0': Pruning 20/208 active filters (rate 0.100).
  - Layer 'layer3.1.c

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 60.73%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 57.72%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 61.80%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 64.65%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 65.03%
Iteration 3 | Best Accuracy in this round: 65.03%

--- Final Evaluation ---


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]


--- Pruning Summary ---
Baseline Target Accuracy (from best warmup): 0.00%
Final Target Accuracy (from best pruned model): 65.03%
Improvement: +65.03%


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from pruning import iterative_pruning

DATA_DIR = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data" 
ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=True, 
    importance_type="taylor"
)

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")

--- Starting Warmup Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Combined source dataloader created with 6062 total images.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.
Warmup model already exists. Loading and skipping warmup...

--- Starting Iterative Pruning Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Created 3 separate source dataloaders.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

Initial Baseline Target Accuracy: 63.48%

--- Pruning Iteration 1/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 25/256 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 25/256 active filters (rate 0.100).
  - Layer

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 62.71%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 60.04%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 65.87%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 65.72%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 65.87%
Iteration 1 | Best Accuracy in this round: 65.87%

--- Pruning Iteration 2/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 23/231 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 23

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 63.04%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 61.87%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 66.33%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 65.39%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 64.72%
Iteration 2 | Best Accuracy in this round: 66.33%

--- Pruning Iteration 3/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 20/208 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 20

Epoch 1 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 60.04%

Retraining Epoch 2/5


Epoch 2 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 59.30%

Retraining Epoch 3/5


Epoch 3 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 54.52%

Retraining Epoch 4/5


Epoch 4 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 65.56%

Retraining Epoch 5/5


Epoch 5 Train2 (Normal):   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 65.74%
Iteration 3 | Best Accuracy in this round: 65.74%

--- Final Evaluation ---


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]


--- Pruning Summary ---
Baseline Target Accuracy (from best warmup): 0.00%
Final Target Accuracy (from best pruned model): 65.74%
Improvement: +65.74%


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from pruning import iterative_pruning

DATA_DIR = r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data" 
ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=False,
    importance_type="taylor")

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")

--- Starting Warmup Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Combined source dataloader created with 6062 total images.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.
Warmup model already exists. Loading and skipping warmup...

--- Starting Iterative Pruning Phase ---
Creating source datasets for: ['art_painting', 'cartoon', 'photo']
  - Domain 'art_painting' (ID 0) loaded with 2048 images.
  - Domain 'cartoon' (ID 1) loaded with 2344 images.
  - Domain 'photo' (ID 2) loaded with 1670 images.
Created 3 separate source dataloaders.
Creating target dataloader for: sketch
  - Domain 'sketch' loaded with 3929 images.


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

Initial Baseline Target Accuracy: 63.48%

--- Pruning Iteration 1/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/64 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/128 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 25/256 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 25/256 active filters (rate 0.100).
  - Layer

Epoch 1 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 64.83%

Retraining Epoch 2/5


Epoch 2 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 65.05%

Retraining Epoch 3/5


Epoch 3 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 67.60%

Retraining Epoch 4/5


Epoch 4 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 67.47%

Retraining Epoch 5/5


Epoch 5 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 68.03%
Iteration 1 | Best Accuracy in this round: 68.03%

--- Pruning Iteration 2/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/63 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 6/122 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 23/231 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 23

Epoch 1 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 62.53%

Retraining Epoch 2/5


Epoch 2 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 60.70%

Retraining Epoch 3/5


Epoch 3 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 57.83%

Retraining Epoch 4/5


Epoch 4 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 61.34%

Retraining Epoch 5/5


Epoch 5 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 59.28%
Iteration 2 | Best Accuracy in this round: 62.53%

--- Pruning Iteration 3/3 with base rate 0.1 ---

Generating mask with base iterative prune rate: 0.1
  - Layer 'conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.0.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv1': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer1.1.conv2': Pruning 1/62 active filters (rate 0.025).
  - Layer 'layer2.0.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.0.downsample.0': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv1': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer2.1.conv2': Pruning 5/116 active filters (rate 0.050).
  - Layer 'layer3.0.conv1': Pruning 20/208 active filters (rate 0.100).
  - Layer 'layer3.0.conv2': Pruning 20

Epoch 1 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 1 Target Accuracy: 63.15%

Retraining Epoch 2/5


Epoch 2 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 2 Target Accuracy: 54.03%

Retraining Epoch 3/5


Epoch 3 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 3 Target Accuracy: 61.31%

Retraining Epoch 4/5


Epoch 4 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 4 Target Accuracy: 62.56%

Retraining Epoch 5/5


Epoch 5 Training:   0%|          | 0/24 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

  Epoch 5 Target Accuracy: 62.03%
Iteration 3 | Best Accuracy in this round: 63.15%

--- Final Evaluation ---


Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]


--- Pruning Summary ---
Baseline Target Accuracy (from best warmup): 0.00%
Final Target Accuracy (from best pruned model): 63.15%
Improvement: +63.15%
