In [None]:
from data_utils import load_ham10000_dataset, LESION_TYPE
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch

from torch.utils.data import DataLoader,Dataset
from torchvision import models,transforms
import clip

In [None]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

# CLIP Zero-Shot Classification

In [None]:
ham10000 = load_ham10000_dataset(transform=clip_preprocess)

In [None]:
BATCH_SIZE = 64

In [None]:
def clip_zero_shot(data_set, classes):
    # https://colab.research.google.com/drive/1IqJfogZdC61dgE4BDQILCJS-zUiphD4y?authuser=2#scrollTo=EuZFg3ZlHOVD
    data_loader = DataLoader(data_set, batch_size=BATCH_SIZE, shuffle=True)
    # Encode text features here
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}, a type of skin lesion.") for c in classes]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    # Encode image features here
    correct = 0
    total = 0
    for image, label in tqdm(data_loader):
        image, label = image.to(device), label.to(device)
        with torch.no_grad():
            image_features = clip_model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        _, pred = similarity.max(dim=-1)
        correct += (pred == label).sum().item()
        total += len(label)

    return correct / total

In [None]:
lesion_classes = LESION_TYPE.values() # This was probably only because the class labels were numbers, not strs