In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score
import numpy as np
from PIL import Image
import os
from torchvision.models import mobilenet_v2

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 50
batch_size = 32
learning_rate = 0.001
num_classes = 3  # Adjust based on your dataset

# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.images = []
        self.labels = []
        for cls in self.classes:
            class_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(class_dir):
                self.images.append(os.path.join(class_dir, img_name))
                self.labels.append(self.class_to_idx[cls])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the pre-trained feature extractor


In [2]:
class UnlabelledDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


In [3]:
# Data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Path to the unlabelled dataset directory

# Define transformations for the unlabelled dataset
transform_unlabelled = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
source_dataset = CustomDataset(root_dir='data/synthetic/cifar10', transform=transform)
target_dataset = UnlabelledDataset(root_dir='data/real/unlabelled', transform=transform_unlabelled)


source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True)



In [4]:
class FeatureExtractor(nn.Module):
    def __init__(self, num_classes):
        super(FeatureExtractor, self).__init__()
        # Load pre-trained MobileNetV2
        mobilenet = mobilenet_v2(pretrained=True)
        
        # Freeze all parameters
        for param in mobilenet.parameters():
            param.requires_grad = False
        
        # Use all layers except the last classifier
        self.features = mobilenet.features
        
        # Add a simple classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, num_classes)  # MobileNetV2's last conv layer has 1280 channels
        )

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
def load_model(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))
    print(f"Model loaded from {path}")
    return model

# Load the saved model
feature_extractor = FeatureExtractor(num_classes=3).to(device)
feature_extractor = load_model(feature_extractor, 'model.pth', device)
feature_extractor.classifier = nn.Identity()  # Remove the classifier part





Model loaded from model.pth


  model.load_state_dict(torch.load(path, map_location=device))


In [5]:
class DANN(nn.Module):
    def __init__(self, feature_extractor, num_classes):
        super(DANN, self).__init__()
        self.feature_extractor = feature_extractor
        self.class_classifier = nn.Linear(1280, num_classes)
        self.domain_classifier = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x, alpha):
        features = self.feature_extractor(x)
        class_output = self.class_classifier(features)
        reverse_features = GradientReversalLayer.apply(features, alpha)
        domain_output = self.domain_classifier(reverse_features)
        return class_output, domain_output


class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None



In [6]:
def ood_filter(model, data_loader, threshold=0.8):
    model.eval()
    filtered_data = []
    filtered_labels = []
    with torch.no_grad():
        for inputs in data_loader:
            inputs = inputs.to(device)
            outputs, _ = model(inputs, alpha=0)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, predicted_labels = torch.max(probabilities, dim=1)
            for i, prob in enumerate(max_probs):
                if prob > threshold:
                    filtered_data.append(inputs[i].cpu())
                    filtered_labels.append(predicted_labels[i].item())
    return filtered_data, filtered_labels


In [7]:
def train_dann(model, source_loader, target_loader, optimizer, num_epochs):
    global source_dataset  # Declare source_dataset as global
    for epoch in range(num_epochs):      
        model.train()
        for (source_data, source_labels), target_data in zip(source_loader, target_loader):
            source_data, source_labels = source_data.to(device), source_labels.to(device)
            target_data = target_data.to(device)  # Now target_data is already a tensor

            # Create domain labels
            source_domain = torch.zeros(source_data.size(0)).long().to(device)
            target_domain = torch.ones(target_data.size(0)).long().to(device)

            # Forward pass
            alpha = 2 / (1 + np.exp(-10 * epoch / num_epochs)) - 1
            source_class_output, source_domain_output = model(source_data, alpha)
            _, target_domain_output = model(target_data, alpha)

            # Compute losses
            class_loss = nn.CrossEntropyLoss()(source_class_output, source_labels)
            domain_loss = nn.CrossEntropyLoss()(source_domain_output, source_domain) + \
                          nn.CrossEntropyLoss()(target_domain_output, target_domain)
            
            loss = class_loss + 0.1 * domain_loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Class Loss: {class_loss.item():.4f}, Domain Loss: {domain_loss.item():.4f}')

        # OOD Filtering and Pseudolabeling (every 5 epochs)
        if (epoch + 1) % 5 == 0:
            filtered_data, filtered_labels = ood_filter(model, target_loader)
            print(f'Filtered samples: {len(filtered_data)}')
            
            # Add filtered data to source dataset (pseudolabeling)
            for data, label in zip(filtered_data, filtered_labels):
                img = transforms.ToPILImage()(data)
                source_dataset.images.append(img)
                source_dataset.labels.append(label)
            
            # Recreate source dataloader with updated dataset
            source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)



In [None]:
# Evaluation function
def evaluate(model, data_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            outputs, _ = model(inputs, alpha=0)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    return accuracy_score(all_labels, all_preds)

# Initialize model, optimizer
model = DANN(feature_extractor, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
train_dann(model, source_loader, target_loader, optimizer, num_epochs)

# Evaluate the model
source_accuracy = evaluate(model, source_loader)
target_accuracy = evaluate(model, target_loader)

print(f'Source Accuracy: {source_accuracy:.4f}')
print(f'Target Accuracy: {target_accuracy:.4f}')


Epoch [1/50], Loss: 0.6060, Class Loss: 0.4661, Domain Loss: 1.3997
