**Introduction**

The Residual Transfer Network (RTN) is a framework designed for domain adaptation, which is the task of adapting a model trained on a source domain to perform well on a target domain with a different data distribution. RTN builds on the concept of residual learning and integrates techniques to reduce the domain gap between the source and target distributions.

  Why is it used for Unsupervised DA ?
  

Residual Transfer Networks (RTNs) shine particularly in unsupervised domain adaptation (UDA) because they are designed to address the lack of labeled data in the target domain by leveraging labeled data from the source domain. Here’s why RTNs are most effective in unsupervised settings, compared to supervised or semi-supervised approaches:

**Core Features of RTN**

1.   Residual Learning:
2.   Domain Adaptation Mechanisms:
3.   Adversarial Training:
4.   Task-Specific Output:

**Imports**

In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from torchvision import models

import kagglehub

from itertools import cycle, islice

In [2]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

**Data Processing**

In [3]:
# Download latest version
path = kagglehub.dataset_download("mei1963/domainnet")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/domainnet


In [4]:

def walk_through_dir(dir_path):
  """
  Walks through dir_path returning its contents.
  Args:
    dir_path (str or pathlib.Path): target directory

  Returns:
    A print out of:
      number of subdiretories in dir_path
      number of images (files) in each subdirectory
      name of each subdirectory
  """
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [5]:
path

'/kaggle/input/domainnet'

In [6]:
# walk_through_dir(path)

In [7]:
Clipart_path = os.path.join(path, "DomainNet/clipart")
Real_path = os.path.join(path, "DomainNet/real")

In [8]:
simple_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])


In [9]:
clipart_dataset = ImageFolder(root=Clipart_path, transform=simple_transform)
real_dataset = ImageFolder(root=Real_path, transform=simple_transform)

# Example DataLoader with performance optimizations
source_data_loader = DataLoader(real_dataset, batch_size=64, shuffle=True,
                                num_workers=6, pin_memory=True, drop_last=True)

target_dataloader = DataLoader(clipart_dataset, batch_size=64, shuffle=True,
                               num_workers=6, pin_memory=True,drop_last=True)

In [10]:
len(target_dataloader), len(source_data_loader)

(763, 2739)

In [11]:
# for batch in target_dataloader:
#     print(batch)
#     break


**Models**

**Feature extractor**

ResNet (Residual Network) can be used as the feature extractor in a Residual Transfer Network (RTN). In fact, ResNet is often preferred in domain adaptation tasks because:

In [12]:
feature_extractor = models.resnet50(pretrained=True)

for param in feature_extractor.parameters():
    param.requires_grad = False

# Unfreeze specific layers
for name, param in feature_extractor.named_parameters():
    if "layer4" in name:  # Example: Unfreeze layer4
        param.requires_grad = True


feature_extractor.fc  = nn.Identity()



**Residual Module**

The Residual Module in a Residual Transfer Network (RTN) is a key component designed to refine and adjust features extracted from the source domain, aligning them with the target domain. This approach leverages residual learning, making the adaptation process more efficient and focused.

In [13]:
class ResidualModule(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size=3, stride=1, padding=1):
        super().__init__()
        # Linear transformation for residual
        self.residual_transform = nn.Linear(input_dim, output_dim)
        # Batch Norm
        self.batch_norm = nn.LayerNorm(output_dim)
        # ReLU activation
        self.activation = nn.ReLU()

    def forward(self, x):
        # Apply residual transformation
        residual = self.residual_transform(x)
        # print('residual', residual.shape)
        batch_norm = self.batch_norm(residual)
        activation = self.activation(batch_norm)

                # Adjust dimensions if needed
        if x.shape[1] != activation.shape[1]:
            x = self.residual_transform(x)

        adjusted_features = x + activation
        return adjusted_features

        return adjusted_features



**Domain Discrepancy Minimization**

Domain discrepancy minimization is a crucial step in domain adaptation tasks, including in pipelines like the Residual Transfer Network (RTN). It focuses on reducing the difference in data distributions between the source domain (labeled) and the target domain (unlabeled or sparsely labeled). By minimizing this discrepancy, the model learns domain-invariant features that generalize well across both domains.

The MMD loss is minimized during training to bring the two distributions closer:

MMD
2
=
∥
1
𝑛
𝑠
∑
𝑖
=
1
𝑛
𝑠
𝜙
(
𝑥
𝑖
𝑠
)
−
1
𝑛
𝑡
∑
𝑖
=
1
𝑛
𝑡
𝜙
(
𝑥
𝑖
𝑡
)
∥
2

In [14]:
class MMDLoss(nn.Module):
  def __init__(self, kernel='rbf', gamma=1.0):
    super().__init__()
    self.kernel = kernel
    self.gamma = gamma

  def pairwise_distance(self, x, y):
    x_norm = torch.sum(x**2, dim=-1)
    y_norm = torch.sum(y**2, dim=-1)
    dist = x_norm + y_norm - 2 * torch.matmul(x, y.t())

    return torch.clamp(dist, min=0.0)

  def rbf_kernel(self, x, y):
    dist = self.pairwise_distance(x, y)
    return torch.exp(-self.gamma * dist)

  def linear_kernel(self, x, y):
    return torch.matmul(x, y.t())

  def forward(self, source_features, target_features):
      """
      Compute the Maximum Mean Discrepancy (MMD) loss.

      Args:
          source_features (torch.Tensor): Features from source domain.
          target_features (torch.Tensor): Features from target domain.

      Returns:
          torch.Tensor: MMD loss.
      """
      if self.kernel == 'rbf':
          kernel_fn = self.rbf_kernel
      elif self.kernel == 'linear':
          kernel_fn = self.linear_kernel
      else:
          raise ValueError("Unsupported kernel type. Choose 'linear' or 'rbf'.")

      xx = kernel_fn(source_features, source_features).mean()
      yy = kernel_fn(target_features, target_features).mean()
      xy = kernel_fn(source_features, target_features).mean()
      mmd_loss = xx + yy - 2.0 * xy
      return torch.clamp(mmd_loss, min=0.0)

**Task-Specific Classifier**

In [15]:
class TaskClassifier(nn.Module):
  def __init__(self, input_dim, num_classes, dropout=0.5,*args, **kwargs) -> None:
     super().__init__(*args, **kwargs)
     self.fc1 = nn.Linear(input_dim, 256)
     self.relu = nn.ReLU()
     self.dropout = nn.Dropout(dropout)
     self.fc2 = nn.Linear(256, num_classes)

  def forward(self, x):
    x = self.fc1(x)
    # print('after fc1', x.shape)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.fc2(x)
    return x

By encapsulating its components into a single class, you ensure modularity and easy integration into your training loop.

In [16]:
# Components: Feature Extractor, Residual Module, Classifier
class RTN(nn.Module):
  def __init__(self, feature_extractor, residual_module, classifier):
    super().__init__()
    self.feature_extractor = feature_extractor
    self.residual_module = residual_module
    self.classifier = classifier

  def forward(self, x):
      features = self.feature_extractor(x)
      # print('feature shape',features.shape)
      aligned_features = self.residual_module(features)
      predictions = self.classifier(aligned_features)
      return aligned_features, predictions

**Training Loop**

In [17]:
feature_extractor = feature_extractor.to(device)
residual_module = ResidualModule(2048, 512).to(device)  # Residual module
classifier = TaskClassifier(input_dim=512, num_classes=345).to(device)

mmd_loss_fn = MMDLoss(kernel='rbf', gamma=1.0).to(device)
classification_loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classifier.parameters()), lr=0.001)


model = RTN(feature_extractor, residual_module, classifier)

In [18]:
scaler = GradScaler()

torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()

target_dataloader_cycle = cycle(target_dataloader)

epochs = 5
for epoch in range(epochs):
  model.train()
  total_classification_loss = 0.0
  total_mmd_loss = 0.0

  for source_data, target_data in zip(source_data_loader, islice(target_dataloader_cycle, len(source_data_loader))):

      # Supervised source domain data
      source_inputs, source_labels = source_data
      source_inputs, source_labels = source_inputs.to(device), source_labels.to(device)

      # Unsupervised target domain data
      target_inputs, _ = target_data  # Target data has no labels
      target_inputs = target_inputs.to(device)

      optimizer.zero_grad()

      # Enable mixed precision
      with autocast(device_type='cuda'):
          source_features, source_predictions = model(source_inputs)
          target_features, _ = model(target_inputs)

          classification_loss = classification_loss_fn(source_predictions, source_labels)
          mmd_loss = mmd_loss_fn(source_features, target_features)

          total_loss = classification_loss + mmd_loss

      # Backward pass with scaled gradients
      scaler.scale(total_loss).backward()
      scaler.step(optimizer)
      scaler.update()

      total_classification_loss += classification_loss.item()
      total_mmd_loss += mmd_loss.item()
  print(f"Epoch [{epoch+1}/{epochs}], Classification Loss: {total_classification_loss:.4f}, MMD Loss: {total_mmd_loss:.4f}")

Epoch [1/5], Classification Loss: 6308.9418, MMD Loss: 15.3575
Epoch [2/5], Classification Loss: 3660.6271, MMD Loss: 2.8869
Epoch [3/5], Classification Loss: 2920.0178, MMD Loss: 1.5874
Epoch [4/5], Classification Loss: 2421.1814, MMD Loss: 1.2705
Epoch [5/5], Classification Loss: 2039.2852, MMD Loss: 1.0331


The classification loss is steadily decreasing, which suggests that your model is learning well. Meanwhile, the MMD loss is gradually reducing, meaning the feature distributions between your source and target domains are aligning more closely.