# Task 3: Are Prompts a Stable Control Knob for DA/DG with CLIP?

---



### Installs

In [1]:
!pip install git+https://github.com/openai/CLIP.git
!pip install torchvision

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-lb6ybqng
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-lb6ybqng
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=018d8775736b515b93753e18debbb9d55c58c6c544fb80d63914cdda2961fef1
  Stored in directory: /tmp/pip-ephem-wheel-cache-7n6hrr0h/wheels/35/3e/df/3d24cbfb3b6a06f17

In [2]:
!pip install datasets



### Imports

In [3]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import clip
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from datasets import load_dataset
from sklearn.linear_model import LogisticRegression
from tqdm.notebook import tqdm
import torch.nn.functional as F

### Device:

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


### CLIP

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 163MiB/s]


## Part 1: CLIP Zero-Shot vs Fine-Tuned on Domains:

### Load PACS

In [None]:
dataset = load_dataset("flwrlabs/pacs", split="train")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

In [None]:
photo_dataset = dataset.filter(lambda example: example["domain"] == "photo")
art_dataset = dataset.filter(lambda example: example["domain"] == "art_painting")
cartoon_dataset = dataset.filter(lambda example: example["domain"] == "cartoon")
sketch_dataset = dataset.filter(lambda example: example["domain"] == "sketch")
domains = {
    "photo": photo_dataset,
    "art_painting": art_dataset,
    "cartoon": cartoon_dataset,
    "sketch": sketch_dataset
}
print({key: len(value) for key, value in domains.items()})

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

{'photo': 1670, 'art_painting': 2048, 'cartoon': 2344, 'sketch': 3929}


In [None]:
class_names = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']

Text Features

In [None]:
def get_text_features(prompt_template):
    texts = [prompt_template.format(class_name) for class_name in class_names]
    text_tokens = clip.tokenize(texts).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

Zero-shot classification

In [None]:
def zero_shot_classification(domain_name, prompt_template):
    domain_data = domains[domain_name]
    text_features = get_text_features(prompt_template)
    correct = 0
    for data in tqdm(domain_data, desc=f"Evaluating {domain_name}"):
        image = preprocess(data["image"]).unsqueeze(0).to(device)
        label = data["label"]
        with torch.no_grad():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T)
            probabilities = similarity.softmax(dim=-1)
            prediction = probabilities.argmax().item()
            if prediction == label:
                correct += 1
    accuracy = correct / len(domain_data)
    print(f"Accuracy on {domain_name} ({prompt_template}): {accuracy*100:.4f}%")
    return accuracy

Prompts

In [None]:
prompts = {
    "photo": "a photo of a {}",
    "art_painting": "an art painting of a {}",
    "cartoon": "a cartoon of a {}",
    "sketch": "a sketch of a {}"
}

In [None]:
results = {}
for domain in domains.keys():
    accuracy = zero_shot_classification(domain, prompts[domain])
    results[domain] = accuracy

Evaluating photo:   0%|          | 0/1670 [00:00<?, ?it/s]

Accuracy on photo (a photo of a {}): 99.7006%


Evaluating art_painting:   0%|          | 0/2048 [00:00<?, ?it/s]

Accuracy on art_painting (an art painting of a {}): 92.7734%


Evaluating cartoon:   0%|          | 0/2344 [00:00<?, ?it/s]

Accuracy on cartoon (a cartoon of a {}): 97.6536%


Evaluating sketch:   0%|          | 0/3929 [00:00<?, ?it/s]

Accuracy on sketch (a sketch of a {}): 85.1871%


#### Fine-tuning CLIP

In [None]:
source_domains = ["photo", "art_painting", "cartoon"]
target_domain = "sketch"

In [None]:
def extract_features(domain_data):
    features = []
    labels = []
    for data in tqdm(domain_data, desc="Extracting features"):
        image = preprocess(data["image"]).unsqueeze(0).to(device)
        label = data["label"]
        with torch.no_grad():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            features.append(image_features.cpu().numpy())
            labels.append(label)
    return np.concatenate(features), np.array(labels)

In [None]:
source_domains_features = []
source_domains_labels = []
for domain in source_domains:
    domain_data = domains[domain]
    features, labels = extract_features(domain_data)
    source_domains_features.append(features)
    source_domains_labels.append(labels)
source_domains_features = np.concatenate(source_domains_features)
source_domains_labels = np.concatenate(source_domains_labels)
target_domain_data = domains[target_domain]
target_domain_features, target_domain_labels = extract_features(target_domain_data)
print("Source domain features shape:", source_domains_features.shape)
print("Source domain labels shape:", source_domains_labels.shape)
print("Target domain features shape:", target_domain_features.shape)
print("Target domain labels shape:", target_domain_labels.shape)

Extracting features:   0%|          | 0/1670 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/2048 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/2344 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/3929 [00:00<?, ?it/s]

Source domain features shape: (6062, 512)
Source domain labels shape: (6062,)
Target domain features shape: (3929, 512)
Target domain labels shape: (3929,)


In [None]:
linear_classifier = LogisticRegression(max_iter=2000, multi_class='multinomial')
linear_classifier.fit(source_domains_features, source_domains_labels)



In [None]:
from sklearn.metrics import accuracy_score
source_domains_predictions = linear_classifier.predict(source_domains_features)
accuracy_source = accuracy_score(source_domains_labels, source_domains_predictions)
print(f"Accuracy on source domains: {accuracy_source*100:.4f}%")
target_domain_predictions = linear_classifier.predict(target_domain_features)
accuracy_target = accuracy_score(target_domain_labels, target_domain_predictions)
print(f"Accuracy on target domain: {accuracy_target*100:.4f}%")

Accuracy on source domains: 98.9113%
Accuracy on target domain: 85.0598%


## Part 2: Prompt-Learning with CLIP:

I am using the `photo` domain as the source domain and the `sketch` domain as the target domain. The reason for this choice is that CLIP performed relatively worse on the `sketch` domain and the best on the `photo` domain, so by learning prompts, I hope that the performance on the sketch domain is improved.

In [5]:
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model = model.float()
dtype = torch.float32

100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 303MiB/s]


In [6]:
dataset = load_dataset("flwrlabs/pacs", split="train")
photo_dataset = dataset.filter(lambda example: example["domain"] == "photo")
sketch_dataset = dataset.filter(lambda example: example["domain"] == "sketch")

class PACSDataset(Dataset):
    def __init__(self, hf_dataset, transform):
        self.data = hf_dataset
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image = self.data[idx]["image"]
        label = self.data[idx]["label"]
        image = self.transform(image)
        return {"image": image, "label": torch.tensor(label, dtype=torch.long)}

source_dataset = PACSDataset(photo_dataset, preprocess)
target_dataset = PACSDataset(sketch_dataset, preprocess)
source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

In [7]:
class LearnablePromptCLIP(nn.Module):
    """
    Learnable prompt-based CLIP for domain adaptation.
    Implements DAPL-style prompt learning where prompts are optimized
    while keeping CLIP frozen.
    """
    def __init__(self, clip_model, n_ctx=16, n_classes=7, class_names=None):
        super().__init__()
        self.clip_model = clip_model
        self.n_ctx = n_ctx  # Number of context tokens (learnable prompts)
        self.n_classes = n_classes
        self.class_names = class_names or ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']

        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False

        # Get the text encoder's token embedding layer
        self.token_embedding = clip_model.token_embedding
        self.positional_embedding = clip_model.positional_embedding
        self.transformer = clip_model.transformer
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

        # Initialize learnable context vectors (prompts)
        # Shape: [n_ctx, embed_dim]
        ctx_dim = clip_model.ln_final.weight.shape[0]
        ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        self.ctx = nn.Parameter(ctx_vectors)

        # Create class token embeddings
        classnames = [name.replace("_", " ") for name in self.class_names]
        name_lens = [len(clip.tokenize(name)[0]) - 2 for name in classnames]  # -2 for [SOS] and [EOS]
        prompts = [" ".join(["X"] * n_ctx) + " " + name + "." for name in classnames]

        # Tokenize prompts
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        with torch.no_grad():
            embedding = self.token_embedding(tokenized_prompts).type(dtype)

        # Store these for later use
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        self.n_cls = n_classes
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens

    def construct_prompts(self):
        """
        Construct text embeddings with learnable context
        Returns: [n_classes, n_ctx, embed_dim]
        """
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        prompts = torch.cat([
            prefix,  # [n_cls, 1, dim]
            ctx,     # [n_cls, n_ctx, dim]
            suffix,  # [n_cls, *, dim]
        ], dim=1)

        return prompts

    def encode_text_with_prompts(self):
        """
        Encode text with learned prompts through CLIP's text encoder
        """
        prompts = self.construct_prompts()
        x = prompts + self.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(dtype)

        # Take features from the eot embedding (end of text)
        x = x[torch.arange(x.shape[0]), self.tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

    def forward(self, image):
        """
        Forward pass: encode image and compute similarity with text prompts
        """
        # Encode image
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Encode text with learned prompts
        text_features = self.encode_text_with_prompts()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute similarity
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [8]:
prompt_learner = LearnablePromptCLIP(
    clip_model=model,
    n_ctx=16,
    n_classes=7,
    class_names=['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
).to(device)

print(f"Learnable parameters: {sum(p.numel() for p in prompt_learner.parameters() if p.requires_grad)}")
print(f"Frozen parameters: {sum(p.numel() for p in prompt_learner.parameters() if not p.requires_grad)}")

Learnable parameters: 8192
Frozen parameters: 151277313


In [9]:
def entropy_loss(logits):
    """
    Entropy minimization loss for target domain
    Encourages the model to make confident predictions
    """
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -(probs * log_probs).sum(dim=-1)
    return entropy.mean()

def pseudo_label_loss(logits, threshold=0.95):
    """
    Pseudo-labeling loss for target domain
    Only use predictions with high confidence
    """
    probs = F.softmax(logits, dim=-1)
    max_probs, pseudo_labels = probs.max(dim=-1)

    # Filter by confidence threshold
    mask = max_probs >= threshold

    if mask.sum() == 0:
        return torch.tensor(0.0).to(logits.device)

    # Cross-entropy loss with pseudo labels
    loss = F.cross_entropy(logits[mask], pseudo_labels[mask])
    return loss

def train_dapl(prompt_learner, source_loader, target_loader,
               num_epochs=10, lr=0.002, lambda_target=0.3,
               use_entropy=True, use_pseudo_label=True):
    """
    Train prompt learner with DAPL-style domain adaptation

    Args:
        prompt_learner: The learnable prompt model
        source_loader: DataLoader for source domain (labeled)
        target_loader: DataLoader for target domain (unlabeled)
        num_epochs: Number of training epochs
        lr: Learning rate
        lambda_target: Weight for target domain loss
        use_entropy: Whether to use entropy minimization
        use_pseudo_label: Whether to use pseudo-labeling
    """
    optimizer = torch.optim.AdamW([prompt_learner.ctx], lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    prompt_learner.train()

    for epoch in range(num_epochs):
        total_loss = 0
        source_loss_total = 0
        target_loss_total = 0
        correct_source = 0
        total_source = 0

        # Create iterator for target domain
        target_iter = iter(target_loader)

        progress_bar = tqdm(source_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, source_batch in enumerate(progress_bar):
            # Get source batch
            source_images = source_batch['image'].to(device)
            source_labels = source_batch['label'].to(device)

            # Get target batch (cycle if needed)
            try:
                target_batch = next(target_iter)
            except StopIteration:
                target_iter = iter(target_loader)
                target_batch = next(target_iter)

            target_images = target_batch['image'].to(device)

            # Forward pass on source domain
            source_logits = prompt_learner(source_images)
            source_loss = F.cross_entropy(source_logits, source_labels)

            # Calculate source accuracy
            _, predicted = source_logits.max(1)
            correct_source += predicted.eq(source_labels).sum().item()
            total_source += source_labels.size(0)

            # Forward pass on target domain
            target_logits = prompt_learner(target_images)

            # Target domain losses
            target_loss = torch.tensor(0.0).to(device)

            if use_entropy:
                ent_loss = entropy_loss(target_logits)
                target_loss += ent_loss

            if use_pseudo_label:
                pl_loss = pseudo_label_loss(target_logits, threshold=0.95)
                target_loss += pl_loss

            # Total loss
            loss = source_loss + lambda_target * target_loss

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track losses
            total_loss += loss.item()
            source_loss_total += source_loss.item()
            target_loss_total += target_loss.item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'src_loss': f'{source_loss.item():.4f}',
                'tgt_loss': f'{target_loss.item():.4f}',
                'src_acc': f'{100.*correct_source/total_source:.2f}%'
            })

        scheduler.step()

        # Epoch summary
        avg_loss = total_loss / len(source_loader)
        avg_source_loss = source_loss_total / len(source_loader)
        avg_target_loss = target_loss_total / len(source_loader)
        source_acc = 100. * correct_source / total_source

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Source Loss: {avg_source_loss:.4f}")
        print(f"  Target Loss: {avg_target_loss:.4f}")
        print(f"  Source Accuracy: {source_acc:.2f}%")
        print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.6f}\n")

    return prompt_learner

In [None]:
# Train the prompt learner
trained_model = train_dapl(
    prompt_learner=prompt_learner,
    source_loader=source_dataloader,
    target_loader=target_dataloader,
    num_epochs=10,
    lr=0.002,
    lambda_target=0.3,
    use_entropy=True,
    use_pseudo_label=True
)

Epoch 1/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 1 Summary:
  Average Loss: 0.1037
  Source Loss: 0.0369
  Target Loss: 0.2228
  Source Accuracy: 98.92%
  Learning Rate: 0.001951



Epoch 2/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 2 Summary:
  Average Loss: 0.0479
  Source Loss: 0.0083
  Target Loss: 0.1320
  Source Accuracy: 99.64%
  Learning Rate: 0.001809



Epoch 3/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 3 Summary:
  Average Loss: 0.0428
  Source Loss: 0.0053
  Target Loss: 0.1250
  Source Accuracy: 99.88%
  Learning Rate: 0.001588



Epoch 4/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 4 Summary:
  Average Loss: 0.0352
  Source Loss: 0.0049
  Target Loss: 0.1011
  Source Accuracy: 99.82%
  Learning Rate: 0.001309



Epoch 5/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 5 Summary:
  Average Loss: 0.0351
  Source Loss: 0.0029
  Target Loss: 0.1072
  Source Accuracy: 99.94%
  Learning Rate: 0.001000



Epoch 6/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 6 Summary:
  Average Loss: 0.0355
  Source Loss: 0.0032
  Target Loss: 0.1079
  Source Accuracy: 99.82%
  Learning Rate: 0.000691



Epoch 7/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 7 Summary:
  Average Loss: 0.0304
  Source Loss: 0.0014
  Target Loss: 0.0966
  Source Accuracy: 100.00%
  Learning Rate: 0.000412



Epoch 8/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 8 Summary:
  Average Loss: 0.0294
  Source Loss: 0.0012
  Target Loss: 0.0940
  Source Accuracy: 100.00%
  Learning Rate: 0.000191



Epoch 9/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 9 Summary:
  Average Loss: 0.0268
  Source Loss: 0.0012
  Target Loss: 0.0855
  Source Accuracy: 100.00%
  Learning Rate: 0.000049



Epoch 10/10:   0%|          | 0/53 [00:00<?, ?it/s]


Epoch 10 Summary:
  Average Loss: 0.0270
  Source Loss: 0.0010
  Target Loss: 0.0865
  Source Accuracy: 100.00%
  Learning Rate: 0.000000



In [None]:
def evaluate_model(model, dataloader, domain_name):
    """
    Evaluate the model on a given domain
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating on {domain_name}"):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            logits = model(images)
            _, predicted = logits.max(1)

            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    accuracy = 100. * correct / total
    print(f"Accuracy on {domain_name}: {accuracy:.2f}%")
    return accuracy

# Evaluate on source domain (photo)
source_accuracy = evaluate_model(trained_model, source_dataloader, "source (photo)")

# Evaluate on target domain (sketch)
target_accuracy = evaluate_model(trained_model, target_dataloader, "target (sketch)")

Evaluating on source (photo):   0%|          | 0/53 [00:00<?, ?it/s]

Accuracy on source (photo): 100.00%


Evaluating on target (sketch):   0%|          | 0/123 [00:00<?, ?it/s]

Accuracy on target (sketch): 87.25%


### Part 4: Open-Set and Generalization Analysis:

Seen and unseen classes

In [32]:
# Define seen and unseen classes
classnames = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
seen_classes = ['dog', 'elephant', 'giraffe', 'house', 'person']
unseen_classes = ['guitar', 'horse']

seen_class_indices = [classnames.index(c) for c in seen_classes]
unseen_class_indices = [classnames.index(c) for c in unseen_classes]

print(f"Training on {len(seen_classes)} classes: {seen_classes}")
print(f"Testing on {len(unseen_classes)} unseen classes: {unseen_classes}")
print(f"Seen class indices: {seen_class_indices}")
print(f"Unseen class indices: {unseen_class_indices}")

Training on 5 classes: ['dog', 'elephant', 'giraffe', 'house', 'person']
Testing on 2 unseen classes: ['guitar', 'horse']
Seen class indices: [0, 1, 2, 5, 6]
Unseen class indices: [3, 4]


In [33]:

photo_dataset_seen = photo_dataset.filter(lambda x: x['label'] in seen_class_indices)
sketch_dataset_seen = sketch_dataset.filter(lambda x: x['label'] in seen_class_indices)


photo_dataset_unseen = photo_dataset.filter(lambda x: x['label'] in unseen_class_indices)
sketch_dataset_unseen = sketch_dataset.filter(lambda x: x['label'] in unseen_class_indices)

print(f"Source (photo) - Seen classes: {len(photo_dataset_seen)} samples")
print(f"Source (photo) - Unseen classes: {len(photo_dataset_unseen)} samples")
print(f"Target (sketch) - Seen classes: {len(sketch_dataset_seen)} samples")
print(f"Target (sketch) - Unseen classes: {len(sketch_dataset_unseen)} samples")
source_dataset_openset = PACSDataset(photo_dataset_seen, preprocess)
target_dataset_openset = PACSDataset(sketch_dataset_seen, preprocess)
source_dataloader_openset = DataLoader(source_dataset_openset, batch_size=32, shuffle=True)
target_dataloader_openset = DataLoader(target_dataset_openset, batch_size=32, shuffle=True)
source_test_all = PACSDataset(photo_dataset, preprocess)
target_test_all = PACSDataset(sketch_dataset, preprocess)
source_test_loader_all = DataLoader(source_test_all, batch_size=32, shuffle=False)
target_test_loader_all = DataLoader(target_test_all, batch_size=32, shuffle=False)

Filter:   0%|          | 0/1670 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3929 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1670 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3929 [00:00<?, ? examples/s]

Source (photo) - Seen classes: 1285 samples
Source (photo) - Unseen classes: 385 samples
Target (sketch) - Seen classes: 2505 samples
Target (sketch) - Unseen classes: 1424 samples


In [34]:
model_openset, _ = clip.load("ViT-B/32", device=device, jit=False)
model_openset = model_openset.float()

# Create prompt learner for all 7 classes (but train on only 5)
prompt_learner_openset = LearnablePromptCLIP(
    clip_model=model_openset,
    n_ctx=16,
    n_classes=7,  # Still maintain 7 classes for comparison
    class_names=classnames
).to(device)

print(f"Learnable parameters: {sum(p.numel() for p in prompt_learner_openset.parameters() if p.requires_grad)}")

# Train on seen classes only
trained_model_openset = train_dapl(
    prompt_learner=prompt_learner_openset,
    source_loader=source_dataloader_openset,
    target_loader=target_dataloader_openset,
    num_epochs=10,
    lr=0.002,
    lambda_target=0.3,
    use_entropy=True,
    use_pseudo_label=True
)

Learnable parameters: 8192


Epoch 1/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 1 Summary:
  Average Loss: 0.0973
  Source Loss: 0.0251
  Target Loss: 0.2406
  Source Accuracy: 99.38%
  Learning Rate: 0.001951



Epoch 2/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 2 Summary:
  Average Loss: 0.0371
  Source Loss: 0.0010
  Target Loss: 0.1200
  Source Accuracy: 100.00%
  Learning Rate: 0.001809



Epoch 3/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 3 Summary:
  Average Loss: 0.0284
  Source Loss: 0.0008
  Target Loss: 0.0923
  Source Accuracy: 100.00%
  Learning Rate: 0.001588



Epoch 4/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 4 Summary:
  Average Loss: 0.0244
  Source Loss: 0.0004
  Target Loss: 0.0800
  Source Accuracy: 100.00%
  Learning Rate: 0.001309



Epoch 5/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 5 Summary:
  Average Loss: 0.0228
  Source Loss: 0.0004
  Target Loss: 0.0745
  Source Accuracy: 100.00%
  Learning Rate: 0.001000



Epoch 6/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 6 Summary:
  Average Loss: 0.0215
  Source Loss: 0.0005
  Target Loss: 0.0700
  Source Accuracy: 100.00%
  Learning Rate: 0.000691



Epoch 7/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 7 Summary:
  Average Loss: 0.0187
  Source Loss: 0.0004
  Target Loss: 0.0610
  Source Accuracy: 100.00%
  Learning Rate: 0.000412



Epoch 8/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 8 Summary:
  Average Loss: 0.0191
  Source Loss: 0.0001
  Target Loss: 0.0633
  Source Accuracy: 100.00%
  Learning Rate: 0.000191



Epoch 9/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 9 Summary:
  Average Loss: 0.0179
  Source Loss: 0.0002
  Target Loss: 0.0589
  Source Accuracy: 100.00%
  Learning Rate: 0.000049



Epoch 10/10:   0%|          | 0/41 [00:00<?, ?it/s]


Epoch 10 Summary:
  Average Loss: 0.0171
  Source Loss: 0.0002
  Target Loss: 0.0561
  Source Accuracy: 100.00%
  Learning Rate: 0.000000



In [35]:
def evaluate_open_set(model, dataloader, seen_indices, unseen_indices, model_name="Model"):
    """
    Evaluate model on both seen and unseen classes with detailed metrics

    Returns:
        dict with metrics for seen and unseen classes
    """
    model.eval()

    all_logits = []
    all_labels = []
    all_probs = []
    all_max_probs = []
    all_predictions = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {model_name}"):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            logits = model(images)
            probs = F.softmax(logits, dim=-1)
            max_probs, predictions = probs.max(dim=-1)

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
            all_probs.append(probs.cpu())
            all_max_probs.append(max_probs.cpu())
            all_predictions.append(predictions.cpu())

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_probs = torch.cat(all_probs, dim=0)
    all_max_probs = torch.cat(all_max_probs, dim=0)
    all_predictions = torch.cat(all_predictions, dim=0)

    # Separate seen and unseen samples
    seen_mask = torch.tensor([label.item() in seen_indices for label in all_labels])
    unseen_mask = torch.tensor([label.item() in unseen_indices for label in all_labels])

    results = {}

    # Metrics for SEEN classes
    if seen_mask.sum() > 0:
        seen_labels = all_labels[seen_mask]
        seen_predictions = all_predictions[seen_mask]
        seen_probs = all_probs[seen_mask]
        seen_max_probs = all_max_probs[seen_mask]

        seen_correct = (seen_predictions == seen_labels).sum().item()
        seen_accuracy = 100.0 * seen_correct / seen_mask.sum().item()
        seen_avg_confidence = seen_max_probs.mean().item()

        # Entropy for seen samples
        seen_entropy = -(seen_probs * torch.log(seen_probs + 1e-10)).sum(dim=-1).mean().item()

        results['seen'] = {
            'accuracy': seen_accuracy,
            'avg_confidence': seen_avg_confidence,
            'avg_entropy': seen_entropy,
            'count': seen_mask.sum().item()
        }

    # Metrics for UNSEEN classes
    if unseen_mask.sum() > 0:
        unseen_labels = all_labels[unseen_mask]
        unseen_predictions = all_predictions[unseen_mask]
        unseen_probs = all_probs[unseen_mask]
        unseen_max_probs = all_max_probs[unseen_mask]

        # For unseen classes: check if model incorrectly assigns to seen classes with high confidence
        # This measures "false positives" - confidently predicting a seen class for unseen sample

        unseen_avg_confidence = unseen_max_probs.mean().item()
        unseen_entropy = -(unseen_probs * torch.log(unseen_probs + 1e-10)).sum(dim=-1).mean().item()

        # Count how many unseen samples are predicted as seen classes with high confidence
        high_conf_threshold = 0.7
        false_positives = (unseen_max_probs > high_conf_threshold).sum().item()
        fpr = 100.0 * false_positives / unseen_mask.sum().item()

        results['unseen'] = {
            'avg_confidence': unseen_avg_confidence,
            'avg_entropy': unseen_entropy,
            'false_positive_rate': fpr,
            'count': unseen_mask.sum().item()
        }

    return results

def compare_zero_shot_vs_tuned_open_set(seen_indices, unseen_indices):
    """
    Compare zero-shot CLIP vs tuned prompts on open-set scenario
    """
    print("\n" + "="*80)
    print("OPEN-SET EVALUATION: ZERO-SHOT CLIP vs TUNED PROMPTS")
    print("="*80)

    # Evaluate zero-shot CLIP (using original model with hand-crafted prompts)
    print("\n--- Zero-Shot CLIP (Hand-crafted prompts) ---")

    # Create a simple wrapper for zero-shot evaluation
    class ZeroShotWrapper(nn.Module):
        def __init__(self, clip_model, class_names):
            super().__init__()
            self.clip_model = clip_model
            self.class_names = class_names
            # Create text features with hand-crafted prompts
            texts = [f"a photo of a {name}" for name in class_names]
            text_tokens = clip.tokenize(texts).to(device)
            with torch.no_grad():
                text_features = clip_model.encode_text(text_tokens)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            self.register_buffer('text_features', text_features)

        def forward(self, image):
            image_features = self.clip_model.encode_image(image)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * image_features @ self.text_features.t()
            return logits

    zero_shot_model = ZeroShotWrapper(model_openset, all_classes).to(device)
    zero_shot_results = evaluate_open_set(
        zero_shot_model,
        target_test_loader_all,
        seen_indices,
        unseen_indices,
        "Zero-Shot CLIP"
    )

    # Evaluate tuned prompts
    print("\n--- Tuned Prompts (DAPL) ---")
    tuned_results = evaluate_open_set(
        trained_model_openset,
        target_test_loader_all,
        seen_indices,
        unseen_indices,
        "Tuned Prompts"
    )

    # Print comparison
    print("\n" + "="*80)
    print("RESULTS COMPARISON (Target Domain: Sketch)")
    print("="*80)

    print("\n SEEN CLASSES (Trained on these):")
    print(f"  Zero-Shot CLIP:")
    print(f"    - Accuracy: {zero_shot_results['seen']['accuracy']:.2f}%")
    print(f"    - Avg Confidence: {zero_shot_results['seen']['avg_confidence']:.4f}")
    print(f"    - Avg Entropy: {zero_shot_results['seen']['avg_entropy']:.4f}")

    print(f"\n  Tuned Prompts:")
    print(f"    - Accuracy: {tuned_results['seen']['accuracy']:.2f}%")
    print(f"    - Avg Confidence: {tuned_results['seen']['avg_confidence']:.4f}")
    print(f"    - Avg Entropy: {tuned_results['seen']['avg_entropy']:.4f}")

    print("\nUNSEEN CLASSES (Never seen during training):")
    print(f"  Zero-Shot CLIP:")
    print(f"    - Avg Confidence: {zero_shot_results['unseen']['avg_confidence']:.4f}")
    print(f"    - Avg Entropy: {zero_shot_results['unseen']['avg_entropy']:.4f}")
    print(f"    - False Positive Rate (>0.7 conf): {zero_shot_results['unseen']['false_positive_rate']:.2f}%")

    print(f"\n  Tuned Prompts:")
    print(f"    - Avg Confidence: {tuned_results['unseen']['avg_confidence']:.4f}")
    print(f"    - Avg Entropy: {tuned_results['unseen']['avg_entropy']:.4f}")
    print(f"    - False Positive Rate (>0.7 conf): {tuned_results['unseen']['false_positive_rate']:.2f}%")

    print("\n" + "="*80)
    print("KEY INSIGHTS:")
    print("="*80)

    # Calculate degradation
    conf_increase = tuned_results['unseen']['avg_confidence'] - zero_shot_results['unseen']['avg_confidence']
    entropy_change = tuned_results['unseen']['avg_entropy'] - zero_shot_results['unseen']['avg_entropy']
    fpr_increase = tuned_results['unseen']['false_positive_rate'] - zero_shot_results['unseen']['false_positive_rate']

    print(f"1. Seen class accuracy improvement: {tuned_results['seen']['accuracy'] - zero_shot_results['seen']['accuracy']:.2f}%")
    print(f"2. Unseen class confidence change: {conf_increase:+.4f}")
    print(f"3. Unseen class entropy change: {entropy_change:+.4f} (lower is more confident)")
    print(f"4. False positive rate increase: {fpr_increase:+.2f}%")

    if fpr_increase > 10:
        print("\n WARNING: Tuned prompts show significantly higher false positive rate!")
        print("   The model is over-confident on out-of-distribution samples.")

    return zero_shot_results, tuned_results

In [36]:
zero_shot_results, tuned_results = compare_zero_shot_vs_tuned_open_set(
    seen_class_indices,
    unseen_class_indices
)


OPEN-SET EVALUATION: ZERO-SHOT CLIP vs TUNED PROMPTS

--- Zero-Shot CLIP (Hand-crafted prompts) ---


Evaluating Zero-Shot CLIP:   0%|          | 0/123 [00:00<?, ?it/s]


--- Tuned Prompts (DAPL) ---


Evaluating Tuned Prompts:   0%|          | 0/123 [00:00<?, ?it/s]


RESULTS COMPARISON (Target Domain: Sketch)

 SEEN CLASSES (Trained on these):
  Zero-Shot CLIP:
    - Accuracy: 76.93%
    - Avg Confidence: 0.8388
    - Avg Entropy: 0.4523

  Tuned Prompts:
    - Accuracy: 92.50%
    - Avg Confidence: 0.9811
    - Avg Entropy: 0.0506

UNSEEN CLASSES (Never seen during training):
  Zero-Shot CLIP:
    - Avg Confidence: 0.9594
    - Avg Entropy: 0.1490
    - False Positive Rate (>0.7 conf): 95.79%

  Tuned Prompts:
    - Avg Confidence: 0.9197
    - Avg Entropy: 0.2195
    - False Positive Rate (>0.7 conf): 88.76%

KEY INSIGHTS:
1. Seen class accuracy improvement: 15.57%
2. Unseen class confidence change: -0.0397
3. Unseen class entropy change: +0.0705 (lower is more confident)
4. False positive rate increase: -7.02%
