In [None]:
DATA_DIR = r"D:\Haseeb\Datasets\pacs_data"

### Layer_wise_L1norm_by_source

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from Layer_wise_pruning import iterative_pruning

ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=False,
    importance_type="by_source")

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")

### Layer_wise_L1norm_by_target

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import copy
import os
from dataset import get_pacs_dataloaders
from utils import *
from Layer_wise_pruning import iterative_pruning

ALL_DOMAINS = ['art_painting', 'cartoon', 'photo', 'sketch']
TARGET_DOMAIN = 'sketch'
SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]

BATCH_SIZE = 256
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("--- Starting Warmup Phase ---")
source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
)
num_classes = len(class_to_idx)

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
WARMUP_EPOCHS = 5
best_warmup_acc = 0.0
WARMUP_MODEL_PATH = "best_warmup_model.pth"

if os.path.exists(WARMUP_MODEL_PATH):
    print("Warmup model already exists. Loading and skipping warmup...")
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
else:
    print("No warmup model found. Running warmup training...")
    source_loader_combined, target_loader, class_to_idx = get_pacs_dataloaders(
        data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=True
    )

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    WARMUP_EPOCHS = 5

    for epoch in range(WARMUP_EPOCHS):
        train_vanilla(model, source_loader_combined, optimizer, DEVICE, epoch)
        _, val_acc = evaluate(model, target_loader, DEVICE)
        print(f"  Warmup Epoch {epoch+1} Target Accuracy: {val_acc:.2f}%")

        if val_acc > best_warmup_acc:
            best_warmup_acc = val_acc
            torch.save(model.state_dict(), WARMUP_MODEL_PATH)
            print(f"  New best warmup accuracy: {best_warmup_acc:.2f}%. Checkpoint saved.")

    print(f"\nWarmup finished. Best accuracy: {best_warmup_acc:.2f}%")

print("\n--- Starting Iterative Pruning Phase ---")
source_loaders_list, target_loader, _ = get_pacs_dataloaders(
    data_dir=DATA_DIR, source_domains=SOURCE_DOMAINS, target_domain=TARGET_DOMAIN,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, combine_sources=False
)

pruning_model = models.resnet18()
pruning_model.fc = nn.Linear(pruning_model.fc.in_features, num_classes)
pruning_model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
pruning_model.to(DEVICE)

PRUNE_RATES = [0.10, 0.10, 0.10]
FINETUNE_EPOCHS = 5
FINETUNE_LR = 1e-4
ALPHA = 1.0

final_model, final_mask = iterative_pruning(
    model=pruning_model,
    source_loaders_list=source_loaders_list,
    target_loader=target_loader,
    device=DEVICE,
    prune_rates=PRUNE_RATES,
    retrain_epochs=FINETUNE_EPOCHS,
    lr=FINETUNE_LR,
    alpha=ALPHA,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    SFT=False,
    importance_type="by_target")

# --- 4. Final Evaluation ---
print("\n--- Final Evaluation ---")
baseline_acc = best_warmup_acc
apply_mask(final_model, final_mask)
_, final_acc = evaluate(final_model, target_loader, DEVICE, mask=final_mask)

print("\n--- Pruning Summary ---")
print(f"Baseline Target Accuracy (from best warmup): {baseline_acc:.2f}%")
print(f"Final Target Accuracy (from best pruned model): {final_acc:.2f}%")
improvement = final_acc - baseline_acc
print(f"Improvement: {improvement:+.2f}%")