In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
import os


In [None]:
# Hyperparameters
batch_size = 64
num_epochs = 20
lr = 0.01
lambda_domain = 0.1  # Domain adaptation weight
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# DANN (Domain Adversarial Neural Network)

DANN is a domain adaptation method that uses adversarial training to learn domain-invariant features.

**Key Components:**
1. **Feature Extractor (Gf)**: Extracts features from input images
2. **Label Classifier (Gy)**: Classifies the task labels
3. **Domain Discriminator (Gd)**: Distinguishes between source and target domains
4. **Gradient Reversal Layer (GRL)**: Reverses gradients during backpropagation to enable adversarial training

**Training Objective:**
- Minimize label classification loss on source domain
- Maximize domain discrimination loss (via GRL) to learn domain-invariant features


In [None]:
# Gradient Reversal Layer
class GradientReversalLayer(torch.autograd.Function):
    """Gradient Reversal Layer for adversarial domain adaptation"""
    
    @staticmethod
    def forward(ctx, x, lambda_param):
        ctx.lambda_param = lambda_param
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_param, None

class GRL(nn.Module):
    def __init__(self, lambda_param=1.0):
        super(GRL, self).__init__()
        self.lambda_param = lambda_param
    
    def forward(self, x):
        return GradientReversalLayer.apply(x, self.lambda_param)


In [None]:
# Feature Extractor (Gf)
class FeatureExtractor(nn.Module):
    """Feature extractor network"""
    def __init__(self, input_channels=1):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=5),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=5),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return x


In [None]:
# Label Classifier (Gy)
class LabelClassifier(nn.Module):
    """Task label classifier"""
    def __init__(self, num_classes=10):
        super(LabelClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(x)


In [None]:
# Domain Discriminator (Gd)
class DomainDiscriminator(nn.Module):
    """Domain discriminator to distinguish source and target domains"""
    def __init__(self):
        super(DomainDiscriminator, self).__init__()
        self.grl = GRL(lambda_param=lambda_domain)
        self.discriminator = nn.Sequential(
            nn.Linear(256 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 2)  # Binary classification: source or target
        )
    
    def forward(self, x):
        x = self.grl(x)
        return self.discriminator(x)


In [None]:
# Complete DANN Model
class DANN(nn.Module):
    """Domain Adversarial Neural Network"""
    def __init__(self, num_classes=10, input_channels=1):
        super(DANN, self).__init__()
        self.feature_extractor = FeatureExtractor(input_channels)
        self.label_classifier = LabelClassifier(num_classes)
        self.domain_discriminator = DomainDiscriminator()
    
    def forward(self, x, alpha=1.0):
        # Extract features
        features = self.feature_extractor(x)
        
        # Classify labels
        class_output = self.label_classifier(features)
        
        # Discriminate domains (with gradient reversal)
        domain_output = self.domain_discriminator(features)
        
        return class_output, domain_output


In [None]:
# Prepare datasets (using MNIST as source, MNIST-M as target for demonstration)
# In practice, you would use different domains (e.g., real photos vs. synthetic images)

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Source domain: MNIST
source_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)

# For demonstration, we'll use a subset of MNIST as "target" domain
# In real applications, target would be a different dataset (e.g., MNIST-M, SVHN)
target_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True)

print(f'Source dataset size: {len(source_dataset)}')
print(f'Target dataset size: {len(target_dataset)}')


In [None]:
# Initialize model (MNIST is 1 channel)
model = DANN(num_classes=10, input_channels=1).to(device)

# Optimizers
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

# Loss functions
criterion_class = nn.CrossEntropyLoss()
criterion_domain = nn.CrossEntropyLoss()

print(model)


In [None]:
# Training function
def train_epoch(model, source_loader, target_loader, optimizer, epoch, num_epochs):
    model.train()
    
    # Create iterators
    source_iter = iter(source_loader)
    target_iter = iter(target_loader)
    
    total_class_loss = 0
    total_domain_loss = 0
    correct_class = 0
    total_samples = 0
    
    # Process batches
    min_len = min(len(source_loader), len(target_loader))
    
    for batch_idx in range(min_len):
        # Get source batch
        try:
            source_data, source_labels = next(source_iter)
        except StopIteration:
            source_iter = iter(source_loader)
            source_data, source_labels = next(source_iter)
        
        # Get target batch
        try:
            target_data, _ = next(target_iter)
        except StopIteration:
            target_iter = iter(target_loader)
            target_data, _ = next(target_iter)
        
        # Move to device
        source_data = source_data.to(device)
        source_labels = source_labels.to(device)
        target_data = target_data.to(device)
        
        # Create domain labels: 0 for source, 1 for target
        source_domain_labels = torch.zeros(source_data.size(0), dtype=torch.long).to(device)
        target_domain_labels = torch.ones(target_data.size(0), dtype=torch.long).to(device)
        
        # Combine source and target
        combined_data = torch.cat([source_data, target_data], dim=0)
        combined_domain_labels = torch.cat([source_domain_labels, target_domain_labels], dim=0)
        
        # Forward pass
        optimizer.zero_grad()
        class_output, domain_output = model(combined_data)
        
        # Split outputs
        source_class_output = class_output[:source_data.size(0)]
        source_domain_output = domain_output[:source_data.size(0)]
        target_domain_output = domain_output[source_data.size(0):]
        
        # Classification loss (only on source domain)
        class_loss = criterion_class(source_class_output, source_labels)
        
        # Domain loss (on both domains)
        domain_loss = (criterion_domain(source_domain_output, source_domain_labels) + 
                      criterion_domain(target_domain_output, target_domain_labels)) / 2
        
        # Total loss
        loss = class_loss + lambda_domain * domain_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_class_loss += class_loss.item()
        total_domain_loss += domain_loss.item()
        _, predicted = source_class_output.max(1)
        correct_class += predicted.eq(source_labels).sum().item()
        total_samples += source_labels.size(0)
        
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{min_len}], '
                  f'Class Loss: {class_loss.item():.4f}, Domain Loss: {domain_loss.item():.4f}, '
                  f'Acc: {100.*correct_class/total_samples:.2f}%')
    
    avg_class_loss = total_class_loss / min_len
    avg_domain_loss = total_domain_loss / min_len
    accuracy = 100. * correct_class / total_samples
    
    return avg_class_loss, avg_domain_loss, accuracy


In [None]:
# Training loop
train_losses = []
train_accs = []

for epoch in range(num_epochs):
    class_loss, domain_loss, accuracy = train_epoch(
        model, source_loader, target_loader, optimizer, epoch, num_epochs
    )
    train_losses.append(class_loss)
    train_accs.append(accuracy)
    print(f'Epoch [{epoch+1}/{num_epochs}] - Class Loss: {class_loss:.4f}, '
          f'Domain Loss: {domain_loss:.4f}, Accuracy: {accuracy:.2f}%')
    print('-' * 60)


In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Classification Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# Evaluation on target domain
def evaluate(model, data_loader, domain_name='Target'):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)
            class_output, _ = model(data)
            _, predicted = class_output.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'{domain_name} Domain Accuracy: {accuracy:.2f}%')
    return accuracy

# Evaluate on source and target
source_acc = evaluate(model, source_loader, 'Source')
target_acc = evaluate(model, target_loader, 'Target')
