**Why use CCSA?**

*   **Semantic Alignment:** It minimizes the distance between samples of the same class across domains while maximizing the separation between samples of different classes. This ensures that features are well-aligned semantically, which is crucial for DA
*   **Unified Framework:** It combines classification loss with contrastive loss, providing a comprehensive approach to supervised domain adaptation and generalization
*   **Fast Adaptation**: The method demonstrates a high "speed" of adaptation, meaning it can quickly achieve strong performance even with limited labeled target data






**Why is CCSA good for supervised domain adaptation?**

*   **Leveraging Supervision:** Unlike unsupervised approaches, supervised domain adaptation benefits from labeled data in both the source and target domains. CCSA Loss explicitly uses these labels to enforce semantic alignment, which leads to better classification performance.
*   **Class Discrimination:** By maximizing inter-class variance, CCSA Loss ensures that different classes are well-separated in the feature space. This helps the model maintain high discriminative power across domains, making it robust to variations.



**Imports**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


import torchvision.models as models
from torchvision import datasets, transforms

from torch.utils.data import WeightedRandomSampler
from torch.utils.data import DataLoader

from collections import Counter
from torch.utils.data import random_split

In [2]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

**Data Processing**

In [3]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("xixuhu/office31")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/office31


In [4]:
import os
def walk_through_dir(dir_path):

  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")


# uncomment to find the path of the 3 datasets
# walk_through_dir(path)

In [5]:
amazon_path = '/root/.cache/kagglehub/datasets/xixuhu/office31/versions/1/Office-31/amazon'
dslr_path = '/root/.cache/kagglehub/datasets/xixuhu/office31/versions/1/Office-31/dslr'
webcam_path = '/root/.cache/kagglehub/datasets/xixuhu/office31/versions/1/Office-31/webcam'

In [6]:
simple_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


target_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

amazon_data = datasets.ImageFolder(amazon_path, transform=simple_transform)
# dslr_data = datasets.ImageFolder(dslr_path, transform=simple_transform)
webcam_data = datasets.ImageFolder(webcam_path, transform=target_transform)


In [7]:
# split the target domain (amazon) to make val dataset
val_size = int(len(amazon_data) * 0.2)
train_size = len(amazon_data) - val_size
amazon_dataset, amazon_val_dataset = random_split(amazon_data, [train_size, val_size])

In [8]:
batch_size = 32

# Calculate weights for oversampling
target_weights = [1.0] * len(webcam_data)
target_sampler = WeightedRandomSampler(target_weights, num_samples=len(amazon_dataset), replacement=True)

In [9]:
# DataLoaders
batch_size = 32

amazon_dataloader = DataLoader(amazon_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
# dslr_dataloader = DataLoader(dslr_data, batch_size=batch_size, shuffle=True, drop_last=True)
webcam_dataloader = DataLoader(webcam_data, batch_size=batch_size, sampler=target_sampler,  drop_last=True)

amazon_val_dataloader = DataLoader(amazon_val_dataset, batch_size=batch_size, shuffle=True)



In [10]:
len(amazon_dataloader), len(webcam_dataloader), len(amazon_val_dataloader)

(70, 70, 18)

In [11]:
def count_class_instances(dataset):
    """
    Count the number of instances for each class in the dataset.

    Args:
        dataset (Dataset): A PyTorch Dataset object where each item returns (data, label).

    Returns:
        dict: A dictionary with class labels as keys and their counts as values.
    """
    # Count occurrences of each class label
    class_counts = Counter([label for _, label in dataset])

    # Print the counts
    print("Class distribution:")
    for cls, count in class_counts.items():
        print(f"Class {cls}: {count} instances")

    return class_counts

# amazon_class_counts = count_class_instances(amazon_data)
# dslr_class_counts = count_class_instances(dslr_data)
# webcam_class_counts = count_class_instances(webcam_data)


In [12]:
for images, labels in amazon_dataloader:
    print("Batch of images shape:", images.shape)  # E.g., [4, 3, 224, 224] for 4 images, 3 color channels, 224x224 size
    print("Batch of labels:", labels)             # E.g., tensor([0, 1, 0, 2]) depending on classes
    break

Batch of images shape: torch.Size([32, 3, 224, 224])
Batch of labels: tensor([25, 16,  1,  3,  3, 25,  8, 17, 23, 21, 21, 18, 28,  9, 25, 25, 27,  6,
         0,  8, 18,  2, 18, 10, 24, 15, 10, 26, 28,  8,  2,  6])


**Models**

Feature extractor


---


The backbone of your model will extract meaningful features from input images




In [13]:
class FeatureExtractor(nn.Module):
  def __init__(self):
    super().__init__()
    backbone = models.resnet18(pretrained=True)
    self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])

    # Freeze early layers
    for param in list(self.feature_extractor.parameters())[:-3]:  # Freeze all but the last 3 layers
        param.requires_grad = False

  def forward(self, x):
    return self.feature_extractor(x)

Embedding Layer


---

After extracting features, an embedding layer maps the features into a lower-dimensional space where domain alignment takes place. This layer can be a simple fully connected layer


In [14]:
class EmbeddingLayer(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(EmbeddingLayer, self).__init__()
        self.fc = nn.Linear(input_dim, embedding_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x

Combined Architecture

---



In [15]:
class DomainAdaptationModel(nn.Module):
    def __init__(self, embedding_dim=128, num_classes = 31):
        super(DomainAdaptationModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.embedding_layer = EmbeddingLayer(input_dim=512, embedding_dim=embedding_dim)  # ResNet18 output size is 512
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        embeddings = self.embedding_layer(features)
        embeddings = self.classifier(embeddings)
        return embeddings


**CCSA Loss**

CCSA Loss works by minimizing the distance between samples of the same class from different domains while maximizing the distance between samples of different classes. This involves calculating pairwise distances for embeddings. Focuses on aligning features across domains by minimizing intra-class variance and maximizing inter-class variance. It's ideal for supervised domain adaptation tasks.

In [40]:
class CCSALoss(nn.Module):
    def __init__(self, margin=2):
        super(CCSALoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, labels, domains):
        # Calculate pairwise distances
        pairwise_dist = torch.cdist(embeddings, embeddings, p=2)  # Euclidean distance

        # Masks for same labels and different domains
        same_label_mask = labels.unsqueeze(1) == labels.unsqueeze(0)
        different_domain_mask = domains.unsqueeze(1) != domains.unsqueeze(0)

        # Positive pairs: Same class, different domain
        positive_mask = same_label_mask & different_domain_mask
        positive_loss = pairwise_dist[positive_mask].sum()
        num_positive_pairs = positive_mask.sum().item()
        if num_positive_pairs > 0:
            positive_loss /= num_positive_pairs  # Normalize positive loss

        # Negative pairs: Different class
        negative_mask = ~same_label_mask
        negative_loss = torch.clamp(self.margin - pairwise_dist[negative_mask], min=0).sum()
        num_negative_pairs = negative_mask.sum().item()
        if num_negative_pairs > 0:
            negative_loss /= num_negative_pairs  # Normalize negative loss

        # Combine losses
        loss = positive_loss + negative_loss
        return loss



In [41]:
model = DomainAdaptationModel()
criterion = CCSALoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


**Training Loop**

In [42]:
def train_loop(
    model,
    ccsa_loss,
    optimizer,
    source_dataloader,
    target_dataloader,
    num_epochs,
    device
):
    """
    Train a domain adaptation model with two datasets (source and target).

    Args:
        model (nn.Module): The domain adaptation model.
        ccsa_loss (nn.Module): The CCSA loss function.
        optimizer (torch.optim.Optimizer): Optimizer for the model.
        source_dataloader (DataLoader): Dataloader for the source domain.
        target_dataloader (DataLoader): Dataloader for the target domain.
        num_epochs (int): Number of epochs to train.
        device (torch.device): The device (CPU/GPU) for training.
    """
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        # Zip through source and target dataloaders
        for source_batch, target_batch in zip(source_dataloader, target_dataloader):
            # Prepare source data
            source_data, source_labels = source_batch
            source_domains = torch.zeros(source_data.size(0), dtype=torch.long).to(device)  # 0 indicates source domain

            # Prepare target data
            target_data, target_labels = target_batch
            target_domains = torch.ones(target_data.size(0), dtype=torch.long).to(device)  # 1 indicates target domain

            # Combine source and target data
            combined_data = torch.cat([source_data, target_data], dim=0).to(device)
            combined_labels = torch.cat([source_labels, target_labels], dim=0).to(device)
            combined_domains = torch.cat([source_domains, target_domains], dim=0).to(device)

            # Forward pass and compute embeddings
            embeddings = model(combined_data)

            # Compute CCSA Loss
            loss = ccsa_loss(embeddings, combined_labels, combined_domains)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        # Learning rate adjustment after the epoch
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(source_dataloader)}")


In [43]:
train_loop(
    model,
    criterion,
    optimizer,
    source_dataloader=webcam_dataloader,
    target_dataloader=amazon_dataloader,
    num_epochs=10,
    device=device
)

Epoch 1/10, Loss: 1.981586091859
Epoch 2/10, Loss: 1.830453886304583
Epoch 3/10, Loss: 1.8580328992434911
Epoch 4/10, Loss: 1.8459636688232421
Epoch 5/10, Loss: 1.8349796022687639
Epoch 6/10, Loss: 1.8140858394759043
Epoch 7/10, Loss: 1.786916880948203
Epoch 8/10, Loss: 1.7980996659823827
Epoch 9/10, Loss: 1.7911572643688747
Epoch 10/10, Loss: 1.7589624319757735


**Testing loop**

In [44]:
def test_loop(model, dataloader,criterion, device):
  model.eval()
  total_loss = 0.0
  correct = 0
  total = 0

  with torch.no_grad():
    for batch in dataloader:
      data, labels = batch
      data, labels = data.to(device), labels.to(device)

      #forward pass
      outputs = model(data)
      loss = criterion(outputs, labels)

      # accumate loss
      total_loss += loss.item()

      # calculate accuracy
      _, predicted = torch.max(outputs, dim=1)
      # print('Predicted:', predicted, 'True label:', labels)
      correct += (predicted == labels).sum().item()
      total += labels.size(0)


  # compute average loss and accuracy
  avg_loss = total_loss / len(dataloader)
  accuracy = correct / total * 100

  print(f"Validation Loss: {avg_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
  return avg_loss, accuracy

In [46]:
# # Assuming `amazon_val_dataloader` and necessary components are defined
# avg_loss, accuracy = test_loop(
#     model=model,
#     dataloader=amazon_val_dataloader,
#     criterion=torch.nn.CrossEntropyLoss(),
#     device=device
# )
