In [None]:
!pip installed install datasets



In [None]:
# --------------------- Imports ------------------------- #

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16

# Datasets
# from datasets import load_dataset
import kagglehub

# Extra
import os
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt
import timm

## Office-31

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/ATML/Project/OfficeHomeDataset_10072016.zip -d /content/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00050.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00051.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00052.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00053.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00054.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00055.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00056.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00057.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00058.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00059.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00060.jpg  
  inflating: /content/OfficeHomeDataset_10072016/Product/Speaker/00061.jpg  
  inflating

In [None]:
def Office31(domain, transform):
    # Download and set up the dataset path
    path = '/content/OfficeHomeDataset_10072016'
    path = os.path.join(path, domain)
    return datasets.ImageFolder(root=path, transform=transform)

def get_office_data_loaders(batch_size):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    amazon_data = Office31("Art", transform)
    dslr_data = Office31("Clipart", transform)
    webcam_data = Office31("Product", transform)
    irl_data = Office31("Real World", transform)

    # Create data loaders
    loader_art = DataLoader(amazon_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    loader_clp = DataLoader(dslr_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    loader_prod = DataLoader(webcam_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    loader_irl = DataLoader(irl_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    return loader_art, loader_clp, loader_prod, loader_irl


## Pacs

## Model Cards

In [None]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base_model = torchvision.models.resnet50(pretrained=pretrained)
        self.conv1 = base_model.conv1
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        f_block1 = self.layer1(x)
        f_block2 = self.layer2(f_block1)
        f_block3 = self.layer3(f_block2)
        f_block4 = self.layer4(f_block3)

        return f_block2, f_block3, f_block4




class VitBasicFeatureExtractor(nn.Module):

    def __init__(self, pretrained=True, layers=[4, 8, 12]):
        super().__init__()

        self.model = timm.create_model('vit_base_patch16_224', pretrained=pretrained, features_only=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        features = self.model(x)
        pooled_features = []
        for feature in features:
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features)


class SwinBasicFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):

        super(SwinBasicFeatureExtractor, self).__init__()

        self.model = timm.create_model('swin_base_patch4_window7_224',
                                      pretrained=pretrained,
                                      features_only=True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        all_features = self.model(x)

        pooled_features = []
        for feature in all_features:
            feature = feature.permute(0, 3, 1, 2)
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features[1:])

In [None]:
model = SwinBasicFeatureExtractor(pretrained=True)
input = torch.randn(1, 3, 224, 224)
output = model(input)

for feature in output:
    print(feature.shape)

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

torch.Size([1, 256])
torch.Size([1, 512])
torch.Size([1, 1024])


In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class SparseAutoencoder(nn.Module):

    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=768, dropout_rate=0.3):
        super(SparseAutoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming initialization for layers followed by ReLU
                init.kaiming_uniform_(m.weight, a=0, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                # Initialize LayerNorm with weight=1 and bias=0
                init.ones_(m.weight)
                init.zeros_(m.bias)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z


In [None]:
class BaselineModel(nn.Module):
    def __init__(self, feature_extractor, num_classes=31):

        super().__init__()
        self.feature_extractor = feature_extractor
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)
        logits = self.classifier(f4)
        return logits, f4


In [None]:
class UnifiedModelMultiBlockSAE(nn.Module):
    def __init__(self,
                 sae_dim_block2=128,
                 sae_dim_block3=256,
                 sae_dim_block4=256,
                 num_classes=31):
        super().__init__()

        self.feature_extractor = SwinBasicFeatureExtractor(pretrained=True)
        self.sae4 = SparseAutoencoder(input_dim=1024, latent_dim=1024)
        self.classifier = nn.Sequential(
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)

        x_recon4, z4 = self.sae4(f4)

        class_logits = self.classifier(z4)

        return (class_logits,
                (x_recon4, z4),
                (f4))

## Losses

In [None]:
def irm_penalty(logits, labels):

    scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
    loss_erm = F.cross_entropy(scale * logits, labels)

    grad = torch.autograd.grad(loss_erm, [scale], create_graph=True)[0]

    penalty = torch.sum(grad**2)
    var = 0.0

    return loss_erm, penalty, var

In [None]:
def compute_sae_loss(recons, pools, zs, lambda_sparse=[1.0, 1.0, 1.0], lambda_reconstruction=[1.0, 1.0, 1.0]):

    x2_recon, x3_recon, x4_recon = recons
    f2p, f3p, f4p = pools
    z2, z3, z4 = zs

    rec_loss2 = F.mse_loss(x2_recon, f2p)
    rec_loss3 = F.mse_loss(x3_recon, f3p)
    rec_loss4 = F.mse_loss(x4_recon, f4p)

    lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction
    recon_loss_total = (lambda_r1 * rec_loss2 + lambda_r2 * rec_loss3 + lambda_r3 * rec_loss4)

    l1_z2 = torch.mean(torch.abs(z2))
    l1_z3 = torch.mean(torch.abs(z3))
    l1_z4 = torch.mean(torch.abs(z4))
    lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
    l1_sparsity_total = (lambda_s1 * l1_z2 + lambda_s2 * l1_z3 + lambda_s3 * l1_z4)

    return recon_loss_total, l1_sparsity_total


## Trian and Evaluate Functions

### Train

In [None]:
def train_main_irm_multi_sae_office(model,
                              loader_source,
                              test_loader,
                              num_epochs=20,
                              lr=1e-4,
                              lr_sae=1e-4,
                              lambda_irm=1.0,
                              lambda_sae_rec=1.0,
                              lambda_sae_sparse=1e-4,
                              lambda_sparse=[1.0, 1.0, 1.0],
                              lambda_reconstruction=[1.0, 1.0, 1.0],
                              lambda_irm_pair=[1.0, 1.0, 1.0],
                              device='cuda',
                              verbose=False):

    model.to(device)
    model.train()

    def sae_forward_splits(f2p, f3p, f4p):
        x_recon2, z2 = model.sae2(f2p)
        x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon2, z2, x_recon3, z3, x_recon4, z4

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            class_logits_s, _, _, _, _ = model(x_s)

            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * (loss_erm_s)
            irm_pen  = 0.5 * (penalty_s)


            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * (irm_loss) + w2 * (lambda_irm * irm_pen) + w3 * var_s


            with torch.no_grad():
                _, (x_recon2_s, z2_s), (x_recon3_s, z3_s), (x_recon4_s, z4_s), (f2p_s, f3p_s, f4p_s) = model(x_s)

            x_recon2_s, z2_s, x_recon3_s, z3_s, x_recon4_s, z4_s = sae_forward_splits(f2p_s, f3p_s, f4p_s)

            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction

            rec_loss2_s = F.mse_loss(x_recon2_s, f2p_s)
            rec_loss2   = lambda_r1 * (rec_loss2_s)

            rec_loss3_s = F.mse_loss(x_recon3_s, f3p_s)
            rec_loss3   = lambda_r2 * (rec_loss3_s)

            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * (rec_loss4_s)

            rec_loss_total = (rec_loss2 + rec_loss3 + rec_loss4)

            l1_2_s = torch.mean(torch.abs(z2_s))
            l1_2   = lambda_s1 * (l1_2_s)

            l1_3_s = torch.mean(torch.abs(z3_s))
            l1_3   = lambda_s2 * (l1_3_s)

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * (l1_4_s)

            l1_sparsity = l1_2 + l1_3 + l1_4

            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model

In [None]:
def train_model_irm_sae_with_warmup_office(
    batch_size=16,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cuda',
    loader=['A', 'W'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):


    loader_art, loader_clp, loader_prod, loader_irl = get_office_data_loaders(batch_size)

    source, target = loader if loader is not None else ['A', 'C']

    if source == 'A':
        loader_source = loader_art
    elif source == 'C':
        loader_source = loader_clp
    elif source == 'P':
        loader_source = loader_prod
    elif source == 'R':
        loader_source = loader_irl

    if target == 'A':
        loader_target = loader_art
    elif target == 'C':
        loader_target = loader_clp
    elif target == 'P':
        loader_target = loader_prod
    elif target == 'R':
        loader_target = loader_irl


    if model is None:
        model = UnifiedModelMultiBlockSAE(512, 1024, 2048, 65)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_main_irm_multi_sae_office(model,
                                            loader_source=loader_source,
                                            test_loader=loader_target,
                                            num_epochs=num_main_epochs,
                                            lr=lr,
                                            lr_sae=lr_sae,
                                            lambda_irm=lambda_irm,
                                            lambda_sae_rec=lambda_sae_rec,
                                            lambda_sae_sparse=lambda_sae_sparse,
                                            device=device,
                                            verbose=verbose,
                                            lambda_sparse=lambda_sparse,
                                            lambda_reconstruction=lambda_reconstruction,
                                            lambda_irm_pair=lambda_irm_pair
                                            )

    return model

### Separate optm

In [None]:
def train_main_irm_multi_sae_office(
    model,
    loader_source,
    test_loader,
    num_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 1.0],
    device='cuda',
    verbose=False
):
    import torch.optim as optim

    model.to(device)
    model.train()

    def sae_forward_splits(f4p):
        # x_recon2, z2 = model.sae2(f2p)
        # x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon4, z4

    # Define separate optimizers
    # Optimizer for the feature extractor and classifier
    params_rest = [
        p for n, p in model.named_parameters()
        if not (n.startswith('sae2') or n.startswith('sae3') or n.startswith('sae4'))
    ]
    # optimizer_rest = optim.SGD(params_rest, lr=lr, momentum=0.9)
    optimizer_rest = optim.Adam(params_rest, lr=lr)
    # Optimizer for the Sparse Autoencoders (sae2, sae3, sae4)
    params_sae = list(model.sae4.parameters())
    optimizer_sae = optim.Adam(params_sae, lr=lr_sae)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            # Forward pass for classification
            class_logits_s, _, _ = model(x_s)

            # Compute IRM loss
            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * loss_erm_s
            irm_pen  = 0.5 * penalty_s

            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * irm_loss + w2 * (lambda_irm * irm_pen) + w3 * var_s

            # Forward pass for SAEs without tracking gradients
            with torch.no_grad():
                class_logits_s, (x_recon4_s, z4_s), (f4p_s) = model(x_s)

            # Forward pass through SAEs to get reconstructions and latent vectors
            x_recon4_s, z4_s = sae_forward_splits(f4p_s)

            # Compute Reconstruction Loss
            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction


            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * rec_loss4_s

            rec_loss_total = rec_loss4

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * l1_4_s

            l1_sparsity = l1_4

            # Total SAE Loss
            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss


            # Zero gradients for both optimizers
            optimizer_rest.zero_grad()
            optimizer_sae.zero_grad()

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer_rest.step()
            optimizer_sae.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        # Evaluation after each epoch
        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model


### Evaluate

In [None]:
def evaluate_baseline(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    model.train()
    return acc


In [None]:
def evaluate(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    return acc

# MAIN

### Ours

##A -> D **DO** **NOT** **RUN**

In [None]:
combination = ["A", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
Resuming download from 0 bytes (79538472 bytes left)...
Resuming download from https://www.kaggle.com/api/v1/datasets/download/xixuhu/office31?dataset_version_number=1 (0/79538472) bytes left.


100%|██████████| 75.9M/75.9M [00:04<00:00, 17.8MB/s]

Extracting files...





===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 16.9484, SAE Loss: 24.5316
Epoch [1/10], Step [80/89], IRM Loss: 15.9215, SAE Loss: 7.4968
** End of Epoch 1/10 | Test Accuracy: 32.33% **
Epoch [2/10], Step [40/89], IRM Loss: 10.5335, SAE Loss: 2.8149
Epoch [2/10], Step [80/89], IRM Loss: 6.2005, SAE Loss: 3.8766
** End of Epoch 2/10 | Test Accuracy: 76.10% **
Epoch [3/10], Step [40/89], IRM Loss: 2.9101, SAE Loss: 3.6661
Epoch [3/10], Step [80/89], IRM Loss: 2.2937, SAE Loss: 3.8711
** End of Epoch 3/10 | Test Accuracy: 88.68% **
Epoch [4/10], Step [40/89], IRM Loss: 1.3317, SAE Loss: 4.4993
Epoch [4/10], Step [80/89], IRM Loss: 0.5690, SAE Loss: 3.6851
** End of Epoch 4/10 | Test Accuracy: 91.57% **
Epoch [5/10], Step [40/89], IRM Loss: 0.6646, SAE Loss: 3.2237
Epoch [5/10], Step [80/89], IRM Loss: 0.2320, SAE Loss: 3.8896
** End of Epoch 5/10 | Test Accuracy: 89.69% **
Epoch [6/10], Step [40/89], IRM Loss: 0.4437, SAE Loss: 3.5886
Epoch [6/10], Step [80/89], 

##A -> W **DO** **NOT** **RUN**

In [None]:
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse
        )

===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/89], IRM Loss: 17.4886, SAE Loss: 36.3434
Epoch [1/10], Step [80/89], IRM Loss: 16.9284, SAE Loss: 16.5139
** End of Epoch 1/10 | Test Accuracy: 16.27% **
Epoch [2/10], Step [40/89], IRM Loss: 11.7183, SAE Loss: 3.8549
Epoch [2/10], Step [80/89], IRM Loss: 6.7545, SAE Loss: 4.2968
** End of Epoch 2/10 | Test Accuracy: 78.31% **
Epoch [3/10], Step [40/89], IRM Loss: 2.1402, SAE Loss: 4.0041
Epoch [3/10], Step [80/89], IRM Loss: 2.5182, SAE Loss: 4.7085
** End of Epoch 3/10 | Test Accuracy: 87.95% **
Epoch [4/10], Step [40/89], IRM Loss: 1.2146, SAE Loss: 4.8893
Epoch [4/10], Step [80/89], IRM Loss: 2.4019, SAE Loss: 4.2020
** End of Epoch 4/10 | Test Accuracy: 89.36% **
Epoch [5/10], Step [40/89], IRM Loss: 0.5383, SAE Loss: 4.2130
Epoch [5/10], Step [80/89], IRM Loss: 0.8254, SAE Loss: 4.2853
** End of Epoch 5/10 | Test Accuracy: 91.97% **
Epoch [6/10], Step [40/89], IRM Loss: 0.2088, SAE Loss: 4.9362
Epoch [6/10], Step [80/89],

##W -> A **DO** **NOT** **RUN**

In [None]:
combination = ["W", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 3.55% **
** End of Epoch 2/10 | Test Accuracy: 3.80% **
** End of Epoch 3/10 | Test Accuracy: 17.78% **
** End of Epoch 4/10 | Test Accuracy: 12.64% **
** End of Epoch 5/10 | Test Accuracy: 11.61% **
** End of Epoch 6/10 | Test Accuracy: 18.14% **
** End of Epoch 7/10 | Test Accuracy: 39.58% **
** End of Epoch 8/10 | Test Accuracy: 62.62% **
** End of Epoch 9/10 | Test Accuracy: 74.19% **
** End of Epoch 10/10 | Test Accuracy: 76.93% **
--------------------------------------------


## W -> D **DO** **NOT** **RUN**

In [None]:
combination = ["W", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 3.77% **
** End of Epoch 2/10 | Test Accuracy: 5.79% **
** End of Epoch 3/10 | Test Accuracy: 19.25% **
** End of Epoch 4/10 | Test Accuracy: 21.13% **
** End of Epoch 5/10 | Test Accuracy: 30.69% **
** End of Epoch 6/10 | Test Accuracy: 47.30% **
** End of Epoch 7/10 | Test Accuracy: 68.81% **
** End of Epoch 8/10 | Test Accuracy: 89.81% **
** End of Epoch 9/10 | Test Accuracy: 98.36% **
** End of Epoch 10/10 | Test Accuracy: 98.99% **
--------------------------------------------


## D -> W **DO** **NOT** **RUN**

In [None]:
combination = ["D", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/10 | Test Accuracy: 4.62% **
** End of Epoch 2/10 | Test Accuracy: 22.49% **
** End of Epoch 3/10 | Test Accuracy: 32.13% **
** End of Epoch 4/10 | Test Accuracy: 42.77% **
** End of Epoch 5/10 | Test Accuracy: 81.93% **
** End of Epoch 6/10 | Test Accuracy: 98.19% **
** End of Epoch 7/10 | Test Accuracy: 99.00% **
** End of Epoch 8/10 | Test Accuracy: 99.40% **
** End of Epoch 9/10 | Test Accuracy: 99.60% **
** End of Epoch 10/10 | Test Accuracy: 99.60% **
--------------------------------------------


## D -> A **DO** **NOT** **RUN**

In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=10,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 2/10 | Test Accuracy: 9.16% **
** End of Epoch 3/10 | Test Accuracy: 7.24% **
** End of Epoch 4/10 | Test Accuracy: 19.38% **
** End of Epoch 5/10 | Test Accuracy: 42.28% **
** End of Epoch 6/10 | Test Accuracy: 69.76% **
** End of Epoch 7/10 | Test Accuracy: 76.11% **
** End of Epoch 8/10 | Test Accuracy: 77.17% **
** End of Epoch 9/10 | Test Accuracy: 77.64% **
** End of Epoch 10/10 | Test Accuracy: 77.64% **
--------------------------------------------


## Base Vit accuracies

In [None]:
from torchvision.models import swin_b
from tqdm.notebook import tqdm

def get_base_accuracy(loader=None):

    # print(f"------------Combination: {loader}--------------")

    loader_art, loader_clp, loader_prod, loader_irl = get_office_data_loaders(32)

    source, target = loader if loader is not None else ['A', 'C']

    if source == 'A':
        loader_source = loader_art
    elif source == 'C':
        loader_source = loader_clp
    elif source == 'P':
        loader_source = loader_prod
    elif source == 'R':
        loader_source = loader_irl

    if target == 'A':
        loader_target = loader_art
    elif target == 'C':
        loader_target = loader_clp
    elif target == 'P':
        loader_target = loader_prod
    elif target == 'R':
        loader_target = loader_irl


    model = swin_b(weights='IMAGENET1K_V1')
    model.head = nn.Linear(1024, 65)
    model = model.to("cuda")

    optm = torch.optim.Adam(model.parameters(), lr=5e-5)
    epochs = 10

    for epoch in range(epochs):
        model.train()
        for x, y in (loader_source):
            x, y = x.to("cuda"), y.to("cuda")

            optm.zero_grad()
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optm.step()

        print(f"Epoch: {epoch+1}/{epochs} | Loss: {loss.item()}")

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in (loader_target):
                x, y = x.to("cuda"), y.to("cuda")

                logits = model(x)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = 100.0 * correct / total
        print(f"Epoch: {epoch+1}/{epochs} | Test Accuracy: {acc}")

        # print("--------------------------------------------")




In [None]:
combinations = [["A", "C"], ["A", "P"], ["A", "R"], ["C", "P"], ["C", "R"], ["C", "A"], ["P", "R"], ["P", "A"], ["P", "C"], ["R", "A"], ["R", "C"], ["R", "P"]]

for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    get_base_accuracy(loader=combination)
    print("--------------------------------------------")

------------Combination: ['A', 'C']--------------


Downloading: "https://download.pytorch.org/models/swin_b-68c6b09e.pth" to /root/.cache/torch/hub/checkpoints/swin_b-68c6b09e.pth
100%|██████████| 335M/335M [00:02<00:00, 161MB/s]


Epoch: 1/10 | Loss: 2.612029552459717
Epoch: 1/10 | Test Accuracy: 33.585337915234824
Epoch: 2/10 | Loss: 0.7936190962791443
Epoch: 2/10 | Test Accuracy: 48.24742268041237
Epoch: 3/10 | Loss: 0.5184156894683838
Epoch: 3/10 | Test Accuracy: 52.96678121420389
Epoch: 4/10 | Loss: 0.4112170934677124
Epoch: 4/10 | Test Accuracy: 52.485681557846505
Epoch: 5/10 | Loss: 0.20589110255241394
Epoch: 5/10 | Test Accuracy: 53.83734249713631
Epoch: 6/10 | Loss: 0.1458369940519333
Epoch: 6/10 | Test Accuracy: 54.47880870561283
Epoch: 7/10 | Loss: 0.11391185969114304
Epoch: 7/10 | Test Accuracy: 55.18900343642612
Epoch: 8/10 | Loss: 0.12304243445396423
Epoch: 8/10 | Test Accuracy: 54.38717067583047
Epoch: 9/10 | Loss: 0.015888936817646027
Epoch: 9/10 | Test Accuracy: 55.32646048109966
Epoch: 10/10 | Loss: 0.0335274264216423
Epoch: 10/10 | Test Accuracy: 55.37227949599084
--------------------------------------------
------------Combination: ['A', 'P']--------------
Epoch: 1/10 | Loss: 2.606938123703003

In [None]:
# reverse the combination list
combinations_reverse = combinations[::-1]

In [None]:
for combination in combinations_reverse:
    print(f"------------Combination: {combination}--------------")
    lambda_irm_pair = [10.0, 4.0, 0.0]
    lambda_sparse=[0.1, 0.1, 0.1]

    trained_model_sae_office = train_model_irm_sae_with_warmup_office(
                # model=trained_model_sae_office,
                batch_size=32,
                num_warmup_epochs=4,
                num_main_epochs=10,
                lr=1e-5,
                lr_sae=5e-5,
                lambda_irm=1.0,
                lambda_sae_rec=2.0,
                lambda_sae_sparse=2e-4,
                device='cuda',
                verbose=True,
                lambda_irm_pair=lambda_irm_pair,
                lambda_sparse=lambda_sparse,
                loader=combination
            )
    print("--------------------------------------------")

------------Combination: ['R', 'P']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/10], Step [40/137], IRM Loss: 20.8602, SAE Loss: 42.1482
Epoch [1/10], Step [80/137], IRM Loss: 20.8675, SAE Loss: 29.3719
Epoch [1/10], Step [120/137], IRM Loss: 20.4721, SAE Loss: 15.5660
** End of Epoch 1/10 | Test Accuracy: 6.71% **
Epoch [2/10], Step [40/137], IRM Loss: 17.8856, SAE Loss: 5.5534
Epoch [2/10], Step [80/137], IRM Loss: 15.1471, SAE Loss: 4.4189
Epoch [2/10], Step [120/137], IRM Loss: 11.7534, SAE Loss: 4.0728
** End of Epoch 2/10 | Test Accuracy: 33.79% **
Epoch [3/10], Step [40/137], IRM Loss: 7.1011, SAE Loss: 4.8947
Epoch [3/10], Step [80/137], IRM Loss: 5.1481, SAE Loss: 5.3344
Epoch [3/10], Step [120/137], IRM Loss: 4.6132, SAE Loss: 4.0294
** End of Epoch 3/10 | Test Accuracy: 79.91% **
Epoch [4/10], Step [40/137], IRM Loss: 1.9871, SAE Loss: 4.6532
Epoch [4/10], Step [80/137], IRM Loss: 2.7737, SAE Loss: 4.7461
Epoch [4/10], Step [120/137], IRM Loss: 1.2253, SAE Loss

KeyboardInterrupt: 