## SAM with ResNet18

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [5]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim.lr_scheduler import LambdaLR
import random
import numpy as np
import os

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [6]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [7]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        self.rho = rho
        self.adaptive = adaptive
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        scale = self.rho / (grad_norm + 1e-12)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # ascent step
                self.state[p]['e_w'] = e_w

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])  # restore original weights
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        raise NotImplementedError

    def _grad_norm(self):
        shared_device = self.param_groups[0]['params'][0].device
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group['params']
                if p.grad is not None
            ]),
            p=2
        )
        return norm

In [8]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 7)

base_lr = 0.01
weight_decay = 1e-4
warmup_epochs = 3
num_epochs = 30
base_optimizer = torch.optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=0.9,
        weight_decay=weight_decay
    )

def lr_lambda(current_epoch):
    if current_epoch < warmup_epochs:
        # Linear warmup from 0 -> 1
        return float(current_epoch + 1) / float(max(1, warmup_epochs))
    else:
        # Cosine annealing from 1 -> 0
        progress = (current_epoch - warmup_epochs) / float(max(1, num_epochs - warmup_epochs))
        return 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.1415926535)))

optim = SAM(model.parameters(), torch.optim.SGD,
            lr=base_lr, momentum=0.9, weight_decay=weight_decay, rho=0.05)

scheduler = LambdaLR(optim.base_optimizer, lr_lambda=lr_lambda)


criterion = torch.nn.CrossEntropyLoss()

In [9]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "../data/pacs_data/pacs_data/art_painting"
cartoon_dir = "../data/pacs_data/pacs_data/cartoon"
photo_dir = "../data/pacs_data/pacs_data/photo"
sketch_dir = "../data/pacs_data/pacs_data/sketch"

art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

g = torch.Generator()
g.manual_seed(42)

art_loader = DataLoader(
    art_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_loader = DataLoader(
    photo_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

sketch_loader = DataLoader( # This is also the test domain loader
    sketch_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

source_dataset = ConcatDataset([art_dataset, cartoon_dataset, photo_dataset])

source_loader = DataLoader(
    source_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

pacs_classes = sketch_dataset.classes

In [10]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # ---- SAM first step (ascent) ----
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.first_step(zero_grad=True)  # perturb weights

        # ---- SAM second step (descent) ----
        criterion(model(inputs), labels).backward()
        optimizer.second_step(zero_grad=True)

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy


In [11]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [12]:
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, source_loader, criterion, optim, device)
    scheduler.step()
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | LR: {scheduler.get_last_lr()[0]:.5f}")

Epoch 1/30 | Loss: 0.6476 | Train Acc: 77.88% | LR: 0.00667
Epoch 2/30 | Loss: 0.4093 | Train Acc: 86.28% | LR: 0.01000
Epoch 3/30 | Loss: 0.4471 | Train Acc: 86.23% | LR: 0.01000
Epoch 4/30 | Loss: 0.4075 | Train Acc: 87.23% | LR: 0.00997
Epoch 5/30 | Loss: 0.3564 | Train Acc: 88.45% | LR: 0.00987
Epoch 6/30 | Loss: 0.3973 | Train Acc: 88.45% | LR: 0.00970
Epoch 7/30 | Loss: 0.3091 | Train Acc: 90.38% | LR: 0.00947
Epoch 8/30 | Loss: 0.3287 | Train Acc: 89.90% | LR: 0.00918
Epoch 9/30 | Loss: 0.3117 | Train Acc: 90.71% | LR: 0.00883
Epoch 10/30 | Loss: 0.2958 | Train Acc: 90.53% | LR: 0.00843
Epoch 11/30 | Loss: 0.2977 | Train Acc: 90.60% | LR: 0.00799
Epoch 12/30 | Loss: 0.2659 | Train Acc: 91.62% | LR: 0.00750
Epoch 13/30 | Loss: 0.2731 | Train Acc: 91.41% | LR: 0.00698
Epoch 14/30 | Loss: 0.2600 | Train Acc: 91.92% | LR: 0.00643
Epoch 15/30 | Loss: 0.2379 | Train Acc: 92.33% | LR: 0.00587
Epoch 16/30 | Loss: 0.2540 | Train Acc: 92.16% | LR: 0.00529
Epoch 17/30 | Loss: 0.2197 | Trai

### Evaluation on Source Domains

In [13]:
art_accuracy = evaluate(model, art_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(model, cartoon_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(model, photo_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(model, source_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 95.31%
Cartoon Accuracy: 95.26%
Photo Accuracy: 99.40%

All Source Domains Accuracy: 96.42%


### Evaluation on Test Domain

In [14]:
sketch_accuracy = evaluate(model, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 42.35%
