In [2]:
import clip
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [3]:
import os
import torch
import clip
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.datasets import MNIST
from tqdm import tqdm


def linear_probe(dataset):
    #Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load('RN50', device)

    #Load dataset
    root = os.path.expanduser("~/.cache")
    if dataset == "CIFAR10":
        train = CIFAR10(root, download=True, train=True, transform=preprocess)
        test = CIFAR10(root, download=True, train=False, transform=preprocess)
    elif dataset == "CIFAR100":
        train = CIFAR100(root, download=True, train=True, transform=preprocess)
        test = CIFAR100(root, download=True, train=False, transform=preprocess)
    elif dataset == "MNIST":
        train = MNIST(root, download=True, train=True, transform=preprocess)
        test = MNIST(root, download=True, train=False, transform=preprocess)

    def get_features(dataset):
        all_features = []
        all_labels = []

        with torch.no_grad():
            for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
                features = model.encode_image(images.to(device))

                all_features.append(features)
                all_labels.append(labels)

        return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

    # Calculate the image features
    train_features, train_labels = get_features(train)
    test_features, test_labels = get_features(test)

    # Perform logistic regression
    classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=0)
    classifier.fit(train_features, train_labels)

    # Evaluate using the logistic regression classifier
    predictions = classifier.predict(test_features)
    accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
    return f"{dataset} Accuracy = {accuracy:.3f}"

In [4]:
print(linear_probe("CIFAR10"))
print(linear_probe("CIFAR100"))

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 500/500 [05:34<00:00,  1.49it/s]
100%|██████████| 100/100 [01:06<00:00,  1.50it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =        10250     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  1.15129D+05    |proj g|=  1.13516D+03


 This problem is unconstrained.



At iterate   50    f=  2.39793D+04    |proj g|=  1.14444D+02

At iterate  100    f=  2.36020D+04    |proj g|=  2.49202D+01

At iterate  150    f=  2.35885D+04    |proj g|=  1.92117D+00

At iterate  200    f=  2.35879D+04    |proj g|=  4.71428D+00

At iterate  250    f=  2.35878D+04    |proj g|=  4.08149D-01


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  5.4min finished



           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
10250    274    292      1     0     0   1.279D-01   2.359D+04
  F =   23587.782397848034     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             
CIFAR10 Accuracy = 86.760
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 500/500 [05:57<00:00,  1.40it/s]
100%|██████████| 100/100 [01:15<00:00,  1.32it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =       102500     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  2.30259D+05    |proj g|=  1.52442D+02


 This problem is unconstrained.



At iterate   50    f=  1.00123D+05    |proj g|=  4.74049D+01

At iterate  100    f=  1.00068D+05    |proj g|=  2.86478D+00

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
*****    148    157      1     0     0   9.169D-02   1.001D+05
  F =   100066.68951902445     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed: 16.8min finished


CIFAR100 Accuracy = 63.580


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
100%|██████████| 600/600 [04:26<00:00,  2.25it/s]
100%|██████████| 100/100 [00:45<00:00,  2.22it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =        10250     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  1.38155D+05    |proj g|=  1.23014D+03


 This problem is unconstrained.



At iterate   50    f=  2.73799D+04    |proj g|=  7.48579D+01

At iterate  100    f=  2.66444D+04    |proj g|=  4.43989D+01

At iterate  150    f=  2.65884D+04    |proj g|=  2.32177D+01

At iterate  200    f=  2.65775D+04    |proj g|=  2.43566D+00

At iterate  250    f=  2.65756D+04    |proj g|=  5.33320D+00

At iterate  300    f=  2.65754D+04    |proj g|=  8.65665D-01

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
10250    308    328      1     0     0   1.137D+00   2.658D+04
  F =   26575.412240490539     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  9.8min finished


MNIST Accuracy = 95.690


In [5]:
#MNIST: 95.69%
#CIFAR100: 63.58%
#CIFAR10: 86.76%