## CIFAR-10 zero-shot classifier with Resnet50
In this notebook, we provide a use-case of our method.
We use a Resnet50 model (pre-trained on ImageNet) and leverage Text-To-Concept to turn it into a zero-shot classifer on CIFAR-10.
This notebook has these sections:
+ <i>Preliminaries</i>: we import required libraries and load transformations.
+ <i>Resnet50</i>: we load the model and implement its necessary functions, enabling us to use `TextToConcept` framework.
+ <i>Linear Aligner</i>: we initiate `TextToConcept` object and train/load its linear aligner.
+ <i>Zero-shot classifier</i>: we use methods implemented in `TextToConcept` and appropriate text prompts to get the zero-shot classifer.
+ <i>Zero-shot performance on CIFAR-10</i>: we load CIFAR-10 and evaluate Resnet50-based zero-shot classifier on it. 


### Preliminaries
In this section, we import the required libraries and initialize standard transformations necessary for loading datasets. It is worth mentioning that certain models require input normalization, while others do not.

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from TextToConcept import TextToConcept

In [2]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
std_transform_without_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor()])


std_transform_with_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])

### Resnet50
In this part, we load Resnet50 model.
In order to use ``TextToConcept`` framework, model should implement these functions/attributes:
+ ``forward_features(x)`` that takes a tensor as the input and outputs the representation (features) of input $x$ when it is passed through the model.
+ ``get_normalizer`` should be the normalizer that the models uses to preprocess the input. e.g., Resnet18, uses standard ImageNet normalizer.
+ Attribute ``has_normalizer`` should be `True` when normalizer is need for the model.

In [None]:
model = torchvision.models.resnet50(pretrained=True)

encoder = torch.nn.Sequential(*list(model.children())[:-1])
model.forward_features = lambda x : encoder(x)
model.get_normalizer = torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
model.has_normalizer = True

### Linear Aligner

<b>Initiating Text-To-Concept Object</b><br>
In this section, we initiate ``TextToConcept`` object which turns the vision encoder (e.g., Resnet50) into a model capable of integrating language and vision. By doing so, we enable the utilization of certain abilities present in vision-language models.

In [7]:
text_to_concept = TextToConcept(model, 'resnet50')

We can either train the aligner or load an existing one.

#### Training Linear Aligner

<b>Loading ImageNet Dataset to Train the Aligner</b><br>
We note that even $20\%$ of ImageNet training samples suffices for training an effective linear aligner. 
We refer to Appendix A of our paper for more details on sample efficiency of linear alignment.

In [9]:
# loading imagenet dataset to train aligner.
dset = torchvision.datasets.ImageNet(root='/fs/cml-datasets/ImageNet/ILSVRC2012',
                                     split='train',
                                     transform=std_transform_without_normalization)

# 20% of images are fairly enough.
num_of_samples = int(0.2 * len(dset))
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), num_of_samples, replace=False))

<b>Training the Linear Aligner</b><br>
After loading the object, we need to train the aligner.
+ In order to train the aligner, ``train_linear_aligner`` should be called which obtains representations of the given model (e.g., Resnet50) on ``dset`` as well that of a vision-language model such as CLIP. These representations can also be loaded. Next, this function solves the linear transformation and obtain optimal alignment from model's space to vision-language space.
+ By calling the function ``save_linear_aligner``, linear aliger will be stored which can be utilized later.

In [10]:
text_to_concept.train_linear_aligner(dset,
                                     load_reps=False,)

text_to_concept.save_linear_aligner('imagenet_resnet50_aligner.pth')

Obtaining representations ...


100%|██████████████████████████████████████████████████████████████████████████████████████| 16015/16015 [14:44<00:00, 18.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 16015/16015 [11:41<00:00, 22.83it/s]


Training linear aligner ...
Linear alignment: ((256233, 2048)) --> ((256233, 512)).
Initial MSE, R^2: 7.412, -0.648
Epoch number, loss: 0, 1.160
Epoch number, loss: 1, 1.023
Epoch number, loss: 2, 0.998
Epoch number, loss: 3, 0.986
Epoch number, loss: 4, 0.980
Final MSE, R^2 = 0.976, 0.783


<b>Loading the Linear Aligner</b><br>
We can also use an already existing linear aligner, to do so, we use the function ``load_linear_aligner``.

In [25]:
text_to_concept.load_linear_aligner('imagenet_resnet50_aligner.pth')

### Zero-shot Classifier
We note that CIFAR-10 is a <i>$10$-way</i> classification problem. 
We use prompts of the form `a pixelated of {c}` to get appropriate concepts in vision-language space.

In [11]:
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [12]:
cifar_zeroshot_classifier = text_to_concept.get_zero_shot_classifier(cifar_classes,
                                                                     prompts=['a pixelated photo of a {}'])

### Zero-shot performance on CIFAR-10
After loading CIFAR-10, we use `cifar_zeroshot_classifier(x)` to get logits of the classification problem when input $x$ is given.

In [13]:
cifar = torchvision.datasets.CIFAR10(root='data/',
                                     download=True,
                                     train=False,
                                     transform=std_transform_with_normalization)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/


In [14]:
loader = torch.utils.data.DataLoader(cifar, batch_size=16, shuffle=True, num_workers=8)
correct, total = 0, 0
with torch.no_grad():
    for data in tqdm(loader):
        x, y = data[:2]
        x = x.to(device)

        outputs = cifar_zeroshot_classifier(x).detach().cpu()
        _, predicted = outputs.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()

100%|██████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:28<00:00, 22.26it/s]


In [15]:
f'ResNet50 Zeroshot Accuracy on CIFAR-10 {100.*correct/total:.2f}'

'ResNet50 Zeroshot Accuracy on CIFAR-10 68.18'