## SAM with ResNet18

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [1]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        self.rho = rho
        self.adaptive = adaptive
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        scale = self.rho / (grad_norm + 1e-12)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # ascent step
                self.state[p]['e_w'] = e_w

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])  # restore original weights
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        raise NotImplementedError

    def _grad_norm(self):
        shared_device = self.param_groups[0]['params'][0].device
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group['params']
                if p.grad is not None
            ]),
            p=2
        )
        return norm

In [3]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 7)

optim = SAM(model.parameters(), torch.optim.Adam, lr=5e-5, rho=0.05)

criterion = torch.nn.CrossEntropyLoss()

In [4]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

art_loader = DataLoader(
    art_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

photo_loader = DataLoader(
    photo_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

sketch_loader = DataLoader( # This is also the test domain loader
    sketch_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

source_dataset = ConcatDataset([art_dataset, cartoon_dataset, photo_dataset])

source_loader = DataLoader(
    source_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

pacs_classes = sketch_dataset.classes

envs = [art_loader, cartoon_loader, photo_loader]

In [5]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # ---- SAM first step (ascent) ----
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.first_step(zero_grad=True)  # perturb weights

        # ---- SAM second step (descent) ----
        criterion(model(inputs), labels).backward()
        optimizer.second_step(zero_grad=True)

        # ---- Track loss and accuracy (use unperturbed outputs) ----
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy


In [6]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [7]:
train_acc = 0
num_epochs = 50
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, source_loader, criterion, optim, device)
    print(f"Epoch {epoch+1}: Loss:  {train_loss:.4f} | | Train Accuracy:  {train_acc:.2f}%")

Epoch 1: Loss:  1.8926 | | Train Accuracy:  23.42%
Epoch 2: Loss:  1.5966 | | Train Accuracy:  45.32%
Epoch 3: Loss:  1.3800 | | Train Accuracy:  58.71%
Epoch 4: Loss:  1.2184 | | Train Accuracy:  67.57%
Epoch 5: Loss:  1.0890 | | Train Accuracy:  72.63%
Epoch 6: Loss:  0.9885 | | Train Accuracy:  76.18%
Epoch 7: Loss:  0.9089 | | Train Accuracy:  78.95%
Epoch 8: Loss:  0.8510 | | Train Accuracy:  79.50%
Epoch 9: Loss:  0.7967 | | Train Accuracy:  81.36%
Epoch 10: Loss:  0.7508 | | Train Accuracy:  82.13%
Epoch 11: Loss:  0.7119 | | Train Accuracy:  82.88%
Epoch 12: Loss:  0.6785 | | Train Accuracy:  83.47%
Epoch 13: Loss:  0.6531 | | Train Accuracy:  83.98%
Epoch 14: Loss:  0.6313 | | Train Accuracy:  84.36%
Epoch 15: Loss:  0.6072 | | Train Accuracy:  84.72%
Epoch 16: Loss:  0.5842 | | Train Accuracy:  84.72%
Epoch 17: Loss:  0.5693 | | Train Accuracy:  85.45%
Epoch 18: Loss:  0.5553 | | Train Accuracy:  85.43%
Epoch 19: Loss:  0.5397 | | Train Accuracy:  85.60%
Epoch 20: Loss:  0.52

### Evaluation on Source Domains

In [8]:
art_accuracy = evaluate(model, art_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(model, cartoon_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(model, photo_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(model, source_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 88.18%
Cartoon Accuracy: 87.20%
Photo Accuracy: 97.43%

All Source Domains Accuracy: 90.35%


### Evaluation on Test Domain

In [9]:
sketch_accuracy = evaluate(model, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 38.58%
