In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16

# Datasets
# from datasets import load_dataset
import kagglehub
from datasets import load_dataset

# Extra
import os
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt
import timm

## Office-31

In [None]:
def Office31(domain, transform):
    path = kagglehub.dataset_download("xixuhu/office31")
    path = os.path.join(os.path.join(path, "Office-31"), domain)
    return datasets.ImageFolder(root=path, transform=transform)

def get_office_data_loaders(batch_size):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    amazon_data = Office31("amazon", transform)
    dslr_data = Office31("dslr", transform)
    webcam_data = Office31("webcam", transform)

    loader_amazon = DataLoader(amazon_data, batch_size=batch_size, shuffle=True)
    loader_dslr = DataLoader(dslr_data, batch_size=batch_size, shuffle=True)
    loader_webcam = DataLoader(webcam_data, batch_size=batch_size, shuffle=True)

    return loader_amazon, loader_dslr, loader_webcam

## Pacs

In [None]:
class PACSDataset(Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        image = example['image']
        label = example['label']
        image = self.transform(image)
        return image, label

def get_pacs_data(domain_name, split="train"):
    """Loads a specific domain from the PACS dataset."""
    dataset = load_dataset("flwrlabs/pacs")[split]
    domain_data = dataset.filter(lambda example: example['domain'] == domain_name)

    return domain_data

def get_pacs_data_loaders(batch_size):
    """Creates data loaders for three domains in PACS."""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    art_painting_data = get_pacs_data("art_painting", split="train")
    cartoon_data = get_pacs_data("cartoon", split="train")
    photo_data = get_pacs_data("photo", split="train")
    sketch_data = get_pacs_data("sketch", split="train")

    art_painting_dataset = PACSDataset(art_painting_data, transform)
    cartoon_dataset = PACSDataset(cartoon_data, transform)
    photo_dataset = PACSDataset(photo_data, transform)
    sketch_dataset = PACSDataset(sketch_data, transform)

    loader_art_painting = DataLoader(art_painting_dataset, batch_size=batch_size, shuffle=True)
    loader_cartoon = DataLoader(cartoon_dataset, batch_size=batch_size, shuffle=True)
    loader_photo = DataLoader(photo_dataset, batch_size=batch_size, shuffle=True)
    loader_sketch = DataLoader(sketch_dataset, batch_size=batch_size, shuffle=True)

    return loader_art_painting, loader_cartoon, loader_photo, loader_sketch

## Model Cards

In [None]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base_model = torchvision.models.resnet50(pretrained=pretrained)
        self.conv1 = base_model.conv1
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        f_block1 = self.layer1(x)
        f_block2 = self.layer2(f_block1)
        f_block3 = self.layer3(f_block2)
        f_block4 = self.layer4(f_block3)

        return f_block2, f_block3, f_block4



class VitBasicFeatureExtractor(nn.Module):

    def __init__(self, pretrained=True, layers=[4, 8, 12]):
        super().__init__()

        self.model = timm.create_model('vit_base_patch16_224', pretrained=pretrained, features_only=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):

        features = self.model(x)
        pooled_features = []
        for feature in features:
            pooled = self.avgpool(feature)
            pooled = pooled.view(pooled.size(0), -1)
            pooled_features.append(pooled)

        return tuple(pooled_features)


In [None]:
model = VitBasicFeatureExtractor(pretrained=True)

input = torch.randn(16, 3, 224, 224)
output = model(input)

for out in output:
    print(out.shape)

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

torch.Size([16, 768])
torch.Size([16, 768])
torch.Size([16, 768])


In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class SparseAutoencoder(nn.Module):

    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=768, dropout_rate=0.3):
        super(SparseAutoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_uniform_(m.weight, a=0, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                init.ones_(m.weight)
                init.zeros_(m.bias)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z


In [None]:
class UnifiedModelMultiBlockSAE(nn.Module):
    def __init__(self,
                 sae_dim_block2=128,
                 sae_dim_block3=256,
                 sae_dim_block4=256,
                 num_classes=7):
        super().__init__()

        self.feature_extractor = VitBasicFeatureExtractor(pretrained=True)

        self.sae2 = SparseAutoencoder(input_dim=768, latent_dim=768)
        self.sae3 = SparseAutoencoder(input_dim=768, latent_dim=768)
        self.sae4 = SparseAutoencoder(input_dim=768, latent_dim=768)


        self.classifier = nn.Sequential(
            nn.Linear(768, num_classes)
        )

    def forward(self, x):

        _, _, f4 = self.feature_extractor(x)
        x_recon4, z4 = self.sae4(f4)

        class_logits = self.classifier(z4)

        return (class_logits,
                (x_recon4, z4),
                (f4))

## Losses

In [None]:
def irm_penalty(logits, labels):

    scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
    loss_erm = F.cross_entropy(scale * logits, labels)

    grad = torch.autograd.grad(loss_erm, [scale], create_graph=True)[0]

    penalty = torch.sum(grad**2)
    var = 0.0

    return loss_erm, penalty, var

## Trian and Evaluate Functions

### Train

In [None]:
def train_model_irm_sae_with_warmup_office(
    batch_size=16,
    num_warmup_epochs=5,
    num_main_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    device='cuda',
    loader=['P', 'C'],
    model=None,
    verbose=False,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 0.0],
):

    loader_art_painting, loader_cartoon, loader_photo, loader_sketch = get_pacs_data_loaders(batch_size)

    source, target = loader if loader is not None else ['P', 'C']

    if source == 'P':
        loader_source = loader_photo
    elif source == 'C':
        loader_source = loader_cartoon
    elif source == 'A':
        loader_source = loader_art_painting
    else:
        loader_source = loader_sketch

    if target == 'P':
        loader_target = loader_photo
    elif target == 'C':
        loader_target = loader_cartoon
    elif target == 'A':
        loader_target = loader_art_painting
    else:
        loader_target = loader_sketch


    if model is None:
        model = UnifiedModelMultiBlockSAE(512, 1024, 2048, 31)

    print("===== Main Phase (IRM + SAE) =====")

    model = train_main_irm_multi_sae_office(model,
                                            loader_source=loader_source,
                                            test_loader=loader_target,
                                            num_epochs=num_main_epochs,
                                            lr=lr,
                                            lr_sae=lr_sae,
                                            lambda_irm=lambda_irm,
                                            lambda_sae_rec=lambda_sae_rec,
                                            lambda_sae_sparse=lambda_sae_sparse,
                                            device=device,
                                            verbose=verbose,
                                            lambda_sparse=lambda_sparse,
                                            lambda_reconstruction=lambda_reconstruction,
                                            lambda_irm_pair=lambda_irm_pair
                                            )

    return model

### Separate optm

In [None]:
def train_main_irm_multi_sae_office(
    model,
    loader_source,
    test_loader,
    num_epochs=20,
    lr=1e-4,
    lr_sae=1e-4,
    lambda_irm=1.0,
    lambda_sae_rec=1.0,
    lambda_sae_sparse=1e-4,
    lambda_sparse=[1.0, 1.0, 1.0],
    lambda_reconstruction=[1.0, 1.0, 1.0],
    lambda_irm_pair=[1.0, 1.0, 1.0],
    device='cuda',
    verbose=False
):
    import torch.optim as optim

    model.to(device)
    model.train()

    def sae_forward_splits(f4p):
        # x_recon2, z2 = model.sae2(f2p)
        # x_recon3, z3 = model.sae3(f3p)
        x_recon4, z4 = model.sae4(f4p)
        return x_recon4, z4

    # Define separate optimizers
    # Optimizer for the feature extractor and classifier
    params_rest = [
        p for n, p in model.named_parameters()
        if not (n.startswith('sae2') or n.startswith('sae3') or n.startswith('sae4'))
    ]
    # optimizer_rest = optim.SGD(params_rest, lr=lr, momentum=0.9)
    optimizer_rest = optim.Adam(params_rest, lr=lr)
    # Optimizer for the Sparse Autoencoders (sae2, sae3, sae4)
    params_sae = list(model.sae2.parameters()) + list(model.sae3.parameters()) + list(model.sae4.parameters())
    optimizer_sae = optim.Adam(params_sae, lr=lr_sae)

    for epoch in range(num_epochs):
        source_iter = iter(loader_source)
        steps_per_epoch = len(source_iter)

        for step in range(steps_per_epoch):

            try:
                x_s, y_s = next(source_iter)
            except StopIteration:
                source_iter = iter(loader_source)
                x_s, y_s = next(source_iter)

            x_s, y_s = x_s.to(device), y_s.to(device)

            # Forward pass for classification
            class_logits_s, _, _ = model(x_s)

            # Compute IRM loss
            loss_erm_s, penalty_s, var_s = irm_penalty(class_logits_s, y_s)

            irm_loss = 0.5 * loss_erm_s
            irm_pen  = 0.5 * penalty_s

            w1, w2, w3 = lambda_irm_pair
            loss_irm = w1 * irm_loss + w2 * (lambda_irm * irm_pen) + w3 * var_s

            # Forward pass for SAEs without tracking gradients
            with torch.no_grad():
                class_logits_s, (x_recon4_s, z4_s), (f4p_s) = model(x_s)

            # Forward pass through SAEs to get reconstructions and latent vectors
            x_recon4_s, z4_s = sae_forward_splits(f4p_s)

            # Compute Reconstruction Loss
            lambda_s1, lambda_s2, lambda_s3 = lambda_sparse
            lambda_r1, lambda_r2, lambda_r3 = lambda_reconstruction


            rec_loss4_s = F.mse_loss(x_recon4_s, f4p_s)
            rec_loss4   = lambda_r3 * rec_loss4_s

            rec_loss_total = rec_loss4

            l1_4_s = torch.mean(torch.abs(z4_s))
            l1_4   = lambda_s3 * l1_4_s

            l1_sparsity = l1_4

            # Total SAE Loss
            sae_loss = lambda_sae_rec * rec_loss_total + lambda_sae_sparse * l1_sparsity

            loss = loss_irm + sae_loss


            # Zero gradients for both optimizers
            optimizer_rest.zero_grad()
            optimizer_sae.zero_grad()

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer_rest.step()
            optimizer_sae.step()

            if (step+1) % 40 == 0 and verbose:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{steps_per_epoch}], "
                      f"IRM Loss: {loss_irm.item():.4f}, SAE Loss: {sae_loss.item():.4f}")

        # Evaluation after each epoch
        test_acc = evaluate(model, test_loader, device=device)
        print(f"** End of Epoch {epoch+1}/{num_epochs} | Test Accuracy: {test_acc:.2f}% **")

    return model


### Evaluate

In [None]:
def evaluate_baseline(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    model.train()
    return acc


In [None]:
def evaluate(model, loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = 100.0 * correct / total
    return acc

# MAIN

In [None]:
feature_extractor_baseline = ResNetFeatureExtractor(pretrained=True)
baseline_model = BaselineModel(feature_extractor=feature_extractor_baseline, num_classes=31)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 201MB/s]


In [None]:
loader_amazon, loader_webcam, loader_dslr = get_office_data_loaders(32)

Downloading from https://www.kaggle.com/api/v1/datasets/download/xixuhu/office31?dataset_version_number=1...


100%|██████████| 75.9M/75.9M [00:04<00:00, 16.0MB/s]

Extracting files...







### Baseline models

#### A -> W

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_amazon,
                         loader_webcam,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (458752x7 and 768x31)

#### A -> D

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_amazon,
                         loader_dslr,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

#### W -> A

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_webcam,
                         loader_amazon,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

#### W -> D

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_webcam,
                         loader_dslr,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

#### D -> W

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_dslr,
                         loader_webcam,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

#### D -> A

In [None]:
baseline_model = train_baseline_model_office(baseline_model,
                         loader_dslr,
                         loader_amazon,
                         num_epochs=20,
                         lr=3e-5,
                         lambda_irm=4.0,
                         device='cuda')

### Ours

In [None]:
combinations = [["A", "D"], ["W", "A"], ["W", "D"], ["D", "W"], ["D", "A"]] # Removed ["A", "D"]

##A -> D **DO** **NOT** **RUN**

In [None]:
combination = ["A", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=6,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['A', 'D']--------------
===== Main Phase (IRM + SAE) =====
Epoch [1/6], Step [40/89], IRM Loss: 7.7358, SAE Loss: 5.7715
Epoch [1/6], Step [80/89], IRM Loss: 5.2013, SAE Loss: 6.7702
** End of Epoch 1/6 | Test Accuracy: 89.43% **
Epoch [2/6], Step [40/89], IRM Loss: 0.0913, SAE Loss: 5.4998
Epoch [2/6], Step [80/89], IRM Loss: 0.3113, SAE Loss: 5.2867
** End of Epoch 2/6 | Test Accuracy: 91.07% **
Epoch [3/6], Step [40/89], IRM Loss: 0.0932, SAE Loss: 6.8329
Epoch [3/6], Step [80/89], IRM Loss: 0.1265, SAE Loss: 6.0455
** End of Epoch 3/6 | Test Accuracy: 91.45% **
Epoch [4/6], Step [40/89], IRM Loss: 0.0782, SAE Loss: 5.5913
Epoch [4/6], Step [80/89], IRM Loss: 0.0847, SAE Loss: 5.7738
** End of Epoch 4/6 | Test Accuracy: 91.07% **
Epoch [5/6], Step [40/89], IRM Loss: 0.0900, SAE Loss: 5.3271
Epoch [5/6], Step [80/89], IRM Loss: 0.0715, SAE Loss: 6.0006
** End of Epoch 5/6 | Test Accuracy: 90.69% **
Epoch [6/6], Step [40/89], IRM Loss: 0.0639, SAE Loss: 6.461

##A -> W **DO** **NOT** **RUN**

In [None]:
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=6,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse
        )



model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

===== Main Phase (IRM + SAE) =====
Epoch [1/6], Step [40/89], IRM Loss: 14.9789, SAE Loss: 5.4896
Epoch [1/6], Step [80/89], IRM Loss: 13.1903, SAE Loss: 4.1527
** End of Epoch 1/6 | Test Accuracy: 74.90% **
Epoch [2/6], Step [40/89], IRM Loss: 2.5375, SAE Loss: 7.0794
Epoch [2/6], Step [80/89], IRM Loss: 1.5245, SAE Loss: 7.0305
** End of Epoch 2/6 | Test Accuracy: 88.76% **
Epoch [3/6], Step [40/89], IRM Loss: 1.2606, SAE Loss: 6.8434
Epoch [3/6], Step [80/89], IRM Loss: 0.7653, SAE Loss: 5.3209
** End of Epoch 3/6 | Test Accuracy: 91.77% **
Epoch [4/6], Step [40/89], IRM Loss: 0.4677, SAE Loss: 5.4863
Epoch [4/6], Step [80/89], IRM Loss: 0.5548, SAE Loss: 6.4527
** End of Epoch 4/6 | Test Accuracy: 91.57% **
Epoch [5/6], Step [40/89], IRM Loss: 0.1260, SAE Loss: 5.8335
Epoch [5/6], Step [80/89], IRM Loss: 0.1268, SAE Loss: 6.7935
** End of Epoch 5/6 | Test Accuracy: 93.37% **
Epoch [6/6], Step [40/89], IRM Loss: 0.1516, SAE Loss: 5.5576
Epoch [6/6], Step [80/89], IRM Loss: 0.1460, S

##W -> A **DO** **NOT** **RUN**

In [None]:
combination = ["W", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=30,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/30 | Test Accuracy: 72.31% **
** End of Epoch 2/30 | Test Accuracy: 72.99% **
** End of Epoch 3/30 | Test Accuracy: 74.94% **
** End of Epoch 4/30 | Test Accuracy: 75.54% **
** End of Epoch 5/30 | Test Accuracy: 75.90% **
** End of Epoch 6/30 | Test Accuracy: 76.36% **
** End of Epoch 7/30 | Test Accuracy: 76.46% **
** End of Epoch 8/30 | Test Accuracy: 76.64% **
** End of Epoch 9/30 | Test Accuracy: 76.75% **
** End of Epoch 10/30 | Test Accuracy: 76.93% **
** End of Epoch 11/30 | Test Accuracy: 76.93% **
** End of Epoch 12/30 | Test Accuracy: 77.00% **
** End of Epoch 13/30 | Test Accuracy: 77.03% **
** End of Epoch 14/30 | Test Accuracy: 77.07% **
** End of Epoch 15/30 | Test Accuracy: 77.14% **
** End of Epoch 16/30 | Test Accuracy: 77.14% **
** End of Epoch 17/30 | Test Accuracy: 77.14% **
** End of Epoch 18/30 | Test Accuracy: 77.25% **
** End of Epoch 19/30 | Test Accuracy: 77.

## W -> D **DO** **NOT** **RUN**

In [None]:
combination = ["W", "D"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=15,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['W', 'D']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/15 | Test Accuracy: 97.23% **
** End of Epoch 2/15 | Test Accuracy: 97.23% **
** End of Epoch 3/15 | Test Accuracy: 97.99% **
** End of Epoch 4/15 | Test Accuracy: 98.24% **
** End of Epoch 5/15 | Test Accuracy: 98.36% **
** End of Epoch 6/15 | Test Accuracy: 98.49% **
** End of Epoch 7/15 | Test Accuracy: 98.62% **
** End of Epoch 8/15 | Test Accuracy: 98.62% **
** End of Epoch 9/15 | Test Accuracy: 98.87% **
** End of Epoch 10/15 | Test Accuracy: 98.87% **
** End of Epoch 11/15 | Test Accuracy: 98.87% **
** End of Epoch 12/15 | Test Accuracy: 98.87% **
** End of Epoch 13/15 | Test Accuracy: 98.87% **
** End of Epoch 14/15 | Test Accuracy: 98.87% **
** End of Epoch 15/15 | Test Accuracy: 98.87% **
--------------------------------------------


## D -> W **DO** **NOT** **RUN**

In [None]:
combination = ["D", "W"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=20,
            lr=1e-5,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'W']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/20 | Test Accuracy: 41.37% **
** End of Epoch 2/20 | Test Accuracy: 54.82% **
** End of Epoch 3/20 | Test Accuracy: 49.60% **
** End of Epoch 4/20 | Test Accuracy: 74.90% **
** End of Epoch 5/20 | Test Accuracy: 96.18% **
** End of Epoch 6/20 | Test Accuracy: 99.80% **
** End of Epoch 7/20 | Test Accuracy: 100.00% **
** End of Epoch 8/20 | Test Accuracy: 100.00% **
** End of Epoch 9/20 | Test Accuracy: 100.00% **
** End of Epoch 10/20 | Test Accuracy: 100.00% **
** End of Epoch 11/20 | Test Accuracy: 100.00% **
** End of Epoch 12/20 | Test Accuracy: 100.00% **
** End of Epoch 13/20 | Test Accuracy: 100.00% **
** End of Epoch 14/20 | Test Accuracy: 100.00% **
** End of Epoch 15/20 | Test Accuracy: 100.00% **
** End of Epoch 16/20 | Test Accuracy: 100.00% **
** End of Epoch 17/20 | Test Accuracy: 100.00% **
** End of Epoch 18/20 | Test Accuracy: 100.00% **
** End of Epoch 19/20 | Test A

## D -> A **DO** **NOT** **RUN**

In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=20,
            lr=1e-6,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")


------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/20 | Test Accuracy: 21.05% **
** End of Epoch 2/20 | Test Accuracy: 28.51% **
** End of Epoch 3/20 | Test Accuracy: 26.09% **
** End of Epoch 4/20 | Test Accuracy: 47.07% **
** End of Epoch 5/20 | Test Accuracy: 69.29% **
** End of Epoch 6/20 | Test Accuracy: 75.33% **
** End of Epoch 7/20 | Test Accuracy: 75.33% **
** End of Epoch 8/20 | Test Accuracy: 75.75% **
** End of Epoch 9/20 | Test Accuracy: 75.51% **
** End of Epoch 10/20 | Test Accuracy: 75.36% **
** End of Epoch 11/20 | Test Accuracy: 75.26% **
** End of Epoch 12/20 | Test Accuracy: 75.36% **
** End of Epoch 13/20 | Test Accuracy: 75.19% **
** End of Epoch 14/20 | Test Accuracy: 75.08% **
** End of Epoch 15/20 | Test Accuracy: 75.22% **
** End of Epoch 16/20 | Test Accuracy: 75.36% **
** End of Epoch 17/20 | Test Accuracy: 75.29% **
** End of Epoch 18/20 | Test Accuracy: 75.26% **
** End of Epoch 19/20 | Test Accuracy: 75.

In [None]:
combination = ["D", "A"]
print(f"------------Combination: {combination}--------------")
lambda_irm_pair = [10.0, 4.0, 0.0]
lambda_sparse=[0.1, 0.1, 0.1]

trained_model_sae_office = train_model_irm_sae_with_warmup_office(
            # model=trained_model_sae_office,
            batch_size=32,
            num_warmup_epochs=4,
            num_main_epochs=30,
            lr=65e-7,
            lr_sae=5e-5,
            lambda_irm=1.0,
            lambda_sae_rec=2.0,
            lambda_sae_sparse=2e-4,
            device='cuda',
            verbose=True,
            lambda_irm_pair=lambda_irm_pair,
            lambda_sparse=lambda_sparse,
            loader=combination
        )

print("--------------------------------------------")



------------Combination: ['D', 'A']--------------
===== Main Phase (IRM + SAE) =====
** End of Epoch 1/30 | Test Accuracy: 28.93% **
** End of Epoch 2/30 | Test Accuracy: 32.34% **
** End of Epoch 3/30 | Test Accuracy: 29.36% **
** End of Epoch 4/30 | Test Accuracy: 48.21% **
** End of Epoch 5/30 | Test Accuracy: 65.14% **
** End of Epoch 6/30 | Test Accuracy: 72.95% **
** End of Epoch 7/30 | Test Accuracy: 73.91% **
** End of Epoch 8/30 | Test Accuracy: 73.80% **
** End of Epoch 9/30 | Test Accuracy: 74.01% **
** End of Epoch 10/30 | Test Accuracy: 73.91% **
** End of Epoch 11/30 | Test Accuracy: 74.01% **
** End of Epoch 12/30 | Test Accuracy: 74.09% **
** End of Epoch 13/30 | Test Accuracy: 74.19% **
** End of Epoch 14/30 | Test Accuracy: 74.16% **
** End of Epoch 15/30 | Test Accuracy: 74.12% **
** End of Epoch 16/30 | Test Accuracy: 74.16% **
** End of Epoch 17/30 | Test Accuracy: 74.16% **
** End of Epoch 18/30 | Test Accuracy: 74.09% **
** End of Epoch 19/30 | Test Accuracy: 74.

KeyboardInterrupt: 

## Base Vit accuracies

In [None]:
from torchvision.models import vit_b_16
from tqdm.notebook import tqdm

def get_base_accuracy(loader=None):

    # print(f"------------Combination: {loader}--------------")

    loader_art_painting, loader_cartoon, loader_photo, loader_sketch = get_pacs_data_loaders(32)

    source, target = loader if loader is not None else ['P', 'C']

    if source == 'P':
        loader_source = loader_photo
    elif source == 'C':
        loader_source = loader_cartoon
    elif source == 'A':
        loader_source = loader_art_painting
    else:
        loader_source = loader_sketch

    if target == 'P':
        loader_target = loader_photo
    elif target == 'C':
        loader_target = loader_cartoon
    elif target == 'A':
        loader_target = loader_art_painting
    else:
        loader_target = loader_sketch


    model = vit_b_16(weights="DEFAULT")
    model.heads.head = nn.Linear(in_features=768, out_features=31, bias=True)
    model = model.to("cuda")

    optm = torch.optim.Adam(model.parameters(), lr=1e-5)
    epochs = 10

    for epoch in range(epochs):
        model.train()
        for x, y in (loader_source):
            x, y = x.to("cuda"), y.to("cuda")

            optm.zero_grad()
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optm.step()

        print(f"Epoch: {epoch+1}/{epochs} | Loss: {loss.item()}")

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in (loader_target):
                x, y = x.to("cuda"), y.to("cuda")

                logits = model(x)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = 100.0 * correct / total
        print(f"Epoch: {epoch+1}/{epochs} | Test Accuracy: {acc}")

        # print("--------------------------------------------")




In [None]:
combinations = [["P", "C"], ["P", "A"], ["P", "S"], ["C", "A"], ["C", "S"], ["C", "P"], ["A", "P"], ["A", "C"], ["A", "S"], ["S", "P"], ["S", "C"], ["S", "A"]]

for combination in combinations:
    print(f"------------Combination: {combination}--------------")
    get_base_accuracy(loader=combination)
    print("--------------------------------------------")

------------Combination: ['P', 'C']--------------


README.md:   0%|          | 0.00/3.89k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 216MB/s]


Epoch: 1/10 | Loss: 0.6663272976875305
Epoch: 1/10 | Test Accuracy: 24.914675767918087
Epoch: 2/10 | Loss: 0.13300557434558868
Epoch: 2/10 | Test Accuracy: 25.68259385665529
Epoch: 3/10 | Loss: 0.025642847642302513
Epoch: 3/10 | Test Accuracy: 28.370307167235495
Epoch: 4/10 | Loss: 0.011229739524424076
Epoch: 4/10 | Test Accuracy: 29.351535836177476
Epoch: 5/10 | Loss: 0.018107669427990913
Epoch: 5/10 | Test Accuracy: 29.948805460750854
Epoch: 6/10 | Loss: 0.007119265850633383
Epoch: 6/10 | Test Accuracy: 30.247440273037544
Epoch: 7/10 | Loss: 0.007505128625780344
Epoch: 7/10 | Test Accuracy: 30.076791808873722
Epoch: 8/10 | Loss: 0.009765692055225372
Epoch: 8/10 | Test Accuracy: 30.418088737201366
Epoch: 9/10 | Loss: 0.003912889864295721
Epoch: 9/10 | Test Accuracy: 30.716723549488055
Epoch: 10/10 | Loss: 0.005921937525272369
Epoch: 10/10 | Test Accuracy: 30.546075085324233
--------------------------------------------
------------Combination: ['P', 'A']--------------
Epoch: 1/10 | Los