**Partial domain Adpatation**

Partial Domain Adaptation (PDA) is a domain adaptation scenario where the target domain's label space is a subset of the source domain's label space. Unlike traditional domain adaptation, PDA addresses the challenge of negative transfer caused by mismatched classes between the source and target domains

**Class Conditional Alignment (CCA-PDA)**

CCA - is a well-designed method for partial domain adaptation that directly tackles the class mismatch issue between source and target domains.


It uses a **multi-class adversarial loss** to perform this alignment, ensuring that only the shared classes between source and target are emphasized. This helps avoid **negative transfer** from source-only classes.

**Imports**

In [1]:
import os
import kagglehub

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder

import PIL.Image as Image


import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

**Datasets**

**Caltech** as the source and **Office-31** as the target is a classic partial domain adaptation (PDA) scenario, since Caltech has a broader label space (256 classes) while Office-31 has only 31. This means you’ll need to filter out the irrelevant Caltech classes to avoid negative transfer.

upload caltech source from kaggle

In [3]:
# Download latest version
source_path = kagglehub.dataset_download("jessicali9530/caltech256")

print("Path to dataset files:", source_path)


Downloading from https://www.kaggle.com/api/v1/datasets/download/jessicali9530/caltech256?dataset_version_number=2...


100%|██████████| 2.12G/2.12G [01:40<00:00, 22.6MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/jessicali9530/caltech256/versions/2


upload zipped office-31 file from google drive downloaded from mega.nz

In [4]:
# Where i downloaded office-31
# https://mega.nz/file/dSpjyCwR#9ctB4q1RIE65a4NoJy0ox3gngh15cJqKq1XpOILJt9s

!unzip -q "/content/drive/MyDrive/office-31/Original_images1.zip" -d /content/office31

In [5]:
# Load dataset
dataset_path = "/content/office31"

target_amazon_path = os.path.join(dataset_path, "Original_images/amazon")
target_dslr_path = os.path.join(dataset_path, "Original_images/dslr")
target_webcam_path = os.path.join(dataset_path, "Original_images/webcam")


**Data Processing**

In [6]:
# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
def walk_through_dir(dir_path):

  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [8]:
# Load source dataset (Caltech-10 subset)
source_dataset = datasets.ImageFolder(root=source_path, transform=transform)
source_loader = DataLoader(source_dataset, batch_size=64, shuffle=True)


target_dataset = datasets.ImageFolder(root=target_amazon_path, transform=transform)
target_loader = DataLoader(target_dataset, batch_size=64, shuffle=True)

**Feature extraction**

In [9]:
# Load a pretrained ResNet model
resnet50 = models.resnet50(pretrained=True)
resnet50.fc = torch.nn.Identity()  # Remove the final classification layer

resnet50 = resnet50.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 225MB/s]


In [10]:
def extract_and_save_features(resnet, loader, device, save_dir, prefix="features", return_labels=True):

    os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

    all_features, all_labels = [], []
    resnet.eval()

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            features = resnet(images).view(images.size(0), -1).cpu()
            all_features.append(features)
            if return_labels:
                all_labels.append(labels)

    features_tensor = torch.cat(all_features)
    torch.save(features_tensor, os.path.join(save_dir, f"{prefix}_features.pt"))
    print(f" Features saved to {os.path.join(save_dir, f'{prefix}_features.pt')}")

    if return_labels:
        labels_tensor = torch.cat(all_labels)
        torch.save(labels_tensor, os.path.join(save_dir, f"{prefix}_labels.pt"))
        print(f" Labels saved to {os.path.join(save_dir, f'{prefix}_labels.pt')}")

    return (features_tensor, labels_tensor) if return_labels else features_tensor


In [11]:
save_path = '/content/drive/MyDrive/saved_resnet50_features'

In [12]:
# For unlabeled target data
target_features = extract_and_save_features(resnet50, target_loader, device,
                          save_dir=save_path, prefix="target", return_labels=False)
print("Extracted Office Features Shape:", target_features.shape)

 Features saved to /content/drive/MyDrive/saved_resnet50_features/target_features.pt
Extracted Office Features Shape: torch.Size([2817, 2048])


In [13]:
# For source features
source_features, source_labels = extract_and_save_features(resnet50, source_loader, device,
                          save_dir=save_path, prefix="source")

print('source_features', source_features.shape)
print('source_labels', source_labels.shape)

 Features saved to /content/drive/MyDrive/saved_resnet50_features/source_features.pt
 Labels saved to /content/drive/MyDrive/saved_resnet50_features/source_labels.pt
source_features torch.Size([61214, 2048])
source_labels torch.Size([61214])


**Train Classifier on Source Domain**

In [30]:
num_source_classes = 10
num_epochs = 100


input_dim = source_features.shape[1]
classifier = nn.Linear(input_dim, num_source_classes)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

# Simple supervised classification on source
for epoch in range(num_epochs):
    logits = classifier(source_features)
    loss = nn.CrossEntropyLoss()(logits, source_labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Print loss every 5 epochs
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")



Epoch 5/100, Loss: 1.9428
Epoch 10/100, Loss: 1.4269
Epoch 15/100, Loss: 1.1070
Epoch 20/100, Loss: 0.9357
Epoch 25/100, Loss: 0.8483
Epoch 30/100, Loss: 0.8025
Epoch 35/100, Loss: 0.7770
Epoch 40/100, Loss: 0.7618
Epoch 45/100, Loss: 0.7520
Epoch 50/100, Loss: 0.7452
Epoch 55/100, Loss: 0.7403
Epoch 60/100, Loss: 0.7365
Epoch 65/100, Loss: 0.7335
Epoch 70/100, Loss: 0.7310
Epoch 75/100, Loss: 0.7289
Epoch 80/100, Loss: 0.7270
Epoch 85/100, Loss: 0.7254
Epoch 90/100, Loss: 0.7239
Epoch 95/100, Loss: 0.7225
Epoch 100/100, Loss: 0.7213


**Generate Pseudo-Labels on Target**

In [31]:
with torch.no_grad():
    target_logits = classifier(target_features)
    pseudo_labels = torch.argmax(target_logits, dim=1)
    confidences = torch.softmax(target_logits, dim=1).max(dim=1)[0]

    # Optional: Keep only high-confidence predictions
    confidence_threshold = 0.3
    confident_mask = confidences > confidence_threshold


**Initialize Class-wise Domain Discriminators**

In [32]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    def forward(self, x): return self.net(x)

# One discriminator per class
discriminators = {
    c: Discriminator(input_dim).to(device) for c in range(num_source_classes)
}

**Adversarial Training (Per Class)**

In [33]:
for class_idx in range(num_source_classes):
    D = discriminators[class_idx]
    optimizer_D = torch.optim.Adam(D.parameters(), lr=1e-4)

    # Select class-specific features
    src_mask = (source_labels == class_idx)
    tgt_mask = (pseudo_labels == class_idx) & confident_mask

    src_feat = source_features[src_mask]
    tgt_feat = target_features[tgt_mask]

    if src_feat.size(0) == 0 or tgt_feat.size(0) == 0:
        print(f"Skipping class {class_idx}: not enough samples (src: {src_feat.size(0)}, tgt: {tgt_feat.size(0)})")
        continue

    all_feat = torch.cat([src_feat, tgt_feat], dim=0).to(device)
    domain_labels = torch.cat([
        torch.ones(src_feat.size(0)),
        torch.zeros(tgt_feat.size(0))
    ]).to(device)

    domain_preds = D(all_feat).squeeze()
    adv_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_labels)

    optimizer_D.zero_grad()
    adv_loss.backward()
    optimizer_D.step()

    print(f"Trained discriminator for class {class_idx} | "
          f"src: {src_feat.size(0)}, tgt: {tgt_feat.size(0)}, "
          f"loss: {adv_loss.item():.4f}")


Trained discriminator for class 0 | src: 30607, tgt: 1106, loss: 0.6472
Trained discriminator for class 1 | src: 30607, tgt: 1711, loss: 0.6960
Skipping class 2: not enough samples (src: 0, tgt: 0)
Skipping class 3: not enough samples (src: 0, tgt: 0)
Skipping class 4: not enough samples (src: 0, tgt: 0)
Skipping class 5: not enough samples (src: 0, tgt: 0)
Skipping class 6: not enough samples (src: 0, tgt: 0)
Skipping class 7: not enough samples (src: 0, tgt: 0)
Skipping class 8: not enough samples (src: 0, tgt: 0)
Skipping class 9: not enough samples (src: 0, tgt: 0)


**Possible Causes:**

**Classifier bias toward a few classes**

Your classifier might be overfitting to dominant classes in Caltech-10, ignoring others when assigning pseudo-labels.

Some classes may have lower feature separability, making the softmax outputs less confident.

**Target domain shift**

Domain shift causes category mismatch—even if some objects in Office-31 belong to a "shared" class, the visual domain is different enough that your classifier might not recognize them confidently.

**Confidence threshold cutting too aggressively**

If your threshold (confidence_threshold = 0.7) is too high, most target samples won’t qualify as confidently pseudo-labeled.

Lowering it to 0.5 or even 0.3 might help populate more classes.
