Linear probe with hyperparameter sweep for CLIP models. Approach based on https://github.com/openai/CLIP.

In [None]:
from sklearn import metrics
import os
import clip
import torch
import wandb
import utils as uu

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms

from tqdm import tqdm

from datasetCUB.Cub_class.class_cub import Cub2011
from datasetCUB.transformations import label_transformation as lt

In [None]:
clip.available_models()

In [None]:
#choose a clip model!
model_architecture = 'ViT-B/32'

In [None]:
model_name = model_architecture.replace('/', '-')
PROJECT_NAME="Hyperparameter-Tuning-ViT"
RUN_NAME = "Linear-Probe-" + model_name 

In [None]:
# Load the model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(model_architecture, device)
print(device)

In [None]:
# # Alternatively, you can load a model that has been pre-trained on imageNet with a resnet50. (no CLIP-model)

# model_architecture = 'resnet50'

# preprocess = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# model = torch.hub.load('pytorch/vision:v0.8.0', model_architecture, pretrained=True) 

In [None]:
# Load the dataset
root = 'path/to/repository/SpeciesRecognition'
cub_root = uu.get_root_CUB(root)
train = Cub2011(cub_root, train=True, transform_image=preprocess, label_mapping = False)
test = Cub2011(cub_root, train=False, transform_image=preprocess, label_mapping = False)


In [None]:
# get dataset split
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed = 42

# Creating data indices for training and validation splits:
dataset_size = len(train)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

In [None]:
def get_features(dataset,sampler):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100, sampler=sampler)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)
            a = torch.cat(all_features).cpu().numpy()
            b = torch.cat(all_labels).cpu().numpy()

    return a, b

# Calculate the image features
train_features, train_labels = get_features(train, train_sampler)
val_features, val_labels = get_features(train, valid_sampler )
test_features, test_labels = get_features(test, None)
all_train_features, all_train_labels = get_features(train, None)

In [None]:
train_features, train_labels = 0,0
val_features, val_labels = 0,0
test_features, test_labels = 0,0

In [None]:
#Hpyerparametersearch
SWEEP = True

# fix parameters
random_state = 0 
max_iter = 1000 
verbose = 1

# hyperparametersearch for parameter C via wandb
c_min = 3.0
c_max = 4.0

if SWEEP:
    sweep_config = {
        'method': 'random'
        }
    metric = {
        'name': 'val_accuracy',
        'goal': 'maximize'
    }
    parameters_dict = {
        'random_state': {
            'value': random_state
            },
        'C': {
            'min': c_min,
            'max': c_max 
            },
        'max_iter': {
              'value': max_iter
            },
        'verbose':{
            'value': verbose
            },
        }

    sweep_config['parameters'] = parameters_dict
    sweep_config['metric'] = metric
    sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)

In [None]:
# run hyperparameter tuning with Wandb

cross_val = True # if False: validation on own split

def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        classifier = LogisticRegression(random_state=config.random_state, C=config.C, max_iter=config.max_iter, verbose=config.verbose)
    
        if cross_val == True:
            predicted = cross_val_predict(classifier, all_train_features, all_train_labels, cv=10)
            accuracy = metrics.accuracy_score(all_train_labels, predicted)
        else: # valiation on own split 
            classifier.fit(train_features, train_labels)
            predictions = classifier.predict(val_features)
            accuracy = np.mean((val_labels == predictions).astype(np.float)) * 100.        
        
        wandb.log({"val_accuracy": accuracy})

if SWEEP:
    wandb.agent(sweep_id, train, count=30)

In [None]:
# Train classifier

random_state = 0 
C = 3.59 # choose hyperparameter from sweep
max_iter = 1000 
verbose = 1

run = wandb.init(project=PROJECT_NAME, job_type="inference", name=RUN_NAME)

classifier = LogisticRegression(random_state=random_state, C=C, max_iter=max_iter, verbose=verbose)
classifier.fit(all_train_features, all_train_labels)
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float64)) * 100.
print(f"Accuracy = {accuracy:.3f}")

wandb.finish()
