**Introduction**

**Maximum Mean Discrepancy**

Maximum Mean Discrepancy (MMD) is a kernel based statistical test used to determine whether two distributions are the same. It does this by measuring the distance between the means of the two distributions

MMD(X, Y) = ∥E[φ(X)] - E[φ(Y)]∥²

𝐸
[
𝜑
(
𝑋
)
]
: This is the expectation (mean) of the feature map
𝜑
 applied to the samples from distribution
𝑋
.

𝐸
[
𝜑
(
𝑌
)
]
: This is the expectation (mean) of the feature map
𝜑
 applied to the samples from distribution
𝑌
.

∥
.
.
.
∥ ^ 2
: This represents the squared norm (distance) between the two mean feature vectors.

Why use MMD compared to other loss function for DA??

When it comes to domain adaptation (DA), MMD (Maximum Mean Discrepancy) has distinct advantages over Cross-Entropy Loss and Mean Squared Error (MSE) Loss due to its focus on aligning distributions rather than individual predictions. It directly addresses the issue of domain shift by aligning the distributions.

**Implementation of MMD**

Imports

In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import transforms
import torchvision.models as models

from PIL import Image

import xml.etree.ElementTree as ET

In [2]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


# Set device assertion flag for more detailed errors
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Custom loss function

- In the context of Maximum Mean Discrepancy (MMD), the loss function used is typically the MMD value itself. The goal is to minimize the MMD between two distributions, which effectively measures how similar they are in the feature space.

In [4]:
class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf'):
        super(MMDLoss, self).__init__()
        self.kernel_type = kernel_type

    def gaussian_kernel(self, x, y, sigma=1.0):
        # Compute the Gaussian (RBF) kernel between x and y
        dist = torch.cdist(x, y)
        K = torch.exp(-dist ** 2 / (2 * sigma ** 2))
        return K

    def forward(self, x, y):
        if self.kernel_type == 'rbf':
            return torch.mean(self.gaussian_kernel(x, x)) + \
                   torch.mean(self.gaussian_kernel(y, y)) - \
                   2 * torch.mean(self.gaussian_kernel(x, y))
        else:
            raise ValueError('Unknown kernel type: {}'.format(self.kernel_type))

Datasets



*   Target data - CIFAR10
*   Source data - Pascal-voc-2012



In [5]:
# Define transformations for the training and test datasets
simple_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a common size
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [6]:
import os
def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [7]:
# Downloading the target dataset from torchvision datasets
target_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=simple_transform)
target_testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=simple_transform)


target_train_dataloader = torch.utils.data.DataLoader(target_trainset, shuffle=True, batch_size=32)
target_test_dataloader = torch.utils.data.DataLoader(target_testset, shuffle=True,batch_size=32)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
# Downlaoding the Source dataset from kaggle
import kagglehub

source_path = kagglehub.dataset_download("gopalbhattrai/pascal-voc-2012-dataset")

print("Path to source_dataset files:", source_path)


Path to source_dataset files: /root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1


In [9]:
walk_through_dir(source_path)

There are 2 directories and 0 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1'.
There are 1 directories and 0 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val'.
There are 5 directories and 1 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val'.
There are 0 directories and 2913 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val/SegmentationObject'.
There are 0 directories and 17125 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val/JPEGImages'.
There are 0 directories and 17125 images in '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val/Annotations'.
There are 4 directories and 0 images in '

In [10]:
from __future__ import annotations
# Custom dataset for the Pascal-2012
class PascalDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.image_list = os.listdir(self.image_dir)


    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]
        image_path = os.path.join(self.image_dir, image_name)
        # annotation_path = os.path.join(self.annotation_dir).

        image = Image.open(image_path).convert('RGB')

        annotation_path = os.path.join(self.annotation_dir, image_name.replace('.jpg', '.xml'))

        # Parse the annotation file
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        label = root.find('object').find('name').text

        # Convert label to an integer (assuming labels are stored as strings and are convertible to integers)
        label_dict = {
            'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
            'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10,
            'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
            'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19
        }
        label = label_dict[label]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label)

In [11]:
image_dir = '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val/JPEGImages'
annotation_dir = '/root/.cache/kagglehub/datasets/gopalbhattrai/pascal-voc-2012-dataset/versions/1/VOC2012_train_val/VOC2012_train_val/Annotations'

pascal_voc_dataset = PascalDataset(image_dir=image_dir, annotation_dir=annotation_dir, transform=simple_transform)

source_loader = torch.utils.data.DataLoader(pascal_voc_dataset, batch_size=32, shuffle=True)

In [12]:
# img, label = next(iter(source_loader))

# # Batch size will now be 1, try changing the batch_size parameter above and see what happens
# print(f"Image shape: {img.shape} -> [batch_size, color_channels, height, width]")
# print(f"Label shape: {label}")
# print(len(label))

Models

In [13]:
class FeatureExtractor(nn.Module):

  def __init__(self):
    super().__init__()
    # Load pre-trained model
    self.resnet = models.resnet18(pretrained=True)
    # Remove last layer
    self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])


  def forward(self, x):
    x = self.resnet(x)
    x = x.view(x.size(0), -1)
    return x


class Classifier(nn.Module):
  def __init__(self, input_put, num_classes):
    super().__init__()
    self.classifier = nn.Linear(input_put, num_classes)

  def forward(self, x):
    return self.classifier(x)


Traing loop with MMD Loss

In [14]:
feature_extractor = FeatureExtractor().to(device)
classifier = Classifier(512, 20).to(device)

mmd_loss_fn = MMDLoss(kernel_type='rbf')
loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classifier.parameters()), lr=1e-4)



In [None]:
num_epochs = 10

for epoch in range(num_epochs):
  feature_extractor.train()
  classifier.train()
  total_loss = 0.0
  total_mmd_loss = 0.0
  total_cls_loss = 0.0

  source_iter = iter(source_loader)
  target_iter = iter(target_train_dataloader)

  num_batchs = min(len(source_loader), len(target_train_dataloader))


  for i in range(num_batchs):
    try:
        source_data, source_labels = next(source_iter)
    except StopIteration:
        source_iter = iter(source_loader)
        source_data, source_labels = next(source_iter)

    try:
        target_data, _ = next(target_iter)
    except StopIteration:
        target_iter = iter(target_train_dataloader)
        target_data, _ = next(target_iter)

    source_data = source_data.to(device)
    source_labels = source_labels.to(device)
    target_data = target_data.to(device)

    optimizer.zero_grad()

    # Forward pass through the feature extractor
    source_features = feature_extractor(source_data)
    target_features = feature_extractor(target_data)

    #Forward pass through the classifier
    source_logits = classifier(source_features)


    # Compute classification loss on source data
    cls_loss = loss_fn(source_logits, source_labels)

    # Compute the MMD loss between the source and target features
    mmd_loss = mmd_loss_fn(source_features, target_features)

    # Total loss
    loss = cls_loss + mmd_loss

    # backward pass and optimizer
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_cls_loss += cls_loss.item()
    total_mmd_loss += mmd_loss.item()

print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, Classification Loss: {total_cls_loss:.4f}, MMD Loss: {total_mmd_loss:.4f}')

Test Loop

In [None]:
def test_loop(feature_extractor, classifier, test_loader):
  feature_extractor.eval()
  classifier.eval()

  total_correct = 0
  total_samples = 0

  with torch.no_grad():
    for data, labels in test_loader:
      data = data.to(device)
      labels = labels.to(device)

      # Extract features
      features = feature_extractor(data)

      # classifier model preds
      preds = classifier(features)

      _, predicted = torch.max(preds.data, 1)
      total_samples += labels.size(0)
      correct += (predicted == labels).sum()


  accuracy = 100 * correct.item() / total_samples
  print(f'Accuracy of the model on the target test images: {accuracy:.2f}%')


# Evaluate the model
test_loop(feature_extractor, classifier, target_test_dataloader)

**Results and Observations**

Challenges and Solutions:

*   Long training time
    - Use a stronger GPU
*   Low Acuuracy due to only 4 common classes between the two Domains
    - Change the source dataset or use differnt domains to practice MMD

