In [None]:
# --------------------- Imports ------------------------- #

# PyTorch
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16


# Datasets
import kagglehub

# Extra
import os
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt
import timm

## Office-31

In [None]:
def Office31(domain, transform):
    path = kagglehub.dataset_download("xixuhu/office31")
    path = os.path.join(path, "Office-31", domain)
    return datasets.ImageFolder(root=path, transform=transform)

def get_office_data_loaders(batch_size):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    amazon_data = Office31("amazon", transform)
    dslr_data = Office31("dslr", transform)
    webcam_data = Office31("webcam", transform)

    loader_amazon = DataLoader(amazon_data, batch_size=batch_size, shuffle=True, pin_memory=True)
    loader_dslr = DataLoader(dslr_data, batch_size=batch_size, shuffle=True, pin_memory=True)
    loader_webcam = DataLoader(webcam_data, batch_size=batch_size, shuffle=True, pin_memory=True)

    return loader_amazon, loader_dslr, loader_webcam

## Model Cards

In [None]:
class VitBasicFeatureExtractor(nn.Module):

    def __init__(self, pretrained=True, layers=[4, 8, 12]):
        super().__init__()

        self.model = timm.create_model('vit_base_patch16_224', pretrained=pretrained, features_only=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        features = self.model(x)
        pooled_features = []
        for feature in features:
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features)

In [None]:
# model = VitBasicFeatureExtractor()
# input = torch.randn(1, 3, 224, 224)
# output = model(input)

# for out in output:
#     print(out.shape)

In [None]:
class SparseAutoencoder(nn.Module):

    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=768, dropout_rate=0.3):
        super(SparseAutoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_uniform_(m.weight, a=0, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                init.ones_(m.weight)
                init.zeros_(m.bias)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z


In [None]:
class BaselineModel(nn.Module):
    def __init__(self, pretrained=True, num_classes=31):

        super().__init__()
        self.feature_extractor = VitBasicFeatureExtractor(pretrained=True)
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x):
        _, _, f3 = self.feature_extractor(x)
        logits = self.classifier(f3)
        return logits


In [None]:
class SADA(nn.Module):
    def __init__(self, num_classes=31):
        super().__init__()

        self.feature_extractor = VitBasicFeatureExtractor(pretrained=True)
        self.sae4 = SparseAutoencoder(input_dim=768, latent_dim=768)
        self.classifier = nn.Sequential(
            nn.Linear(768, num_classes)
        )

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)

        x_recon4, z4 = self.sae4(f4)

        class_logits = self.classifier(z4)

        return (class_logits,
                (x_recon4, z4),
                (f4))

## Losses

In [None]:
def irm_penalty(logits, labels):

    scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
    loss_erm = F.cross_entropy(scale * logits, labels)

    grad = torch.autograd.grad(loss_erm, [scale], create_graph=True)[0]

    penalty = torch.sum(grad**2)
    var = 0.0

    return loss_erm, penalty, var

## Trian and Evaluate Functions

### Train

In [None]:
def train_basic_with_irm(model,
                              loader_source,
                              test_loader,
                              num_epochs=20,
                              lr=1e-4,
                              lr_sae=1e-4,
                              lambda_irm=1.0,
                              lambda_sae_rec=1.0,
                              lambda_sae_sparse=1e-4,
                              lambda_sparse=[1.0, 1.0, 1.0],
                              lambda_reconstruction=[1.0, 1.0, 1.0],
                              lambda_irm_pair=[1.0, 1.0, 1.0],
                              device='cpu',
                              verbose=False):

    model.to(device)
    model.train()


    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            class_logits_s = model(x_s)

            loss_erm_s, penalty_s, _ = irm_penalty(class_logits_s, y_s)

            irm_loss = 1 * (loss_erm_s)
            irm_pen  = 1 * (penalty_s)

            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * (irm_loss) + w2 * (lambda_irm * irm_pen)

            loss = loss_irm

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}")

        test_acc = evaluate_baseline(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model

In [None]:
def train_model_irm(
    batch_size=16,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cpu',
    loader=['A', 'W'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):


    loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(batch_size)
    loader_amazon_test, loader_webcam_test, loader_dslr_test = get_office_data_loaders(batch_size)

    source, target = loader if loader is not None else ['A', 'W']

    if source == 'A':
        loader_source = loader_amazon
    elif source == 'W':
        loader_source = loader_webcam
    elif source == 'D':
        loader_source = loader_dslr

    if target == 'A':
        loader_target = loader_amazon_test
    elif target == 'W':
        loader_target = loader_webcam_test
    elif target == 'D':
        loader_target = loader_dslr_test


    if model is None:
        model = BaselineModel(pretrained=True)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_basic_with_irm( model,
                                  loader_source=loader_source,
                                  test_loader=loader_target,
                                  num_epochs=num_main_epochs,
                                  lr=lr,
                                  lr_sae=lr_sae,
                                  lambda_irm=lambda_irm,
                                  lambda_sae_rec=lambda_sae_rec,
                                  lambda_sae_sparse=lambda_sae_sparse,
                                  device=device,
                                  verbose=verbose,
                                  lambda_sparse=lambda_sparse,
                                  lambda_reconstruction=lambda_reconstruction,
                                  lambda_irm_pair=lambda_irm_pair
                                  )

    return model

In [None]:
def train_model_sae(
    batch_size=16,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cuda',
    loader=['A', 'W'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):


    loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(batch_size)
    loader_amazon_test, loader_webcam_test, loader_dslr_test = get_office_data_loaders(batch_size)

    source, target = loader if loader is not None else ['A', 'W']

    if source == 'A':
        loader_source = loader_amazon
    elif source == 'W':
        loader_source = loader_webcam
    elif source == 'D':
        loader_source = loader_dslr

    if target == 'A':
        loader_target = loader_amazon_test
    elif target == 'W':
        loader_target = loader_webcam_test
    elif target == 'D':
        loader_target = loader_dslr_test


    if model is None:
        model = SADA(31)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_main_sae(model,
                                            loader_source=loader_source,
                                            test_loader=loader_target,
                                            num_epochs=num_main_epochs,
                                            lr=lr,
                                            lr_sae=lr_sae,
                                            lambda_irm=lambda_irm,
                                            lambda_sae_rec=lambda_sae_rec,
                                            lambda_sae_sparse=lambda_sae_sparse,
                                            device=device,
                                            verbose=verbose,
                                            lambda_sparse=lambda_sparse,
                                            lambda_reconstruction=lambda_reconstruction,
                                            lambda_irm_pair=lambda_irm_pair
                                            )

    return model

### Separate optm

In [None]:
def train_main_sae(
    model,
    loader_source,
    test_loader,
    num_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 1.0],
    device='cuda',
    verbose=False
):
    import torch.optim as optim

    model.to(device)
    model.train()

    def sae_forward_splits(f4p):
        x_recon4, z4 = model.sae4(f4p)
        return x_recon4, z4

    params_rest = [
        p for n, p in model.named_parameters()
        if not (n.startswith('sae2') or n.startswith('sae3') or n.startswith('sae4'))
    ]

    optimizer_rest = optim.Adam(params_rest, lr=lr)

    params_sae = list(model.sae4.parameters())
    optimizer_sae = optim.Adam(params_sae, lr=lr_sae)
    loss_f_irm = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            # Forward pass for classification
            class_logits_s, _, _ = model(x_s)

            # Compute IRM loss
            # loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            # irm_loss = 0.5 * loss_erm_s
            # irm_pen  = 0.5 * penalty_s

            # w1, w2, w3 = lambda_irm_pair
            # loss_irm = w1 * irm_loss + w2 * (lambda_irm * irm_pen) + w3 * var_s
            loss_irm = loss_f_irm(class_logits_s, y_s)

            # Forward pass for SAEs without tracking gradients
            with torch.no_grad():
                class_logits_s, (x_recon4_s, z4_s), (f4p_s) = model(x_s)

            # Forward pass through SAEs to get reconstructions and latent vectors
            x_recon4_s, z4_s = sae_forward_splits(f4p_s)

            # Compute Reconstruction Loss
            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction


            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * rec_loss4_s

            rec_loss_total = rec_loss4

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * l1_4_s

            l1_sparsity = l1_4

            # Total SAE Loss
            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss


            # Zero gradients for both optimizers
            optimizer_rest.zero_grad()
            optimizer_sae.zero_grad()

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer_rest.step()
            optimizer_sae.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        # Evaluation after each epoch
        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model


### Evaluate

In [None]:
def evaluate_baseline(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    model.train()
    return acc


In [None]:
def evaluate(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    return acc

# ONLY IRM

Using different Lr as SAE + IRM because lower LR gives veryyyy low accuracy (~10%)

##A -> D **DO** **NOT** **RUN**

In [None]:
combination = ["A", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 22.1977
Epoch [1/10], Step [80/89], IRM Loss: 8.7731
** End of Epoch 1/10 | Test Accuracy: 73.08% **
Epoch [2/10], Step [40/89], IRM Loss: 3.7084
Epoch [2/10], Step [80/89], IRM Loss: 1.9287
** End of Epoch 2/10 | Test Accuracy: 77.74% **
Epoch [3/10], Step [40/89], IRM Loss: 0.4109
Epoch [3/10], Step [80/89], IRM Loss: 0.8617
** End of Epoch 3/10 | Test Accuracy: 78.87% **
Epoch [4/10], Step [40/89], IRM Loss: 0.0672
Epoch [4/10], Step [80/89], IRM Loss: 0.6940
** End of Epoch 4/10 | Test Accuracy: 78.24% **
Epoch [5/10], Step [40/89], IRM Loss: 0.0490
Epoch [5/10], Step [80/89], IRM Loss: 0.0490
** End of Epoch 5/10 | Test Accuracy: 78.24% **
Epoch [6/10], Step [40/89], IRM Loss: 0.0661
Epoch [6/10], Step [80/89], IRM Loss: 0.0441
** End of Epoch 6/10 | Test Accuracy: 79.50% **
Epoch [7/10], Step [40/89], IRM Loss: 0.0222
Epoch [7/10], Step [80/89], IRM Loss: 0.0

##A -> W **DO** **NOT** **RUN**

In [None]:
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse
        )

===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 21.7279
Epoch [1/10], Step [80/89], IRM Loss: 7.0227
** End of Epoch 1/10 | Test Accuracy: 70.68% **
Epoch [2/10], Step [40/89], IRM Loss: 2.6860
Epoch [2/10], Step [80/89], IRM Loss: 3.9195
** End of Epoch 2/10 | Test Accuracy: 81.53% **
Epoch [3/10], Step [40/89], IRM Loss: 1.1220
Epoch [3/10], Step [80/89], IRM Loss: 3.2336
** End of Epoch 3/10 | Test Accuracy: 83.33% **
Epoch [4/10], Step [40/89], IRM Loss: 0.1281
Epoch [4/10], Step [80/89], IRM Loss: 0.1681
** End of Epoch 4/10 | Test Accuracy: 82.73% **
Epoch [5/10], Step [40/89], IRM Loss: 0.0497
Epoch [5/10], Step [80/89], IRM Loss: 0.0391
** End of Epoch 5/10 | Test Accuracy: 83.13% **
Epoch [6/10], Step [40/89], IRM Loss: 0.0347
Epoch [6/10], Step [80/89], IRM Loss: 0.0402
** End of Epoch 6/10 | Test Accuracy: 83.73% **
Epoch [7/10], Step [40/89], IRM Loss: 0.0142
Epoch [7/10], Step [80/89], IRM Loss: 0.0564
** End of Epoch 7/10 | Test Accuracy: 84.34% *

##W -> A **DO** **NOT** **RUN**

In [None]:
combination = ["W", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'A']--------------
Downloading from https://www.kaggle.com/api/v1/datasets/download/xixuhu/office31?dataset_version_number=1...


100%|██████████| 75.9M/75.9M [00:04<00:00, 18.0MB/s]

Extracting files...







model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 17.22% **
** End of Epoch 2/10 | Test Accuracy: 26.62% **
** End of Epoch 3/10 | Test Accuracy: 44.87% **
** End of Epoch 4/10 | Test Accuracy: 59.64% **
** End of Epoch 5/10 | Test Accuracy: 62.80% **
** End of Epoch 6/10 | Test Accuracy: 63.90% **
** End of Epoch 7/10 | Test Accuracy: 64.50% **
** End of Epoch 8/10 | Test Accuracy: 64.50% **
** End of Epoch 9/10 | Test Accuracy: 64.04% **
** End of Epoch 10/10 | Test Accuracy: 64.00% **
--------------------------------------------


## W -> D **DO** **NOT** **RUN**

In [None]:
combination = ["W", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 35.09% **
** End of Epoch 2/10 | Test Accuracy: 30.19% **
** End of Epoch 3/10 | Test Accuracy: 64.28% **
** End of Epoch 4/10 | Test Accuracy: 93.84% **
** End of Epoch 5/10 | Test Accuracy: 96.98% **
** End of Epoch 6/10 | Test Accuracy: 97.36% **
** End of Epoch 7/10 | Test Accuracy: 97.61% **
** End of Epoch 8/10 | Test Accuracy: 97.74% **
** End of Epoch 9/10 | Test Accuracy: 97.86% **
** End of Epoch 10/10 | Test Accuracy: 97.99% **
--------------------------------------------


## D -> W **DO** **NOT** **RUN**

In [None]:
combination = ["D", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 54.42% **
** End of Epoch 2/10 | Test Accuracy: 86.55% **
** End of Epoch 3/10 | Test Accuracy: 99.40% **
** End of Epoch 4/10 | Test Accuracy: 99.40% **
** End of Epoch 5/10 | Test Accuracy: 99.60% **
** End of Epoch 6/10 | Test Accuracy: 99.60% **
** End of Epoch 7/10 | Test Accuracy: 99.60% **
** End of Epoch 8/10 | Test Accuracy: 99.60% **
** End of Epoch 9/10 | Test Accuracy: 99.60% **
** End of Epoch 10/10 | Test Accuracy: 99.60% **
--------------------------------------------


## D -> A **DO** **NOT** **RUN**

In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 15.55% **
** End of Epoch 2/10 | Test Accuracy: 41.07% **
** End of Epoch 3/10 | Test Accuracy: 54.10% **
** End of Epoch 4/10 | Test Accuracy: 55.84% **
** End of Epoch 5/10 | Test Accuracy: 56.44% **
** End of Epoch 6/10 | Test Accuracy: 56.44% **
** End of Epoch 7/10 | Test Accuracy: 56.27% **
** End of Epoch 8/10 | Test Accuracy: 56.48% **
** End of Epoch 9/10 | Test Accuracy: 56.34% **
** End of Epoch 10/10 | Test Accuracy: 56.76% **
--------------------------------------------


# ONLY SAE

In [None]:
combination = ["A", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 3.1251, SAE Loss: 4.3712
Epoch [1/10], Step [80/89], IRM Loss: 2.8463, SAE Loss: 4.3677
** End of Epoch 1/10 | Test Accuracy: 75.35% **
Epoch [2/10], Step [40/89], IRM Loss: 1.1197, SAE Loss: 4.0408
Epoch [2/10], Step [80/89], IRM Loss: 0.5273, SAE Loss: 5.1563
** End of Epoch 2/10 | Test Accuracy: 85.41% **
Epoch [3/10], Step [40/89], IRM Loss: 0.3457, SAE Loss: 5.4316
Epoch [3/10], Step [80/89], IRM Loss: 0.3960, SAE Loss: 5.2741
** End of Epoch 3/10 | Test Accuracy: 85.16% **
Epoch [4/10], Step [40/89], IRM Loss: 0.2357, SAE Loss: 4.9880
Epoch [4/10], Step [80/89], IRM Loss: 0.0683, SAE Loss: 5.4592
** End of Epoch 4/10 | Test Accuracy: 88.43% **
Epoch [5/10], Step [40/89], IRM Loss: 0.0714, SAE Loss: 5.2785
Epoch [5/10], Step [80/89], IRM Loss: 0.0743, SAE Loss: 5.7511
** End of Epoch 5/10 | Test Accuracy: 88.93% **
Epoch [6/10], Step [40/89], IRM Loss: 0.0342,

In [None]:
combination = ["A", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'W']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 3.2011, SAE Loss: 5.5974
Epoch [1/10], Step [80/89], IRM Loss: 3.0113, SAE Loss: 4.1160
** End of Epoch 1/10 | Test Accuracy: 75.30% **
Epoch [2/10], Step [40/89], IRM Loss: 0.8934, SAE Loss: 4.8536
Epoch [2/10], Step [80/89], IRM Loss: 0.6474, SAE Loss: 4.7939
** End of Epoch 2/10 | Test Accuracy: 88.96% **
Epoch [3/10], Step [40/89], IRM Loss: 0.3933, SAE Loss: 5.4092
Epoch [3/10], Step [80/89], IRM Loss: 0.2652, SAE Loss: 5.6518
** End of Epoch 3/10 | Test Accuracy: 90.76% **
Epoch [4/10], Step [40/89], IRM Loss: 0.3590, SAE Loss: 4.6993
Epoch [4/10], Step [80/89], IRM Loss: 0.1462, SAE Loss: 5.2465
** End of Epoch 4/10 | Test Accuracy: 90.56% **
Epoch [5/10], Step [40/89], IRM Loss: 0.1614, SAE Loss: 5.2122
Epoch [5/10], Step [80/89], IRM Loss: 0.0944, SAE Loss: 5.6720
** End of Epoch 5/10 | Test Accuracy: 92.77% **
Epoch [6/10], Step [40/89], IRM Loss: 0.0431,

In [None]:
combination = ["W", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 7.42% **
** End of Epoch 2/10 | Test Accuracy: 32.52% **
** End of Epoch 3/10 | Test Accuracy: 57.33% **
** End of Epoch 4/10 | Test Accuracy: 69.01% **
** End of Epoch 5/10 | Test Accuracy: 72.38% **
** End of Epoch 6/10 | Test Accuracy: 73.80% **
** End of Epoch 7/10 | Test Accuracy: 74.09% **
** End of Epoch 8/10 | Test Accuracy: 74.48% **
** End of Epoch 9/10 | Test Accuracy: 74.51% **
** End of Epoch 10/10 | Test Accuracy: 74.94% **
--------------------------------------------


In [None]:
combination = ["W", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 19.50% **
** End of Epoch 2/10 | Test Accuracy: 67.30% **
** End of Epoch 3/10 | Test Accuracy: 89.94% **
** End of Epoch 4/10 | Test Accuracy: 97.23% **
** End of Epoch 5/10 | Test Accuracy: 98.49% **
** End of Epoch 6/10 | Test Accuracy: 98.62% **
** End of Epoch 7/10 | Test Accuracy: 98.87% **
** End of Epoch 8/10 | Test Accuracy: 98.87% **
** End of Epoch 9/10 | Test Accuracy: 98.87% **
** End of Epoch 10/10 | Test Accuracy: 98.87% **
--------------------------------------------


In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 19.31% **
** End of Epoch 2/10 | Test Accuracy: 57.86% **
** End of Epoch 3/10 | Test Accuracy: 69.44% **
** End of Epoch 4/10 | Test Accuracy: 72.84% **
** End of Epoch 5/10 | Test Accuracy: 74.09% **
** End of Epoch 6/10 | Test Accuracy: 74.05% **
** End of Epoch 7/10 | Test Accuracy: 74.51% **
** End of Epoch 8/10 | Test Accuracy: 74.62% **
** End of Epoch 9/10 | Test Accuracy: 74.76% **
** End of Epoch 10/10 | Test Accuracy: 74.94% **
--------------------------------------------


In [None]:
combination = ["D", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_sae(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 17.47% **
** End of Epoch 2/10 | Test Accuracy: 93.57% **
** End of Epoch 3/10 | Test Accuracy: 98.19% **
** End of Epoch 4/10 | Test Accuracy: 99.80% **
** End of Epoch 5/10 | Test Accuracy: 99.80% **
** End of Epoch 6/10 | Test Accuracy: 100.00% **
** End of Epoch 7/10 | Test Accuracy: 100.00% **
** End of Epoch 8/10 | Test Accuracy: 100.00% **
** End of Epoch 9/10 | Test Accuracy: 100.00% **
** End of Epoch 10/10 | Test Accuracy: 100.00% **
--------------------------------------------
