<a href="https://colab.research.google.com/github/dohaadel/Data-Mining/blob/main/zero-shot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preparation for Colab
Make sure you're running a GPU runtime; if not, select "GPU" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will install the `clip` package and its dependencies, and check if PyTorch 1.7.1 or later is installed.

In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[?25l[K     |██████▏                         | 10 kB 16.2 MB/s eta 0:00:01[K     |████████████▍                   | 20 kB 14.5 MB/s eta 0:00:01[K     |██████████████████▌             | 30 kB 10.0 MB/s eta 0:00:01[K     |████████████████████████▊       | 40 kB 8.7 MB/s eta 0:00:01[K     |██████████████████████████████▉ | 51 kB 5.8 MB/s eta 0:00:01[K     |████████████████████████████████| 53 kB 698 kB/s 
Installing collected packages: ftfy
Successfully installed ftfy-6.1.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-oyu1768x
  Running command git clone -q https://github.com/openai/CLIP.git /tmp/pip-req-build-oyu1768x
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369387 sha256=0159826b00c4abe80703fa9667d2f5b10f0c4541c

In [None]:
import numpy as np
import torch
import clip
from tqdm.notebook import tqdm
from torchvision.datasets import  MNIST
import os
import numpy as np

# Loading the model 1


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device)

100%|███████████████████████████████████████| 244M/244M [00:04<00:00, 61.3MiB/s]


# Preparing mnist labels and prompts


In [None]:
mnist_classes = ['0','1','2','3','4','5','6','7','8','9',]

In [None]:
mnist_templates = ['a photo of the number: "{}".',]

print(f"{len(mnist_classes)} classes, {len(mnist_templates)} templates")

10 classes, 1 templates


In [None]:
class_map = {'MNIST': mnist_classes}
template_map = {'MNIST': mnist_templates}

In [None]:
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res

In [None]:
@torch.no_grad()
def extract_text_features(dataset_name):
    class_names = class_map[dataset_name]
    templates = template_map[dataset_name]
    model.to(device)
    model.eval()

    zeroshot_weights = []
    for classname in class_names:
        texts = [template.format(classname) for template in templates]
        texts = clip.tokenize(texts).to(device)
        class_embeddings = model.encode_text(texts)
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings.mean(dim=0)
        class_embedding /= class_embedding.norm()
        zeroshot_weights.append(class_embedding)
    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights

In [None]:
mnist = MNIST(root=os.path.expanduser("~/.cache"), download=True, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/.cache/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /root/.cache/MNIST/raw/train-images-idx3-ubyte.gz to /root/.cache/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/.cache/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /root/.cache/MNIST/raw/train-labels-idx1-ubyte.gz to /root/.cache/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/.cache/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /root/.cache/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/.cache/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/.cache/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /root/.cache/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.cache/MNIST/raw



In [None]:
for dataset in [mnist]:
    image_features = []
    image_labels = []
    for image, class_id in dataset:
        image_input = preprocess(image).unsqueeze(0).to(device)
        with torch.no_grad():
            image_feature = model.encode_image(image_input)
        image_feature /= image_feature.norm()
        image_features.append(image_feature)
        image_labels.append(class_id)
    image_features = torch.stack(image_features, dim=1).to(device)
    image_features = image_features.squeeze()
    
    # extract text feature
    dataset_name = 'MNIST' 
    text_features = extract_text_features(dataset_name)
    
    # compute top-1 accuracy
    logits = (100. * image_features @ text_features).softmax(dim=-1)
    image_labels = torch.tensor(image_labels).unsqueeze(dim=1).to(device)
    top1_acc = accuracy(logits, image_labels, (1,))
    print(f'top-1 accuracy for {dataset_name} dataset: {top1_acc[0]:.3f}')

top-1 accuracy for MNIST dataset: 57.620


In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
cm=confusion_matrix(logits.argmax(1), image_labels)
cm

array([[712,   0,  65,   0,   2,   0,  79,   1,   5, 104],
       [  9, 722,   0,   0,  45,   3,  59,  23,   0,  10],
       [ 55, 143, 543, 134, 109,  19,  70, 169,  21,  39],
       [139, 205, 359, 804,  89, 273, 232,  79, 269,  52],
       [ 19,   8,  27,   2, 588,   1, 112,  26,   1,  42],
       [  0,   0,   1,   0,   0, 277,   5,   0,   1,   0],
       [ 17,   0,  10,  14,   2, 302, 369,   1,  48,   2],
       [ 14,  57,  24,  12, 133,  14,   7, 727,  30, 223],
       [  5,   0,   3,  43,   2,   3,  14,   0, 497,  14],
       [ 10,   0,   0,   1,  12,   0,  11,   2, 102, 523]])

In [None]:
from operator import truediv
import numpy as np

In [None]:
tp = np.diag(cm)
prec = list(map(truediv, tp, np.sum(cm, axis=0)))
rec = list(map(truediv, tp, np.sum(cm, axis=1)))
print ('Precision: {}\nRecall: {}'.format(prec, rec))

Precision: [0.726530612244898, 0.6361233480176212, 0.5261627906976745, 0.7960396039603961, 0.5987780040733197, 0.31053811659192826, 0.38517745302713985, 0.7071984435797666, 0.5102669404517454, 0.5183349851337958]
Recall: [0.7355371900826446, 0.8289322617680827, 0.41705069124423966, 0.32147141143542585, 0.711864406779661, 0.9753521126760564, 0.4823529411764706, 0.5858178887993554, 0.8554216867469879, 0.7912254160363086]


In [None]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']