# Selective Alignment Transfer for Domain Adaptation in Skin Lesion Analysis

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torchvision.transforms as transforms

In [None]:
## transformations
transform = transforms.Compose([
    transforms.Resize((288, 288)),  ## EfficientNet-B2
    transforms.ToTensor(),
])

In [None]:

## Loading datasets
train_dermoscopic = datasets.ImageFolder("/train/Derm", transform=transform)
train_clinical = datasets.ImageFolder("/train/Clinic", transform=transform)
val_dermoscopic = datasets.ImageFolder("/valid/Derm", transform=transform)
val_clinical = datasets.ImageFolder("/valid/Clinic", transform=transform)


## Creating dataloaders
batch_size = 32

train_dermos_loader = DataLoader(train_dermoscopic, batch_size=batch_size, shuffle=True)
train_clinical_loader = DataLoader(train_clinical, batch_size=batch_size, shuffle=True)
val_dermos_loader = DataLoader(val_dermoscopic, batch_size=batch_size, shuffle=False)
val_clinical_loader = DataLoader(val_clinical, batch_size=batch_size, shuffle=False)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import time


## Feature Selector part

class FeatureSelector(nn.Module):
    def __init__(self, feature_dim):
        super(FeatureSelector, self).__init__()
        self.feature_dim = feature_dim

        ### Simple feed-forward network for feature selection
        self.network = nn.Sequential(
            nn.Linear(feature_dim * 2, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim),
            nn.Sigmoid()
        )

    def forward(self, source_features, target_features):

        if source_features.dim() == 1:
            source_features = source_features.unsqueeze(0)
        if target_features.dim() == 1:
            target_features = target_features.unsqueeze(0)

        ## Concatenatening domain features
        combined = torch.cat([source_features, target_features], dim=1)

        ## Generating feature importance weights
        weights = self.network(combined)
        return weights.squeeze(0)  # Removing batch dimension for output

class DomainAdaptiveNet(nn.Module):
    def __init__(self, base_model, feature_dim):
        super(DomainAdaptiveNet, self).__init__()
        self.features = base_model.features
        self.feature_selector = FeatureSelector(feature_dim)
        self.classifier = base_model.classifier
        self.last_feature_weights = None

    def forward(self, source_images, target_images=None):
        # Extracting features
        source_features = self.features(source_images)
        B, C, H, W = source_features.shape
        source_features_flat = source_features.view(B, C, -1).mean(dim=2)

        if self.training and target_images is not None:
            target_features = self.features(target_images)
            target_features_flat = target_features.view(B, C, -1).mean(dim=2)

            ### Getting mean features for each domain
            source_mean = torch.mean(source_features_flat, dim=0)
            target_mean = torch.mean(target_features_flat, dim=0)

            ### Learning feature importance weights
            feature_weights = self.feature_selector(source_mean, target_mean)

            ### Applyig feature weights to both domains
            source_features = source_features * feature_weights.view(1, -1, 1, 1)
            target_features = target_features * feature_weights.view(1, -1, 1, 1)

            ## Global average pooling
            source_pooled = F.adaptive_avg_pool2d(source_features, (1, 1)).view(B, -1)
            target_pooled = F.adaptive_avg_pool2d(target_features, (1, 1)).view(B, -1)

            ## Classification task
            source_output = self.classifier(source_pooled)
            target_output = self.classifier(target_pooled)

            ### Storing weights for inference
            self.last_feature_weights = feature_weights.detach()

            ### Storing intermediate features for alignment loss
            self.source_features_flat = source_features_flat
            self.target_features_flat = target_features_flat

            return source_output, target_output, feature_weights

        else:
            ## Inference mode - only process source images
            if self.last_feature_weights is None:
                self.last_feature_weights = torch.ones(C, device=source_features.device)
            source_features = source_features * self.last_feature_weights.view(1, -1, 1, 1)
            source_pooled = F.adaptive_avg_pool2d(source_features, (1, 1)).view(B, -1)
            return self.classifier(source_pooled)

def train_model(model, train_clinical_loader, train_dermos_loader, val_clinical_loader, val_dermos_loader,
                criterion, optimizer, num_epochs=50, save_path='./SAT-DA.pth'):
    device = next(model.parameters()).device
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        ## Training phase
        model.train()
        total_loss = 0
        correct_source = 0
        correct_target = 0
        total = 0

        for (source_images, source_labels), (target_images, target_labels) in zip(
            train_clinical_loader, train_dermos_loader):

            source_images, source_labels = source_images.to(device), source_labels.to(device)
            target_images, target_labels = target_images.to(device), target_labels.to(device)

            optimizer.zero_grad()

            ## Forward pass
            source_output, target_output, feature_weights = model(source_images, target_images)

            ## Losses and backward pass (same as before)
            source_loss = criterion(source_output, source_labels)
            target_loss = criterion(target_output, target_labels)

            ## Feature alignment loss
            weighted_source_features = model.source_features_flat * feature_weights.unsqueeze(0)
            weighted_target_features = model.target_features_flat * feature_weights.unsqueeze(0)

            source_mean = weighted_source_features.mean(dim=0)
            target_mean = weighted_target_features.mean(dim=0)

            alignment_loss = F.mse_loss(source_mean, target_mean)
            diversity_loss = -torch.std(feature_weights)

            loss = source_loss + target_loss + 0.1 * alignment_loss + 0.01 * diversity_loss

            loss.backward()
            optimizer.step()

            ### Calculating accuracies
            _, predicted_source = torch.max(source_output.data, 1)
            _, predicted_target = torch.max(target_output.data, 1)
            total += source_labels.size(0)
            correct_source += (predicted_source == source_labels).sum().item()
            correct_target += (predicted_target == target_labels).sum().item()

            total_loss += loss.item()

        ## Validation phase
        model.eval()
        val_correct_source = 0
        val_correct_target = 0
        val_total = 0

        with torch.no_grad():
            ## Validate source domain
            for val_source_images, val_source_labels in val_clinical_loader:
                val_source_images = val_source_images.to(device)
                val_source_labels = val_source_labels.to(device)

                val_source_output = model(val_source_images)
                _, val_predicted_source = torch.max(val_source_output.data, 1)
                val_correct_source += (val_predicted_source == val_source_labels).sum().item()
                val_total += val_source_labels.size(0)

            ## Validate target domain
            for val_target_images, val_target_labels in val_dermos_loader:
                val_target_images = val_target_images.to(device)
                val_target_labels = val_target_labels.to(device)

                val_target_output = model(val_target_images)
                _, val_predicted_target = torch.max(val_target_output.data, 1)
                val_correct_target += (val_predicted_target == val_target_labels).sum().item()

        ### Calculating accuracies
        train_source_acc = 100 * correct_source / total
        train_target_acc = 100 * correct_target / total
        val_source_acc = 100 * val_correct_source / val_total
        val_target_acc = 100 * val_correct_target / val_total

        ## Average validation accuracy
        val_avg_acc = (val_source_acc + val_target_acc) / 2

        ## Simple model saving
        if val_avg_acc > best_val_acc:
            best_val_acc = val_avg_acc
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved with validation accuracy: {val_avg_acc:.2f}%")

        ### Printing statistics
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Loss: {total_loss/len(train_clinical_loader):.4f}')
        print(f'Train - Source Acc: {train_source_acc:.2f}%, Target Acc: {train_target_acc:.2f}%')
        print(f'Val - Source Acc: {val_source_acc:.2f}%, Target Acc: {val_target_acc:.2f}%')

        if hasattr(model, 'last_feature_weights'):
            weights = model.last_feature_weights
            print(f'Feature weights - Mean: {weights.mean().item():.4f}, '
                  f'Std: {weights.std().item():.4f}, '
                  f'Max: {weights.max().item():.4f}, '
                  f'Min: {weights.min().item():.4f}')

# Initialize

In [None]:
### Initialize model
base_model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
feature_dim = 1408
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model = DomainAdaptiveNet(base_model, feature_dim).to(device)

## Train model with validation
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
train_model(model, train_clinical_loader, train_dermos_loader,
           val_clinical_loader, val_dermos_loader,
           criterion, optimizer)