In [None]:
!pip installed install datasets



In [None]:
# --------------------- Imports ------------------------- #

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16

# Datasets
# from datasets import load_dataset
import kagglehub

# Extra
import os
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt
import timm

## Office-31

In [None]:
def Office31(domain, transform):
    # Download and set up the dataset path
    path = kagglehub.dataset_download("xixuhu/office31")
    path = os.path.join(path, "Office-31", domain)
    return datasets.ImageFolder(root=path, transform=transform)

def get_office_data_loaders(batch_size):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    amazon_data = Office31("amazon", transform)
    dslr_data = Office31("dslr", transform)
    webcam_data = Office31("webcam", transform)

    # Create data loaders
    loader_amazon = DataLoader(amazon_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    loader_dslr = DataLoader(dslr_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    loader_webcam = DataLoader(webcam_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    return loader_amazon, loader_dslr, loader_webcam


## Pacs

## Model Cards

In [None]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base_model = torchvision.models.resnet50(pretrained=pretrained)
        self.conv1 = base_model.conv1
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        f_block1 = self.layer1(x)
        f_block2 = self.layer2(f_block1)
        f_block3 = self.layer3(f_block2)
        f_block4 = self.layer4(f_block3)

        return f_block2, f_block3, f_block4




class VitBasicFeatureExtractor(nn.Module):

    def __init__(self, pretrained=True, layers=[4, 8, 12]):
        super().__init__()

        # self.model = timm.create_model('vit_base_patch16_224.sam_in1k', pretrained=pretrained, features_only=True)
        self.model = timm.create_model('vit_base_patch16_224', pretrained=pretrained, features_only=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        features = self.model(x)
        pooled_features = []
        for feature in features:
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features)


class SwinBasicFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):

        super(SwinBasicFeatureExtractor, self).__init__()

        self.model = timm.create_model('swin_base_patch4_window7_224',
                                      pretrained=pretrained,
                                      features_only=True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        all_features = self.model(x)

        pooled_features = []
        for feature in all_features:
            feature = feature.permute(0, 3, 1, 2)
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features[1:])

class DeitBasicFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()

        self.model = timm.create_model('deit_base_patch16_224.fb_in1k',
                                      pretrained=pretrained,
                                      features_only=True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        all_features = self.model(x)

        # return tuple(all_features)

        pooled_features = []
        for feature in all_features:
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features)

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class SparseAutoencoder(nn.Module):

    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=768, dropout_rate=0.3):
        super(SparseAutoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming initialization for layers followed by ReLU
                init.kaiming_uniform_(m.weight, a=0, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                # Initialize LayerNorm with weight=1 and bias=0
                init.ones_(m.weight)
                init.zeros_(m.bias)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z


In [None]:
class BaselineModel(nn.Module):
    def __init__(self, feature_extractor, num_classes=31):

        super().__init__()
        self.feature_extractor = feature_extractor
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)
        logits = self.classifier(f4)
        return logits, f4


In [None]:
class UnifiedModelMultiBlockSAE(nn.Module):
    def __init__(self,
                 sae_dim_block2=128,
                 sae_dim_block3=256,
                 sae_dim_block4=256,
                 num_classes=31):
        super().__init__()

        self.feature_extractor = DeitBasicFeatureExtractor(pretrained=True)
        self.sae4 = SparseAutoencoder(input_dim=768, latent_dim=768)
        self.classifier = nn.Sequential(
            nn.Linear(768, num_classes)
        )

    def forward(self, x):
        _, _, f4 = self.feature_extractor(x)

        x_recon4, z4 = self.sae4(f4)

        class_logits = self.classifier(z4)

        return (class_logits,
                (x_recon4, z4),
                (f4))

## Losses

In [None]:
def irm_penalty(logits, labels):

    scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
    loss_erm = F.cross_entropy(scale * logits, labels)

    grad = torch.autograd.grad(loss_erm, [scale], create_graph=True)[0]

    penalty = torch.sum(grad**2)
    var = 0.0

    return loss_erm, penalty, var

In [None]:
def compute_sae_loss(recons, pools, zs, lambda_sparse=[1.0, 1.0, 1.0], lambda_reconstruction=[1.0, 1.0, 1.0]):

    x2_recon, x3_recon, x4_recon = recons
    f2p, f3p, f4p = pools
    z2, z3, z4 = zs

    rec_loss2 = F.mse_loss(x2_recon, f2p)
    rec_loss3 = F.mse_loss(x3_recon, f3p)
    rec_loss4 = F.mse_loss(x4_recon, f4p)

    lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction
    recon_loss_total = (lambda_r1 * rec_loss2 + lambda_r2 * rec_loss3 + lambda_r3 * rec_loss4)

    l1_z2 = torch.mean(torch.abs(z2))
    l1_z3 = torch.mean(torch.abs(z3))
    l1_z4 = torch.mean(torch.abs(z4))
    lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
    l1_sparsity_total = (lambda_s1 * l1_z2 + lambda_s2 * l1_z3 + lambda_s3 * l1_z4)

    return recon_loss_total, l1_sparsity_total


## Trian and Evaluate Functions

### Train

In [None]:
def train_main_irm_multi_sae_office(model,
                              loader_source,
                              test_loader,
                              num_epochs=20,
                              lr=1e-4,
                              lr_sae=1e-4,
                              lambda_irm=1.0,
                              lambda_sae_rec=1.0,
                              lambda_sae_sparse=1e-4,
                              lambda_sparse=[1.0, 1.0, 1.0],
                              lambda_reconstruction=[1.0, 1.0, 1.0],
                              lambda_irm_pair=[1.0, 1.0, 1.0],
                              device='cuda',
                              verbose=False):

    model.to(device)
    model.train()

    def sae_forward_splits(f2p, f3p, f4p):
        x_recon2, z2 = model.sae2(f2p)
        x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon2, z2, x_recon3, z3, x_recon4, z4

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            class_logits_s, _, _, _, _ = model(x_s)

            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * (loss_erm_s)
            irm_pen  = 0.5 * (penalty_s)


            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * (irm_loss) + w2 * (lambda_irm * irm_pen) + w3 * var_s


            with torch.no_grad():
                _, (x_recon2_s, z2_s), (x_recon3_s, z3_s), (x_recon4_s, z4_s), (f2p_s, f3p_s, f4p_s) = model(x_s)

            x_recon2_s, z2_s, x_recon3_s, z3_s, x_recon4_s, z4_s = sae_forward_splits(f2p_s, f3p_s, f4p_s)

            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction

            rec_loss2_s = F.mse_loss(x_recon2_s, f2p_s)
            rec_loss2   = lambda_r1 * (rec_loss2_s)

            rec_loss3_s = F.mse_loss(x_recon3_s, f3p_s)
            rec_loss3   = lambda_r2 * (rec_loss3_s)

            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * (rec_loss4_s)

            rec_loss_total = (rec_loss2 + rec_loss3 + rec_loss4)

            l1_2_s = torch.mean(torch.abs(z2_s))
            l1_2   = lambda_s1 * (l1_2_s)

            l1_3_s = torch.mean(torch.abs(z3_s))
            l1_3   = lambda_s2 * (l1_3_s)

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * (l1_4_s)

            l1_sparsity = l1_2 + l1_3 + l1_4

            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model

In [None]:
def train_model_irm_sae_with_warmup_office(
    batch_size=16,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cuda',
    loader=['A', 'W'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):


    loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(batch_size)
    loader_amazon_test, loader_webcam_test, loader_dslr_test = get_office_data_loaders(batch_size)

    source, target = loader if loader is not None else ['A', 'W']

    if source == 'A':
        loader_source = loader_amazon
    elif source == 'W':
        loader_source = loader_webcam
    elif source == 'D':
        loader_source = loader_dslr

    if target == 'A':
        loader_target = loader_amazon_test
    elif target == 'W':
        loader_target = loader_webcam_test
    elif target == 'D':
        loader_target = loader_dslr_test


    if model is None:
        model = UnifiedModelMultiBlockSAE(512, 1024, 2048, 31)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_main_irm_multi_sae_office(model,
                                            loader_source=loader_source,
                                            test_loader=loader_target,
                                            num_epochs=num_main_epochs,
                                            lr=lr,
                                            lr_sae=lr_sae,
                                            lambda_irm=lambda_irm,
                                            lambda_sae_rec=lambda_sae_rec,
                                            lambda_sae_sparse=lambda_sae_sparse,
                                            device=device,
                                            verbose=verbose,
                                            lambda_sparse=lambda_sparse,
                                            lambda_reconstruction=lambda_reconstruction,
                                            lambda_irm_pair=lambda_irm_pair
                                            )

    return model

### Separate optm

In [None]:
def train_main_irm_multi_sae_office(
    model,
    loader_source,
    test_loader,
    num_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 1.0],
    device='cuda',
    verbose=False
):
    import torch.optim as optim

    model.to(device)
    model.train()

    def sae_forward_splits(f4p):
        # x_recon2, z2 = model.sae2(f2p)
        # x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon4, z4

    # Define separate optimizers
    # Optimizer for the feature extractor and classifier
    params_rest = [
        p for n, p in model.named_parameters()
        if not (n.startswith('sae2') or n.startswith('sae3') or n.startswith('sae4'))
    ]
    # optimizer_rest = optim.SGD(params_rest, lr=lr, momentum=0.9)
    optimizer_rest = optim.Adam(params_rest, lr=lr)
    # Optimizer for the Sparse Autoencoders (sae2, sae3, sae4)
    params_sae = list(model.sae4.parameters())
    optimizer_sae = optim.Adam(params_sae, lr=lr_sae)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            # Forward pass for classification
            class_logits_s, _, _ = model(x_s)

            # Compute IRM loss
            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * loss_erm_s
            irm_pen  = 0.5 * penalty_s

            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * irm_loss + w2 * (lambda_irm * irm_pen) + w3 * var_s

            # Forward pass for SAEs without tracking gradients
            with torch.no_grad():
                class_logits_s, (x_recon4_s, z4_s), (f4p_s) = model(x_s)

            # Forward pass through SAEs to get reconstructions and latent vectors
            x_recon4_s, z4_s = sae_forward_splits(f4p_s)

            # Compute Reconstruction Loss
            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction


            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * rec_loss4_s

            rec_loss_total = rec_loss4

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * l1_4_s

            l1_sparsity = l1_4

            # Total SAE Loss
            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss


            # Zero gradients for both optimizers
            optimizer_rest.zero_grad()
            optimizer_sae.zero_grad()

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer_rest.step()
            optimizer_sae.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        # Evaluation after each epoch
        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model


### Evaluate

In [None]:
def evaluate_baseline(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    model.train()
    return acc


In [None]:
def evaluate(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    return acc

# MAIN

### Ours

##A -> D **DO** **NOT** **RUN**

In [None]:
combination = ["A", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/15], Step [40/89], IRM Loss: 17.0498, SAE Loss: 27.8333
Epoch [1/15], Step [80/89], IRM Loss: 15.9341, SAE Loss: 23.7531
** End of Epoch 1/15 | Test Accuracy: 31.19% **
Epoch [2/15], Step [40/89], IRM Loss: 13.0841, SAE Loss: 17.6362
Epoch [2/15], Step [80/89], IRM Loss: 9.8573, SAE Loss: 20.0936
** End of Epoch 2/15 | Test Accuracy: 63.14% **
Epoch [3/15], Step [40/89], IRM Loss: 6.4672, SAE Loss: 19.9149
Epoch [3/15], Step [80/89], IRM Loss: 6.2289, SAE Loss: 20.3605
** End of Epoch 3/15 | Test Accuracy: 73.08% **
Epoch [4/15], Step [40/89], IRM Loss: 4.9000, SAE Loss: 21.4384
Epoch [4/15], Step [80/89], IRM Loss: 3.3247, SAE Loss: 20.0976
** End of Epoch 4/15 | Test Accuracy: 76.23% **
Epoch [5/15], Step [40/89], IRM Loss: 4.1440, SAE Loss: 19.6685
Epoch [5/15], Step [80/89], IRM Loss: 2.3622, SAE Loss: 21.9492
** End of Epoch 5/15 | Test Accuracy: 78.62% **
Epoch [6/15], Step [40/89], IRM 

##A -> W **DO** **NOT** **RUN**

In [None]:
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse
        )

===== Main Phase (IRM + SAE) =====
Epoch [1/15], Step [40/89], IRM Loss: 17.3859, SAE Loss: 29.8607
Epoch [1/15], Step [80/89], IRM Loss: 16.7197, SAE Loss: 24.8834
** End of Epoch 1/15 | Test Accuracy: 38.76% **
Epoch [2/15], Step [40/89], IRM Loss: 13.4108, SAE Loss: 17.3458
Epoch [2/15], Step [80/89], IRM Loss: 9.6741, SAE Loss: 19.0737
** End of Epoch 2/15 | Test Accuracy: 45.78% **
Epoch [3/15], Step [40/89], IRM Loss: 6.9941, SAE Loss: 20.3981
Epoch [3/15], Step [80/89], IRM Loss: 5.6337, SAE Loss: 20.8559
** End of Epoch 3/15 | Test Accuracy: 66.87% **
Epoch [4/15], Step [40/89], IRM Loss: 5.1578, SAE Loss: 20.9708
Epoch [4/15], Step [80/89], IRM Loss: 4.4381, SAE Loss: 22.5657
** End of Epoch 4/15 | Test Accuracy: 74.30% **
Epoch [5/15], Step [40/89], IRM Loss: 4.0525, SAE Loss: 19.4942
Epoch [5/15], Step [80/89], IRM Loss: 3.0195, SAE Loss: 22.0497
** End of Epoch 5/15 | Test Accuracy: 80.12% **
Epoch [6/15], Step [40/89], IRM Loss: 0.6559, SAE Loss: 23.1317
Epoch [6/15], Step

##W -> A **DO** **NOT** **RUN**

In [None]:
combination = ["W", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/15 | Test Accuracy: 6.64% **
** End of Epoch 2/15 | Test Accuracy: 7.14% **
** End of Epoch 3/15 | Test Accuracy: 9.76% **
** End of Epoch 4/15 | Test Accuracy: 5.18% **
** End of Epoch 5/15 | Test Accuracy: 8.56% **
** End of Epoch 6/15 | Test Accuracy: 12.25% **
** End of Epoch 8/15 | Test Accuracy: 17.61% **
** End of Epoch 9/15 | Test Accuracy: 24.46% **
** End of Epoch 10/15 | Test Accuracy: 30.99% **
** End of Epoch 11/15 | Test Accuracy: 35.32% **
** End of Epoch 12/15 | Test Accuracy: 42.60% **
** End of Epoch 13/15 | Test Accuracy: 47.00% **
** End of Epoch 14/15 | Test Accuracy: 53.85% **
** End of Epoch 15/15 | Test Accuracy: 56.12% **
--------------------------------------------


## W -> D **DO** **NOT** **RUN**

In [None]:
combination = ["W", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/15 | Test Accuracy: 10.31% **
** End of Epoch 2/15 | Test Accuracy: 18.36% **
** End of Epoch 3/15 | Test Accuracy: 27.30% **
** End of Epoch 4/15 | Test Accuracy: 12.96% **
** End of Epoch 5/15 | Test Accuracy: 25.28% **
** End of Epoch 6/15 | Test Accuracy: 37.36% **
** End of Epoch 7/15 | Test Accuracy: 38.99% **
** End of Epoch 8/15 | Test Accuracy: 53.21% **
** End of Epoch 9/15 | Test Accuracy: 61.76% **
** End of Epoch 10/15 | Test Accuracy: 69.94% **
** End of Epoch 11/15 | Test Accuracy: 74.97% **
** End of Epoch 12/15 | Test Accuracy: 84.91% **
** End of Epoch 13/15 | Test Accuracy: 87.80% **
** End of Epoch 14/15 | Test Accuracy: 89.69% **
** End of Epoch 15/15 | Test Accuracy: 91.45% **
--------------------------------------------


## D -> W **DO** **NOT** **RUN**

In [None]:
combination = ["D", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/15 | Test Accuracy: 14.66% **
** End of Epoch 2/15 | Test Accuracy: 31.33% **
** End of Epoch 3/15 | Test Accuracy: 19.88% **
** End of Epoch 4/15 | Test Accuracy: 33.94% **
** End of Epoch 5/15 | Test Accuracy: 47.99% **
** End of Epoch 6/15 | Test Accuracy: 60.24% **
** End of Epoch 7/15 | Test Accuracy: 67.67% **
** End of Epoch 8/15 | Test Accuracy: 77.51% **
** End of Epoch 9/15 | Test Accuracy: 85.54% **
** End of Epoch 10/15 | Test Accuracy: 93.78% **
** End of Epoch 11/15 | Test Accuracy: 96.39% **
** End of Epoch 12/15 | Test Accuracy: 97.59% **
** End of Epoch 13/15 | Test Accuracy: 98.39% **
** End of Epoch 14/15 | Test Accuracy: 98.59% **
** End of Epoch 15/15 | Test Accuracy: 98.59% **
--------------------------------------------


## D -> A **DO** **NOT** **RUN**

In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/15 | Test Accuracy: 8.31% **
** End of Epoch 2/15 | Test Accuracy: 17.22% **
** End of Epoch 3/15 | Test Accuracy: 20.87% **
** End of Epoch 4/15 | Test Accuracy: 17.07% **
** End of Epoch 5/15 | Test Accuracy: 19.17% **
** End of Epoch 6/15 | Test Accuracy: 32.02% **
** End of Epoch 7/15 | Test Accuracy: 43.84% **
** End of Epoch 8/15 | Test Accuracy: 52.96% **
** End of Epoch 9/15 | Test Accuracy: 59.96% **
** End of Epoch 10/15 | Test Accuracy: 63.47% **
** End of Epoch 11/15 | Test Accuracy: 66.42% **


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d2200302950>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7d2200302950>Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():    
if w.is_alive():  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

      File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a

** End of Epoch 12/15 | Test Accuracy: 67.23% **
** End of Epoch 13/15 | Test Accuracy: 68.23% **
** End of Epoch 14/15 | Test Accuracy: 68.48% **
** End of Epoch 15/15 | Test Accuracy: 68.83% **
--------------------------------------------


## Base Vit accuracies

In [None]:
from torchvision.models import vit_b_16
from tqdm.notebook import tqdm
from transformers import DeiTForImageClassification


def get_base_accuracy(loader=None):

    # print(f"------------Combination: {loader}--------------")

    loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(32)

    source, target = loader if loader is not None else ['A', 'W']

    if source == 'A':
        loader_source = loader_amazon
    elif source == 'W':
        loader_source = loader_webcam
    elif source == 'D':
        loader_source = loader_dslr

    if target == 'A':
        loader_target = loader_amazon
    elif target == 'W':
        loader_target = loader_webcam
    elif target == 'D':
        loader_target = loader_dslr


    model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
    model.classifier = nn.Linear(768, 31)
    model = model.to("cuda")

    optm = torch.optim.Adam(model.parameters(), lr=5e-5)
    epochs = 10

    for epoch in range(epochs):
        model.train()
        for x, y in (loader_source):
            x, y = x.to("cuda"), y.to("cuda")

            optm.zero_grad()
            logits = model(x).logits
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optm.step()

        print(f"Epoch: {epoch+1}/{epochs} | Loss: {loss.item()}")

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in (loader_target):
                x, y = x.to("cuda"), y.to("cuda")

                logits = model(x).logits
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = 100.0 * correct / total
        print(f"Epoch: {epoch+1}/{epochs} | Test Accuracy: {acc}")

        # print("--------------------------------------------")




In [None]:
combinations = [["A", "W"], ["A", "D"], ["W", "A"], ["W", "D"], ["D", "W"], ["D", "A"]]

for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    get_base_accuracy(loader=combination)
    print("--------------------------------------------")

------------Combination: ['A', 'W']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 0.03427192568778992
Epoch: 1/10 | Test Accuracy: 77.30923694779116
Epoch: 2/10 | Loss: 0.07789202779531479
Epoch: 2/10 | Test Accuracy: 79.31726907630522
Epoch: 3/10 | Loss: 0.052264418452978134
Epoch: 3/10 | Test Accuracy: 80.12048192771084
Epoch: 4/10 | Loss: 0.0008509114268235862
Epoch: 4/10 | Test Accuracy: 79.51807228915662
Epoch: 5/10 | Loss: 0.01712827943265438
Epoch: 5/10 | Test Accuracy: 80.32128514056225
Epoch: 6/10 | Loss: 0.0013896104646846652
Epoch: 6/10 | Test Accuracy: 80.72289156626506
Epoch: 7/10 | Loss: 0.0007952864980325103
Epoch: 7/10 | Test Accuracy: 80.92369477911646
Epoch: 8/10 | Loss: 0.00029702542815357447
Epoch: 8/10 | Test Accuracy: 80.72289156626506
Epoch: 9/10 | Loss: 0.0013797297142446041
Epoch: 9/10 | Test Accuracy: 80.72289156626506
Epoch: 10/10 | Loss: 0.007620549760758877
Epoch: 10/10 | Test Accuracy: 80.32128514056225
--------------------------------------------
------------Combination: ['A', 'D']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 0.06954265385866165
Epoch: 1/10 | Test Accuracy: 74.71698113207547
Epoch: 2/10 | Loss: 0.006148708052933216
Epoch: 2/10 | Test Accuracy: 76.22641509433963
Epoch: 3/10 | Loss: 0.002998382318764925
Epoch: 3/10 | Test Accuracy: 78.74213836477988
Epoch: 4/10 | Loss: 0.012148083187639713
Epoch: 4/10 | Test Accuracy: 78.23899371069183
Epoch: 5/10 | Loss: 0.009702547453343868
Epoch: 5/10 | Test Accuracy: 78.11320754716981
Epoch: 6/10 | Loss: 0.0006071869283914566
Epoch: 6/10 | Test Accuracy: 78.49056603773585
Epoch: 7/10 | Loss: 0.0017064546700567007
Epoch: 7/10 | Test Accuracy: 78.74213836477988
Epoch: 8/10 | Loss: 0.0002203936892328784
Epoch: 8/10 | Test Accuracy: 78.61635220125787
Epoch: 9/10 | Loss: 0.00015209948469419032
Epoch: 9/10 | Test Accuracy: 78.74213836477988
Epoch: 10/10 | Loss: 0.0007362039177678525
Epoch: 10/10 | Test Accuracy: 78.49056603773585
--------------------------------------------
------------Combination: ['W', 'A']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 1.7352815866470337
Epoch: 1/10 | Test Accuracy: 26.90805821796237
Epoch: 2/10 | Loss: 0.14262782037258148
Epoch: 2/10 | Test Accuracy: 51.01171458998935
Epoch: 3/10 | Loss: 0.020084500312805176
Epoch: 3/10 | Test Accuracy: 55.307064252751154
Epoch: 4/10 | Loss: 0.009953547269105911
Epoch: 4/10 | Test Accuracy: 56.69151579694711
Epoch: 5/10 | Loss: 0.007418349385261536
Epoch: 5/10 | Test Accuracy: 56.86900958466454
Epoch: 6/10 | Loss: 0.005166002083569765
Epoch: 6/10 | Test Accuracy: 56.79801206957757
Epoch: 7/10 | Loss: 0.0047587500885128975
Epoch: 7/10 | Test Accuracy: 57.117500887468935
Epoch: 8/10 | Loss: 0.004499915987253189
Epoch: 8/10 | Test Accuracy: 57.011004614838484
Epoch: 9/10 | Loss: 0.0035709217190742493
Epoch: 9/10 | Test Accuracy: 57.08200212992545
Epoch: 10/10 | Loss: 0.0036504885647445917
Epoch: 10/10 | Test Accuracy: 57.18849840255591
--------------------------------------------
------------Combination: ['W', 'D']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 1.192643165588379
Epoch: 1/10 | Test Accuracy: 80.62893081761007
Epoch: 2/10 | Loss: 0.0785270407795906
Epoch: 2/10 | Test Accuracy: 97.23270440251572
Epoch: 3/10 | Loss: 0.02375122159719467
Epoch: 3/10 | Test Accuracy: 97.86163522012579
Epoch: 4/10 | Loss: 0.00916258618235588
Epoch: 4/10 | Test Accuracy: 98.49056603773585
Epoch: 5/10 | Loss: 0.007910201326012611
Epoch: 5/10 | Test Accuracy: 98.61635220125787
Epoch: 6/10 | Loss: 0.005907063838094473
Epoch: 6/10 | Test Accuracy: 98.61635220125787
Epoch: 7/10 | Loss: 0.004419818986207247
Epoch: 7/10 | Test Accuracy: 98.61635220125787
Epoch: 8/10 | Loss: 0.005289257038384676
Epoch: 8/10 | Test Accuracy: 98.61635220125787
Epoch: 9/10 | Loss: 0.003232485381886363
Epoch: 9/10 | Test Accuracy: 98.61635220125787
Epoch: 10/10 | Loss: 0.0033467216417193413
Epoch: 10/10 | Test Accuracy: 98.61635220125787
--------------------------------------------
------------Combination: ['D', 'W']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 0.39330098032951355
Epoch: 1/10 | Test Accuracy: 97.38955823293173
Epoch: 2/10 | Loss: 0.09654019773006439
Epoch: 2/10 | Test Accuracy: 99.59839357429719
Epoch: 3/10 | Loss: 0.012457349337637424
Epoch: 3/10 | Test Accuracy: 100.0
Epoch: 4/10 | Loss: 0.00811581127345562
Epoch: 4/10 | Test Accuracy: 100.0
Epoch: 5/10 | Loss: 0.005905048921704292
Epoch: 5/10 | Test Accuracy: 100.0
Epoch: 6/10 | Loss: 0.004387938883155584
Epoch: 6/10 | Test Accuracy: 100.0
Epoch: 7/10 | Loss: 0.003571375273168087
Epoch: 7/10 | Test Accuracy: 100.0
Epoch: 8/10 | Loss: 0.00397808151319623
Epoch: 8/10 | Test Accuracy: 100.0
Epoch: 9/10 | Loss: 0.002719544805586338
Epoch: 9/10 | Test Accuracy: 100.0
Epoch: 10/10 | Loss: 0.0022316286340355873
Epoch: 10/10 | Test Accuracy: 100.0
--------------------------------------------
------------Combination: ['D', 'A']--------------


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1/10 | Loss: 1.1038446426391602
Epoch: 1/10 | Test Accuracy: 46.46787362442315
Epoch: 2/10 | Loss: 0.04351508617401123
Epoch: 2/10 | Test Accuracy: 59.95740149094782
Epoch: 3/10 | Loss: 0.016177324578166008
Epoch: 3/10 | Test Accuracy: 61.09336173233937
Epoch: 4/10 | Loss: 0.00868083443492651
Epoch: 4/10 | Test Accuracy: 61.69684061057863
Epoch: 5/10 | Loss: 0.0058838496915996075
Epoch: 5/10 | Test Accuracy: 61.803336883209084
Epoch: 6/10 | Loss: 0.004809904377907515
Epoch: 6/10 | Test Accuracy: 61.90983315583954
Epoch: 7/10 | Loss: 0.004589059855788946
Epoch: 7/10 | Test Accuracy: 62.264820731274405
Epoch: 8/10 | Loss: 0.003907616715878248
Epoch: 8/10 | Test Accuracy: 62.22932197373092
Epoch: 9/10 | Loss: 0.003600620198994875
Epoch: 9/10 | Test Accuracy: 62.37131700390486
Epoch: 10/10 | Loss: 0.002392939757555723
Epoch: 10/10 | Test Accuracy: 62.40681576144835
--------------------------------------------
