In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
import os


In [None]:
# Hyperparameters
batch_size = 64
num_epochs = 20
lr = 0.01
lambda_recon = 0.1  # Reconstruction loss weight
lambda_sim = 0.1    # Similarity loss weight
lambda_diff = 0.1   # Difference loss weight
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# DSN (Domain Separation Networks)

DSN is a domain adaptation method that separates shared and private features to learn domain-invariant representations.

**Key Components:**
1. **Shared Encoder (Es)**: Extracts domain-invariant features
2. **Private Encoder (Ep)**: Extracts domain-specific features  
3. **Reconstruction Decoder (D)**: Reconstructs images from shared + private features
4. **Task Classifier (C)**: Classifies using shared features
5. **Similarity Loss**: Ensures shared features are similar across domains
6. **Difference Loss**: Ensures shared and private features are different

**Training Objective:**
- Minimize task classification loss on source domain
- Minimize reconstruction loss (shared + private â†’ original image)
- Maximize similarity between shared features from different domains
- Maximize difference between shared and private features


In [None]:
# Shared Encoder (Es) - extracts domain-invariant features
class SharedEncoder(nn.Module):
    """Shared encoder for domain-invariant features"""
    def __init__(self, input_channels=1):
        super(SharedEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.encoder(x)


In [None]:
# Private Encoder (Ep) - extracts domain-specific features
class PrivateEncoder(nn.Module):
    """Private encoder for domain-specific features"""
    def __init__(self, input_channels=1):
        super(PrivateEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.encoder(x)


In [None]:
# Reconstruction Decoder (D) - reconstructs image from shared + private features
class ReconstructionDecoder(nn.Module):
    """Decoder to reconstruct images from shared and private features"""
    def __init__(self, output_channels=1):
        super(ReconstructionDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(256, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, output_channels, kernel_size=5, padding=2),
            nn.Tanh()
        )
    
    def forward(self, shared_feat, private_feat):
        # Concatenate shared and private features
        combined = torch.cat([shared_feat, private_feat], dim=1)
        return self.decoder(combined)


In [None]:
# Task Classifier (C) - classifies using shared features
class TaskClassifier(nn.Module):
    """Task classifier using shared features"""
    def __init__(self, num_classes=10):
        super(TaskClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, num_classes)
        )
    
    def forward(self, shared_feat):
        return self.classifier(shared_feat)


In [None]:
# Complete DSN Model
class DSN(nn.Module):
    """Domain Separation Network"""
    def __init__(self, num_classes=10, input_channels=1, output_channels=1):
        super(DSN, self).__init__()
        self.shared_encoder = SharedEncoder(input_channels)
        self.private_encoder = PrivateEncoder(input_channels)
        self.decoder = ReconstructionDecoder(output_channels)
        self.classifier = TaskClassifier(num_classes)
    
    def forward(self, x):
        # Encode features
        shared_feat = self.shared_encoder(x)
        private_feat = self.private_encoder(x)
        
        # Reconstruct image
        reconstructed = self.decoder(shared_feat, private_feat)
        
        # Classify using shared features
        class_output = self.classifier(shared_feat)
        
        return shared_feat, private_feat, reconstructed, class_output


In [None]:
# Loss functions
def similarity_loss(shared_s, shared_t):
    """Encourage shared features from source and target to be similar"""
    # L2 distance between shared features
    return F.mse_loss(shared_s, shared_t)

def difference_loss(shared, private):
    """Encourage shared and private features to be different"""
    # Cosine similarity (we want it to be low, so we maximize 1 - similarity)
    shared_flat = shared.view(shared.size(0), -1)
    private_flat = private.view(private.size(0), -1)
    
    # Normalize
    shared_norm = F.normalize(shared_flat, p=2, dim=1)
    private_norm = F.normalize(private_flat, p=2, dim=1)
    
    # Cosine similarity
    cosine_sim = (shared_norm * private_norm).sum(dim=1).mean()
    
    # We want to maximize difference, so minimize similarity
    return cosine_sim


In [None]:
# Prepare datasets
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Source domain: MNIST
source_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)

# Target domain: MNIST test set (for demonstration)
target_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True)

print(f'Source dataset size: {len(source_dataset)}')
print(f'Target dataset size: {len(target_dataset)}')


In [None]:
# Initialize model
model = DSN(num_classes=10, input_channels=1, output_channels=1).to(device)

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

# Loss functions
criterion_class = nn.CrossEntropyLoss()
criterion_recon = nn.MSELoss()

print(model)


In [None]:
# Training function
def train_epoch(model, source_loader, target_loader, optimizer, epoch, num_epochs):
    model.train()
    
    source_iter = iter(source_loader)
    target_iter = iter(target_loader)
    
    total_class_loss = 0
    total_recon_loss = 0
    total_sim_loss = 0
    total_diff_loss = 0
    correct_class = 0
    total_samples = 0
    
    min_len = min(len(source_loader), len(target_loader))
    
    for batch_idx in range(min_len):
        # Get source batch
        try:
            source_data, source_labels = next(source_iter)
        except StopIteration:
            source_iter = iter(source_loader)
            source_data, source_labels = next(source_iter)
        
        # Get target batch
        try:
            target_data, _ = next(target_iter)
        except StopIteration:
            target_iter = iter(target_loader)
            target_data, _ = next(target_iter)
        
        # Move to device
        source_data = source_data.to(device)
        source_labels = source_labels.to(device)
        target_data = target_data.to(device)
        
        # Forward pass on source
        shared_s, private_s, recon_s, class_s = model(source_data)
        
        # Forward pass on target
        shared_t, private_t, recon_t, class_t = model(target_data)
        
        # Classification loss (only on source)
        class_loss = criterion_class(class_s, source_labels)
        
        # Reconstruction loss (on both domains)
        recon_loss = (criterion_recon(recon_s, source_data) + 
                     criterion_recon(recon_t, target_data)) / 2
        
        # Similarity loss (encourage shared features to be similar)
        sim_loss = similarity_loss(
            shared_s.view(shared_s.size(0), -1).mean(dim=0),
            shared_t.view(shared_t.size(0), -1).mean(dim=0)
        )
        
        # Difference loss (encourage shared and private to be different)
        diff_loss_s = difference_loss(shared_s, private_s)
        diff_loss_t = difference_loss(shared_t, private_t)
        diff_loss = (diff_loss_s + diff_loss_t) / 2
        
        # Total loss
        loss = (class_loss + 
                lambda_recon * recon_loss - 
                lambda_sim * sim_loss + 
                lambda_diff * diff_loss)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_class_loss += class_loss.item()
        total_recon_loss += recon_loss.item()
        total_sim_loss += sim_loss.item()
        total_diff_loss += diff_loss.item()
        _, predicted = class_s.max(1)
        correct_class += predicted.eq(source_labels).sum().item()
        total_samples += source_labels.size(0)
        
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{min_len}], '
                  f'Class: {class_loss.item():.4f}, Recon: {recon_loss.item():.4f}, '
                  f'Sim: {sim_loss.item():.4f}, Diff: {diff_loss.item():.4f}, '
                  f'Acc: {100.*correct_class/total_samples:.2f}%')
    
    avg_class_loss = total_class_loss / min_len
    avg_recon_loss = total_recon_loss / min_len
    avg_sim_loss = total_sim_loss / min_len
    avg_diff_loss = total_diff_loss / min_len
    accuracy = 100. * correct_class / total_samples
    
    return avg_class_loss, avg_recon_loss, avg_sim_loss, avg_diff_loss, accuracy


In [None]:
# Training loop
train_losses = []
train_accs = []

for epoch in range(num_epochs):
    class_loss, recon_loss, sim_loss, diff_loss, accuracy = train_epoch(
        model, source_loader, target_loader, optimizer, epoch, num_epochs
    )
    train_losses.append(class_loss)
    train_accs.append(accuracy)
    print(f'Epoch [{epoch+1}/{num_epochs}] - Class: {class_loss:.4f}, '
          f'Recon: {recon_loss:.4f}, Sim: {sim_loss:.4f}, Diff: {diff_loss:.4f}, '
          f'Accuracy: {accuracy:.2f}%')
    print('-' * 60)


In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Classification Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# Evaluation function
def evaluate(model, data_loader, domain_name='Target'):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)
            _, _, _, class_output = model(data)
            _, predicted = class_output.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'{domain_name} Domain Accuracy: {accuracy:.2f}%')
    return accuracy

# Evaluate on source and target
source_acc = evaluate(model, source_loader, 'Source')
target_acc = evaluate(model, target_loader, 'Target')


In [None]:
# Visualize reconstructions
def visualize_reconstructions(model, data_loader, num_samples=8):
    model.eval()
    with torch.no_grad():
        data, labels = next(iter(data_loader))
        data = data[:num_samples].to(device)
        
        shared, private, recon, _ = model(data)
        
        # Denormalize for visualization
        data_vis = (data + 1) / 2
        recon_vis = (recon + 1) / 2
        
        # Plot
        fig, axes = plt.subplots(2, num_samples, figsize=(15, 4))
        for i in range(num_samples):
            axes[0, i].imshow(data_vis[i].cpu().squeeze(), cmap='gray')
            axes[0, i].axis('off')
            axes[0, i].set_title('Original')
            
            axes[1, i].imshow(recon_vis[i].cpu().squeeze(), cmap='gray')
            axes[1, i].axis('off')
            axes[1, i].set_title('Reconstructed')
        
        plt.tight_layout()
        plt.show()

visualize_reconstructions(model, source_loader)
