## Prepare Notebook

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
# Imports

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder

# Datasets
from datasets import load_dataset
import kagglehub

# Extra
import os
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt

## Office-31

In [None]:
def Office31(domain, transform):
    path = kagglehub.dataset_download("xixuhu/office31")
    path = os.path.join(os.path.join(path, "Office-31"), domain)
    return datasets.ImageFolder(root=path, transform=transform)

def get_office_data_loaders(batch_size):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    amazon_data = Office31("amazon", transform)
    dslr_data = Office31("dslr", transform)
    webcam_data = Office31("webcam", transform)

    loader_amazon = DataLoader(amazon_data, batch_size=batch_size, shuffle=True)
    loader_dslr = DataLoader(dslr_data, batch_size=batch_size, shuffle=True)
    loader_webcam = DataLoader(webcam_data, batch_size=batch_size, shuffle=True)

    return loader_amazon, loader_dslr, loader_webcam

In [None]:
loader_amazon, loader_dslr, loader_webcam = get_office_data_loaders(batch_size=32)

Downloading from https://www.kaggle.com/api/v1/datasets/download/xixuhu/office31?dataset_version_number=1...


100%|██████████| 75.9M/75.9M [00:05<00:00, 15.5MB/s]

Extracting files...







## Model Cards

In [None]:
# @title VGG backbone
class VGGFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True, model_name='vgg16'):
        super(VGGFeatureExtractor, self).__init__()

        # Load Weights
        if model_name == 'vgg16':
            base_model = torchvision.models.vgg16(pretrained=pretrained)
        elif model_name == 'vgg19':
            base_model = torchvision.models.vgg19(pretrained=pretrained)

        self.features = base_model.features

        # Block 1: 0-4
        # Block 2: 5-9
        # Block 3: 10-16
        # Block 4: 17-23
        # Block 5: 24-30

        # Define the layers where you want to extract features
        self.layer_names = {
            'block1': 4,
            'block2': 9,
            'block3': 16,
            'block4': 23,
            'block5': 30
        }

    def forward(self, x):
        outputs = {}

        # Straight Forward Pass
        # for name, layer_idx in self.layer_names.items():
        #     # Pass the input through layers up to the current layer
        #     for i, layer in enumerate(self.features[:layer_idx + 1]):
        #         x = layer(x)
        #     outputs[name] = x

        # Optimised Pass
        for i, layer in enumerate(self.features):
          x = layer(x)
          if i == self.layer_names['block3']:
              outputs['block3'] = x
          elif i == self.layer_names['block4']:
              outputs['block4'] = x
          elif i == self.layer_names['block5']:
              outputs['block5'] = x
              break  # Stop after the last required block

        # Return Outputs of certain blocks
        return outputs['block3'], outputs['block4'], outputs['block5']

In [None]:
# @title SparseAutoencoder
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

In [None]:
# @title ClassifierHead
class ClassifierHead(nn.Module):
    def __init__(self, input_dim=2048, num_classes=7):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

In [None]:
# @title Baseline
class BaselineModel(nn.Module):
    def __init__(self, feature_extractor, num_classes=7):

        super().__init__()
        self.feature_extractor = feature_extractor
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Linear(2048, num_classes)

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)
        f4_pool = self.avgpool(f4).view(f4.size(0), -1)

        logits = self.classifier(f4_pool)
        return logits, f4_pool


In [None]:
# @title UnifiedModelMultiBlockSAE
class UnifiedModelMultiBlockSAE(nn.Module):
    def __init__(self,
                 sae_dim_block2=128,
                 sae_dim_block3=256,
                 sae_dim_block4=256,
                 num_classes=7):

        super().__init__()

        # Define Feature Extractor

        # self.feature_extractor = ResNetFeatureExtractor(pretrained=True)
        # self.feature_extractor = VGGFeatureExtractor(pretrained=True)
        self.feature_extractor = VGGFeatureExtractor(pretrained=True, model_name='vgg19')

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.sae2 = SparseAutoencoder(input_dim=256, hidden_dim=sae_dim_block2)
        self.sae3 = SparseAutoencoder(input_dim=512, hidden_dim=sae_dim_block3)
        self.sae4 = SparseAutoencoder(input_dim=512, hidden_dim=sae_dim_block4)

        total_hidden = sae_dim_block2 + sae_dim_block3 + sae_dim_block4
        self.classifier = nn.Linear(total_hidden, num_classes)

    def forward(self, x):

        f2, f3, f4 = self.feature_extractor(x)

        f2_pool = self.avgpool(f2).view(f2.size(0), -1)
        f3_pool = self.avgpool(f3).view(f3.size(0), -1)
        f4_pool = self.avgpool(f4).view(f4.size(0), -1)

        x_recon2, z2 = self.sae2(f2_pool)
        x_recon3, z3 = self.sae3(f3_pool)
        x_recon4, z4 = self.sae4(f4_pool)

        z_concat = torch.cat([z2, z3, z4], dim=1)

        class_logits = self.classifier(z_concat)

        return (class_logits,
                (x_recon2, z2),
                (x_recon3, z3),
                (x_recon4, z4),
                (f2_pool, f3_pool, f4_pool))

## Losses

In [None]:
def irm_penalty(logits, labels):

    scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
    loss_erm = F.cross_entropy(scale * logits, labels)

    grad = torch.autograd.grad(loss_erm, [scale], create_graph=True)[0]

    penalty = torch.sum(grad**2)

    logit_mean = torch.mean(logits, dim=0)
    var = torch.sum((logits - logit_mean.unsqueeze(0))**2)

    return loss_erm, penalty, var

In [None]:
def compute_sae_loss(recons, pools, zs, lambda_sparse=[1.0, 1.0, 1.0], lambda_reconstruction=[1.0, 1.0, 1.0]):

    x2_recon, x3_recon, x4_recon = recons
    f2p, f3p, f4p = pools
    z2, z3, z4 = zs

    rec_loss2 = F.mse_loss(x2_recon, f2p)
    rec_loss3 = F.mse_loss(x3_recon, f3p)
    rec_loss4 = F.mse_loss(x4_recon, f4p)

    lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction
    recon_loss_total = (lambda_r1 * rec_loss2 + lambda_r2 * rec_loss3 + lambda_r3 * rec_loss4)

    l1_z2 = torch.mean(torch.abs(z2))
    l1_z3 = torch.mean(torch.abs(z3))
    l1_z4 = torch.mean(torch.abs(z4))
    lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
    l1_sparsity_total = (lambda_s1 * l1_z2 + lambda_s2 * l1_z3 + lambda_s3 * l1_z4)

    return recon_loss_total, l1_sparsity_total


## Trian and Evaluate Functions

### Train

In [None]:
def train_warmup(model, loader_source_domain1, loader_source_domain2, loader_test,
                 num_warmup_epochs=5,
                 lr=1e-4,
                 device='cuda'):

    model.to(device)
    model.train()


    optimizer = optim.Adam([
        {'params': model.feature_extractor.parameters(), 'lr': lr},
        {'params': model.classifier.parameters(), 'lr': lr},
    ], lr=lr)

    for epoch in range(num_warmup_epochs):

        domain1_iter = iter(loader_source_domain1)
        domain2_iter = iter(loader_source_domain2)
        steps_per_epoch = max(len(domain1_iter), len(domain2_iter))

        total_loss = 0.0
        total_samples = 0

        for step in range(steps_per_epoch):
            try:
                x_s1, y_s1 = next(domain1_iter)
            except StopIteration:
                domain1_iter = iter(loader_source_domain1)
                x_s1, y_s1 = next(domain1_iter)

            try:
                x_s2, y_s2 = next(domain2_iter)
            except StopIteration:
                domain2_iter = iter(loader_source_domain2)
                x_s2, y_s2 = next(domain2_iter)


            x_source = torch.cat([x_s1, x_s2], dim=0)
            y_source = torch.cat([y_s1, y_s2], dim=0)

            indices = torch.randperm(x_source.size(0))
            x_source = x_source[indices]
            y_source = y_source[indices]

            x_source, y_source = x_source.to(device), y_source.to(device)

            class_logits, _, _, _, _ = model(x_source, grl_enabled=False)
            loss_cls = F.cross_entropy(class_logits, y_source)

            optimizer.zero_grad()
            loss_cls.backward()
            optimizer.step()

            bs = x_source.size(0)
            total_loss += loss_cls.item() * bs
            total_samples += bs

        epoch_loss = total_loss / total_samples
        print(f"[Warmup] Epoch [{epoch+1}/{num_warmup_epochs}] Loss: {epoch_loss:.4f}")
        test_acc = evaluate(model, loader_test, device=device)
        print(f"** End of Epoch {epoch+1}/{num_warmup_epochs} | Test Accuracy: {test_acc:.2f}% **")


In [None]:
def train_main(model,
               loader_domain1,
               loader_domain2,
               test_loader,
               num_epochs=20,
               lr=1e-4,
               lambda_irm=1.0,
               lambda_sae_rec=1.0,
               lambda_sae_sparse=1e-4,
               device='cuda'):

    model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):

        d1_iter = iter(loader_domain1)
        d2_iter   = iter(loader_domain2)

        steps_per_epoch = max(len(d1_iter), len(d2_iter))

        for step in range(steps_per_epoch):

            try:
                x_s1, y_s1 = next(d1_iter)
            except StopIteration:
                d1_iter = iter(loader_domain1)
                x_s1, y_s1 = next(d1_iter)

            try:
                x_s2, y_s2 = next(d2_iter)
            except StopIteration:
                d2_iter = iter(loader_domain2)
                x_s2, y_s2 = next(d2_iter)

            x_s1, y_s1 = x_s1.to(device), y_s1.to(device)
            x_s2, y_s2 = x_s2.to(device), y_s2.to(device)


            logits_s1, _, _, _, _ = model(x_s1, grl_enabled=False)
            logits_s2, _, _, _, _ = model(x_s2, grl_enabled=False)

            loss_erm_s1, penalty_s1 = irm_penalty(logits_s1, y_s1)
            loss_erm_s2, penalty_s2 = irm_penalty(logits_s2,   y_s2)

            irm_loss = 0.5 * (loss_erm_s1 + loss_erm_s2)
            irm_pen  = 0.5 * (penalty_s1 + penalty_s2)
            loss_irm = irm_loss + lambda_irm * irm_pen


            with torch.no_grad():
                _, _, _, _, f4p_s1 = model(x_s1, grl_enabled=False)
                _, _, _, _, f4p_s2 = model(x_s2, grl_enabled=False)


            x_recon_s1, z_s1 = model.sae(f4p_s1)
            x_recon_s2, z_s2 = model.sae(f4p_s2)

            recon_loss_s1 = F.mse_loss(x_recon_s1, f4p_s1)
            recon_loss_s2 = F.mse_loss(x_recon_s2, f4p_s2)
            recon_loss   = 0.5 * (recon_loss_s1 + recon_loss_s2)

            l1_sparsity_s1 = torch.mean(torch.abs(z_s1))
            l1_sparsity_s2 = torch.mean(torch.abs(z_s2))
            l1_sparsity   = 0.5 * (l1_sparsity_s1 + l1_sparsity_s2)

            sae_loss = lambda_sae_rec * recon_loss + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % 10 == 0:
                print(f"[Main] Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model


In [None]:
def train_baseline_model_office(model,
                         loader_source,
                         target_loader,
                         num_epochs=20,
                         lr=1e-4,
                         lambda_irm=1.0,
                         device='cuda'):

    model = model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):

        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)
        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            logits_s, _ = model(x_s)
            loss_erm_s, penalty_s = irm_penalty(logits_s, y_s)

            irm_loss = 0.5 * (loss_erm_s)
            irm_pen  = 0.5 * (penalty_s)
            loss_irm = irm_loss + lambda_irm * irm_pen

            optimizer.zero_grad()
            loss_irm.backward()
            optimizer.step()

            if (step+1) % 10 == 0:
                print(f"[Baseline IRM] Epoch [{epoch+1}/{num_epochs}], "
                      f"Step [{step+1}/{steps_per_epoch}], IRM Loss: {loss_irm.item():.4f}")

        test_acc = evaluate_baseline(model, target_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model

In [None]:
def train_main_irm_multi_sae_office(model,
                              loader_source,
                              test_loader,
                              num_epochs=20,
                              lr=1e-4,
                              lambda_irm=1.0,
                              lambda_sae_rec=1.0,
                              lambda_sae_sparse=1e-4,
                              lambda_sparse=[1.0, 1.0, 1.0],
                              lambda_reconstruction=[1.0, 1.0, 1.0],
                              lambda_irm_pair=[1.0, 1.0, 1.0],
                              device='cuda',
                              verbose=False):

    model.to(device)
    model.train()

    def sae_forward_splits(f2p, f3p, f4p):
        x_recon2, z2 = model.sae2(f2p)
        x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon2, z2, x_recon3, z3, x_recon4, z4

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            class_logits_s, _, _, _, _ = model(x_s)

            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * (loss_erm_s)
            irm_pen  = 0.5 * (penalty_s)


            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * (irm_loss) + w2 * (lambda_irm * irm_pen) + w3 * var_s


            with torch.no_grad():
                _, (x_recon2_s, z2_s), (x_recon3_s, z3_s), (x_recon4_s, z4_s), (f2p_s, f3p_s, f4p_s) = model(x_s)

            x_recon2_s, z2_s, x_recon3_s, z3_s, x_recon4_s, z4_s = sae_forward_splits(f2p_s, f3p_s, f4p_s)

            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction

            rec_loss2_s = F.mse_loss(x_recon2_s, f2p_s)
            rec_loss2   = lambda_r1 * (rec_loss2_s)

            rec_loss3_s = F.mse_loss(x_recon3_s, f3p_s)
            rec_loss3   = lambda_r2 * (rec_loss3_s)

            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * (rec_loss4_s)

            rec_loss_total = (rec_loss2 + rec_loss3 + rec_loss4)

            l1_2_s = torch.mean(torch.abs(z2_s))
            l1_2   = lambda_s1 * (l1_2_s)

            l1_3_s = torch.mean(torch.abs(z3_s))
            l1_3   = lambda_s2 * (l1_3_s)

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * (l1_4_s)

            l1_sparsity = l1_2 + l1_3 + l1_4

            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model

In [None]:
def train_model_irm_sae_with_warmup_office(
    batch_size=32,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cuda',
    loader=['A', 'W'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):

    loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(batch_size)

    source, target = loader if loader is not None else ['A', 'W']

    if source == 'A':
        loader_source = loader_amazon
    elif source == 'W':
        loader_source = loader_webcam
    elif source == 'D':
        loader_source = loader_dslr

    if target == 'A':
        loader_target = loader_amazon
    elif target == 'W':
        loader_target = loader_webcam
    elif target == 'D':
        loader_target = loader_dslr


    if model is None:
        model = UnifiedModelMultiBlockSAE(512, 1024, 2048, 31)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_main_irm_multi_sae_office(model,
                                            loader_source=loader_source,
                                            test_loader=loader_target,
                                            num_epochs=num_main_epochs,
                                            lr=lr,
                                            lambda_irm=lambda_irm,
                                            lambda_sae_rec=lambda_sae_rec,
                                            lambda_sae_sparse=lambda_sae_sparse,
                                            device=device,
                                            verbose=verbose,
                                            lambda_sparse=lambda_sparse,
                                            lambda_reconstruction=lambda_reconstruction,
                                            lambda_irm_pair=lambda_irm_pair
                                            )

    return model

### Evaluate

In [None]:
def evaluate_baseline(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    model.train()
    return acc


In [None]:
def evaluate(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _, _, _, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    return acc

# MAIN

In [None]:
feature_extractor_baseline = VGGFeatureExtractor(pretrained=True)
baseline_model = BaselineModel(feature_extractor=feature_extractor_baseline, num_classes=31)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 223MB/s]


In [None]:
loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(32)



## VGG-19 Backbone

In [None]:
combinations = [["A", "W"], ["A", "D"], ["W", "A"], ["W", "D"], ["D", "W"], ["D", "A"]]

In [None]:
# Backbone with
for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=30,
            lr=2e-5,
            lambda_irm=4.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            loader=combination,
            verbose=False
        )
    print("--------------------------------------------")

------------Combination: ['A', 'W']--------------


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 220MB/s]


===== Main Phase (IRM + SAE) =====
** End of Epoch 1/30 | Test Accuracy: 15.86% **
** End of Epoch 2/30 | Test Accuracy: 18.07% **
** End of Epoch 3/30 | Test Accuracy: 25.70% **
** End of Epoch 4/30 | Test Accuracy: 21.69% **
** End of Epoch 5/30 | Test Accuracy: 35.54% **
** End of Epoch 6/30 | Test Accuracy: 31.73% **
** End of Epoch 7/30 | Test Accuracy: 30.52% **
** End of Epoch 8/30 | Test Accuracy: 22.69% **
** End of Epoch 9/30 | Test Accuracy: 29.72% **
** End of Epoch 10/30 | Test Accuracy: 33.94% **
** End of Epoch 11/30 | Test Accuracy: 38.15% **
** End of Epoch 12/30 | Test Accuracy: 33.94% **
** End of Epoch 13/30 | Test Accuracy: 36.55% **
** End of Epoch 14/30 | Test Accuracy: 42.97% **
** End of Epoch 15/30 | Test Accuracy: 42.77% **
** End of Epoch 16/30 | Test Accuracy: 44.38% **
** End of Epoch 17/30 | Test Accuracy: 40.16% **
** End of Epoch 18/30 | Test Accuracy: 42.17% **
** End of Epoch 19/30 | Test Accuracy: 41.77% **
** End of Epoch 20/30 | Test Accuracy: 46.5

## VGG-16 Backbone

In [None]:
feature_extractor_baseline = VGGFeatureExtractor(pretrained=True)
baseline_model = BaselineModel(feature_extractor=feature_extractor_baseline, num_classes=31)



In [None]:
combinations = [["A", "W"], ["A", "D"], ["W", "A"], ["W", "D"], ["D", "W"], ["D", "A"]]

In [None]:
# Backbone with
for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=30,
            lr=2e-5,
            lambda_irm=4.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            loader=combination,
            verbose=False
        )
    print("--------------------------------------------")

------------Combination: ['A', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/30 | Test Accuracy: 31.33% **
** End of Epoch 2/30 | Test Accuracy: 29.52% **
** End of Epoch 3/30 | Test Accuracy: 44.38% **
** End of Epoch 4/30 | Test Accuracy: 44.38% **
** End of Epoch 5/30 | Test Accuracy: 43.17% **
** End of Epoch 6/30 | Test Accuracy: 49.40% **
** End of Epoch 7/30 | Test Accuracy: 41.57% **
** End of Epoch 8/30 | Test Accuracy: 56.02% **
** End of Epoch 9/30 | Test Accuracy: 53.41% **
** End of Epoch 10/30 | Test Accuracy: 56.22% **
** End of Epoch 11/30 | Test Accuracy: 60.24% **
** End of Epoch 12/30 | Test Accuracy: 58.03% **
** End of Epoch 13/30 | Test Accuracy: 59.04% **
** End of Epoch 14/30 | Test Accuracy: 63.45% **
** End of Epoch 15/30 | Test Accuracy: 62.65% **
** End of Epoch 16/30 | Test Accuracy: 65.06% **
** End of Epoch 17/30 | Test Accuracy: 56.43% **
** End of Epoch 18/30 | Test Accuracy: 51.20% **
** End of Epoch 19/30 | Test Accuracy: 60.

In [None]:
combinations = [["A", "D"], ["W", "A"]]

In [None]:
for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=30,
            lr=2e-5,
            lambda_irm=4.0,
            lambda_sae_rec=3.0,
            lambda_sae_sparse=3e-4,
            device='cuda',
            loader=combination,
            verbose=False
        )
    print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/30 | Test Accuracy: 19.37% **
** End of Epoch 2/30 | Test Accuracy: 22.14% **
** End of Epoch 3/30 | Test Accuracy: 34.47% **
** End of Epoch 4/30 | Test Accuracy: 40.38% **
** End of Epoch 5/30 | Test Accuracy: 48.68% **
** End of Epoch 6/30 | Test Accuracy: 49.18% **
** End of Epoch 7/30 | Test Accuracy: 52.58% **
** End of Epoch 8/30 | Test Accuracy: 49.43% **
** End of Epoch 9/30 | Test Accuracy: 56.60% **
** End of Epoch 10/30 | Test Accuracy: 52.45% **
** End of Epoch 11/30 | Test Accuracy: 58.62% **
** End of Epoch 12/30 | Test Accuracy: 59.50% **
** End of Epoch 13/30 | Test Accuracy: 59.37% **
** End of Epoch 14/30 | Test Accuracy: 57.23% **
** End of Epoch 15/30 | Test Accuracy: 58.74% **
** End of Epoch 16/30 | Test Accuracy: 51.32% **
** End of Epoch 17/30 | Test Accuracy: 56.23% **
** End of Epoch 18/30 | Test Accuracy: 52.45% **
** End of Epoch 19/30 | Test Accuracy: 55.

KeyboardInterrupt: 

## VGG19 Model