# Robust Few-Shot Vision-Language Model Adaptation

This notebook demonstrates how to perform **partial finetuning (PFT)** on CLIP for few-shot learning and evaluate the model on both ID and OOD datasets. Results are comparable to Table 3 in the paper.
1. model
    - CLIP ViT-B/16
2. few-shot setting
    - 16 shots
3. PFT setting
    - top-4 blocks
4. dataset
    - ImageNet-1k (as ID dataset)
    - ImageNet-V2 (as OOD dataset)
    - ImageNet-S (as OOD dataset)
    - ImageNet-A (as OOD dataset)
    - ImageNet-R (as OOD dataset)
5. acknowledgements
    - This code is built on [LCA-on-the-line(ICML'24)](https://github.com/ElvishElvis/LCA-on-the-line) and [SWAT(CVPR'25)](https://github.com/tian1327/SWAT).


In [None]:
import torch
import numpy as np
import random

# Set the random seed for reproducibility
training_seed = 1
data_seed = 1

random.seed(data_seed)
np.random.seed(training_seed)
torch.manual_seed(training_seed)
torch.cuda.manual_seed_all(training_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Partial finetuning

1. Load the CLIP model and prepare the dataset

In [2]:
import clip.clip as clip
import datasets

# Load the CLIP ViT-B/16 model
clip_model, train_preprocess, test_preprocess= clip.load('ViT-B/16', jit=False)


# Prepare dataset and dataloader
root = 'PUT YOUR PATH HERE' # Path to the ImageNet dataset
num_shots = 16 # Number of shots for few-shot learning, options: 4, 8, 16
imagenet_train, text_name = datasets.build_imagenet_few_shot_dataset('imagenet', 'train', data_seed, train_preprocess, root=root, num_shots=num_shots)
imagenet_val, _ = datasets.build_imagenet_dataset('imagenet', 'val', test_preprocess, root=root)
# ID testset
imagenet_test, _ = datasets.build_imagenet_dataset('imagenet', 'test', test_preprocess, root=root)
# OOD testsets
imagenet_a_test, _ = datasets.build_imagenet_dataset('imagenet_a', 'test', test_preprocess, root=root)
imagenet_r_test, _ = datasets.build_imagenet_dataset('imagenet_r', 'test', test_preprocess, root=root)
imagenet_sketch_test, _ = datasets.build_imagenet_dataset('imagenet_sketch', 'test', test_preprocess, root=root)
imagenetv2_test, _ = datasets.build_imagenet_dataset('imagenetv2', 'test', test_preprocess, root=root)

batch_size = 64
train_dataloader = torch.utils.data.DataLoader(
    dataset=imagenet_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(
    dataset=imagenet_test,
    batch_size=64,
    shuffle=False,
    num_workers=8,
    pin_memory=False)

len(imagenet_train), len(imagenet_val), len(imagenet_test), len(imagenet_a_test), len(imagenet_r_test), len(imagenet_sketch_test), len(imagenetv2_test)

Loading few-shot data from data_resource/imagenet/fewshot16_seed1.txt.


(16000, 50000, 50000, 7500, 30000, 50889, 10000)

2. Frozen the clip model except the top-X blocks of visual encoder.

In [3]:
def frozen(model, ft_topk_blks):
    for param in model.parameters():
        param.requires_grad = False

    if ft_topk_blks == -1:
        print('Finetune all blocks of the visual transformer.')
        for param in model.visual.parameters():
            param.requires_grad = True
    else:
        print(f'Finetune top-{ft_topk_blks} blocks of the visual transformer.')
        for blk in model.visual.transformer.resblocks[-ft_topk_blks:]:
            for param in blk.parameters():
                param.requires_grad = True

        for param in model.visual.ln_post.parameters():
            param.requires_grad = True
        model.visual.proj.requires_grad = True

# In our experiments, we adopt PFT on the top-4 blocks as the default setting.
ft_topk_blks = 4
frozen(clip_model, ft_topk_blks)

# double check the parameters
for name, param in clip_model.named_parameters():
    if param.requires_grad:
        print(name)

Finetune top-4 blocks of the visual transformer.
visual.proj
visual.transformer.resblocks.8.attn.in_proj_weight
visual.transformer.resblocks.8.attn.in_proj_bias
visual.transformer.resblocks.8.attn.out_proj.weight
visual.transformer.resblocks.8.attn.out_proj.bias
visual.transformer.resblocks.8.ln_1.weight
visual.transformer.resblocks.8.ln_1.bias
visual.transformer.resblocks.8.mlp.c_fc.weight
visual.transformer.resblocks.8.mlp.c_fc.bias
visual.transformer.resblocks.8.mlp.c_proj.weight
visual.transformer.resblocks.8.mlp.c_proj.bias
visual.transformer.resblocks.8.ln_2.weight
visual.transformer.resblocks.8.ln_2.bias
visual.transformer.resblocks.9.attn.in_proj_weight
visual.transformer.resblocks.9.attn.in_proj_bias
visual.transformer.resblocks.9.attn.out_proj.weight
visual.transformer.resblocks.9.attn.out_proj.bias
visual.transformer.resblocks.9.ln_1.weight
visual.transformer.resblocks.9.ln_1.bias
visual.transformer.resblocks.9.mlp.c_fc.weight
visual.transformer.resblocks.9.mlp.c_fc.bias
vis

3. Initialize the classifier with average text features of OpenAI 80 prompts.

In [4]:
import torch.nn as nn
class MyLinear(nn.Module):
    def __init__(self, input_dim=512, num_classes=1000, bias = False):
        super(MyLinear, self).__init__()

        self.linear = nn.Linear(input_dim, num_classes, bias=bias)
        self.num_classes = num_classes

    def forward(self, x):
        x = self.linear(x)

        return x

    def _init_weights(self, weights):
        # Initialize the weights of the linear layer with the given weights
        self.linear.weight = nn.Parameter(weights.clone())

In [5]:
import templates
def get_zeroshot_weights(text_name, template, logit_scale):
    zeroshot_weights = []
    template = getattr(templates, template)
    for classname in text_name:
        texts = []
        for t in template:
            texts.append(t(classname))
        texts = clip.tokenize(texts).cuda()
        embeddings = clip_model.encode_text(texts) #(80, dim)
        embeddings /= embeddings.norm(dim=-1, keepdim=True)
        
        embeddings = embeddings.mean(dim=0) #(dim)
        embeddings /= embeddings.norm()
    
        zeroshot_weights.append(embeddings)
    
    zeroshot_weights = torch.stack(zeroshot_weights, dim=0) #(1000, dim)
    zeroshot_weights *= logit_scale.exp()
    
    return zeroshot_weights

In [6]:
# Set classifier
num_classes = 1000  # Number of classes
num_features = 512  # Number of features

classifier = MyLinear(input_dim=num_features, num_classes=num_classes, bias=False)

# Initialize the classifier
logit_scale = clip_model.logit_scale
template = 'openai_imagenet_template' # OpenAI 80 prompts
with torch.no_grad():
    zeroshot_weights = get_zeroshot_weights(text_name, template, logit_scale)
classifier._init_weights(zeroshot_weights)

zeroshot_weights.shape

torch.Size([1000, 512])

4. Define the optimizer and learning rate scheduler

In [None]:
from utils.scheduler import build_lr_scheduler

lr_backbone = 1e-6
lr_cls = 1e-3   # set the learning rate for the classifier
weight_decay = 0.1

# Define the optimizer
param_groups = [
            {"params": [p for name, p in clip_model.named_parameters() if p.requires_grad], "lr": lr_backbone},
            {"params": [p for p in classifier.parameters()], "lr": lr_cls},
        ]
optimizer = torch.optim.AdamW(param_groups, lr=lr_cls, weight_decay=weight_decay, betas=(0.9, 0.999))

# Define the learning rate scheduler
num_epochs = 50
total_iter = len(train_dataloader) * num_epochs
warmup_iter = 18
warmup_lr = 1e-8
scheduler = build_lr_scheduler(optimizer,
                               lr_scheduler="cosine",
                               warmup_iter=warmup_iter,
                               max_iter=total_iter,
                               warmup_type="linear",
                               warmup_lr=warmup_lr,
                               verbose=False)

5. Start training

In [11]:
import tqdm
import torchmetrics

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
CE_criterion = nn.CrossEntropyLoss()

best_val_acc = -1
loss, acc, val_acc = [-1] * 3
model_path = 'outputs/PFT_top4_best_model.pth'

clip_model.to(device)
classifier.to(device)
clip_model.train()
classifier.train()
print(f"Start standard finetuning ......")
for epoch in range(1, num_epochs + 1):

    train_acc = torchmetrics.Accuracy(num_classes=num_classes, task="multiclass", top_k=1)
    train_acc.to(device)

    pbar_iter = tqdm.tqdm(train_dataloader)
    for idx, (images, targets) in enumerate(pbar_iter):
        pbar_iter.set_description(f"Epoch {epoch} / {num_epochs}, loss = {loss:.2f}, acc = {acc:.2f}, val_acc = {val_acc:.2f}, best_val_acc = {best_val_acc:.2f}")

        images = images.to(device)
        targets = targets.to(device)

        image_features = clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        outputs = classifier(image_features)

        loss = CE_criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc = train_acc(outputs, targets)

    # Validation
    clip_model.eval()
    classifier.eval()
    with torch.no_grad():
        targets_list = []
        preds_list = []
        for idx, (inputs, targets) in enumerate(val_dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            image_features = clip_model.encode_image(inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            outputs = classifier(image_features)

            targets_list.append(targets.detach().cpu().numpy())
            preds_list.append(outputs.detach().cpu().numpy())

    targets_list = np.hstack(targets_list)
    preds_list = np.vstack(preds_list)
    preds_list = torch.tensor(preds_list)
    val_acc = (torch.softmax(preds_list, dim=1).argmax(1).numpy() == targets_list).mean()

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        best_state_dict = {
            "clip_model": clip_model.state_dict(),
            "classifier": classifier.state_dict(),
        }
        torch.save(best_state_dict, model_path) # Save the best model

    clip_model.train()
    classifier.train()

Start standard finetuning ......


Epoch 1 / 50, loss = 0.79, acc = 0.77, val_acc = -1.00, best_val_acc = -1.00: 100%|██████████| 250/250 [00:38<00:00,  6.44it/s]
Epoch 2 / 50, loss = 0.65, acc = 0.81, val_acc = 0.71, best_val_acc = 0.71: 100%|██████████| 250/250 [00:38<00:00,  6.47it/s]
Epoch 3 / 50, loss = 0.64, acc = 0.84, val_acc = 0.72, best_val_acc = 0.72: 100%|██████████| 250/250 [00:39<00:00,  6.39it/s]
Epoch 4 / 50, loss = 0.84, acc = 0.73, val_acc = 0.73, best_val_acc = 0.73: 100%|██████████| 250/250 [00:39<00:00,  6.39it/s]
Epoch 5 / 50, loss = 0.85, acc = 0.77, val_acc = 0.73, best_val_acc = 0.73: 100%|██████████| 250/250 [00:39<00:00,  6.38it/s]
Epoch 6 / 50, loss = 0.55, acc = 0.81, val_acc = 0.73, best_val_acc = 0.73: 100%|██████████| 250/250 [00:39<00:00,  6.38it/s]
Epoch 7 / 50, loss = 0.53, acc = 0.81, val_acc = 0.74, best_val_acc = 0.74: 100%|██████████| 250/250 [00:39<00:00,  6.38it/s]
Epoch 8 / 50, loss = 0.54, acc = 0.88, val_acc = 0.74, best_val_acc = 0.74: 100%|██████████| 250/250 [00:38<00:00,  

In [12]:
best_val_acc

0.74398

6. Test the model on ID and OOD datasets

In [15]:
checkpoint = torch.load(model_path)
clip_model.load_state_dict(checkpoint['clip_model'])
classifier.load_state_dict(checkpoint['classifier'])

results_dict = {}
for test_dataset in [imagenet_test, imagenetv2_test, imagenet_sketch_test, imagenet_a_test,
                     imagenet_r_test]:

    dataset_name = test_dataset.dataset_name
    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4)
    test_label_map = test_dataset.label_map

    clip_model.eval()
    classifier.eval()
    with torch.no_grad():
        targets_list = []
        preds_list = []

        pbar_iter = tqdm.tqdm(test_dataloader)
        for idx, (inputs, targets) in enumerate(pbar_iter):
            pbar_iter.set_description(f"{dataset_name} test")
            inputs = inputs.to(device)
            targets = targets.to(device)

            image_features = clip_model.encode_image(inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            outputs = classifier(image_features)
            outputs = outputs[:,test_label_map] # map the logits to the test dataset, primarily for ImageNet-A and ImageNet-R

            targets_list.append(targets.detach().cpu().numpy())
            preds_list.append(outputs.detach().cpu().numpy())

    targets_list = np.hstack(targets_list)
    preds_list = np.vstack(preds_list)
    preds_list = torch.tensor(preds_list)
    test_acc = (torch.softmax(preds_list, dim=1).argmax(1).numpy() == targets_list).mean()

    results_dict[dataset_name] = {'pred_logits': preds_list.numpy(), 'targets': targets_list, 'test_acc': test_acc}

ood = []
for dataset_name in results_dict.keys():
    print(f"{dataset_name} Test acc = {results_dict[dataset_name]['test_acc']}")
    if dataset_name != 'ImageNet-1k':
        ood.append(results_dict[dataset_name]['test_acc'])
avg_ood = np.mean(ood)
print(f"Avg OOD Test acc = {avg_ood}")

ImageNet-1k test: 100%|██████████| 782/782 [01:14<00:00, 10.44it/s]
ImageNet-v2 test: 100%|██████████| 157/157 [00:15<00:00, 10.13it/s]
ImageNet-Sketch test: 100%|██████████| 796/796 [01:39<00:00,  8.04it/s]
ImageNet-A test: 100%|██████████| 118/118 [00:11<00:00, 10.02it/s]
ImageNet-R test: 100%|██████████| 469/469 [00:45<00:00, 10.24it/s]

ImageNet-1k Test acc = 0.74398
ImageNet-v2 Test acc = 0.6667
ImageNet-Sketch Test acc = 0.49668887185835836
ImageNet-A Test acc = 0.514
ImageNet-R Test acc = 0.7774666666666666
Avg OOD Test acc = 0.6137138846312562



