In [1]:
import torch
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


### Hyperparameters

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

### Dataset

In [3]:
# Load the CUB-200-2011 dataset
def load_cub_dataset(data_dir):
    images = pd.read_csv(os.path.join(data_dir, 'images.txt'), sep=' ', names=['image_id', 'file_path'])
    labels = pd.read_csv(os.path.join(data_dir, 'image_class_labels.txt'), sep=' ', names=['image_id', 'class_id'])
    classes = pd.read_csv(os.path.join(data_dir, 'classes.txt'), sep=' ', names=['class_id', 'class_name'])
    return images, labels, classes
data_dir = 'data'

images, labels, classes = load_cub_dataset(data_dir)

print(images.head())
print(labels.head())
print(classes.head())

print(images.shape)
print(labels.shape)
print(classes.shape)

   image_id                                          file_path
0         1  001.Black_footed_Albatross/Black_Footed_Albatr...
1         2  001.Black_footed_Albatross/Black_Footed_Albatr...
2         3  001.Black_footed_Albatross/Black_Footed_Albatr...
3         4  001.Black_footed_Albatross/Black_Footed_Albatr...
4         5  001.Black_footed_Albatross/Black_Footed_Albatr...
   image_id  class_id
0         1         1
1         2         1
2         3         1
3         4         1
4         5         1
   class_id                  class_name
0         1  001.Black_footed_Albatross
1         2        002.Laysan_Albatross
2         3         003.Sooty_Albatross
3         4       004.Groove_billed_Ani
4         5          005.Crested_Auklet
(11788, 2)
(11788, 2)
(200, 2)


### CLIP

In [6]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cuda", jit=False)
clip_model.eval()

def get_clip_features(img_path):
    img = Image.open(img_path)
    img_input = clip_preprocess(img).unsqueeze(0).to("cuda")
    with torch.no_grad():
        img_features = clip_model.encode_image(img_input)
    return img_features

def get_clip_text_features(text):
    text_input = clip.tokenize([text]).to("cuda")
    with torch.no_grad():
        text_features = clip_model.encode_text(text_input)
    return text_features

def get_clip_similarity_score(img_features, text_features):
    img_features /= img_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity_score = (100.0 * img_features @ text_features.T).softmax(dim=-1)
    return similarity_score.item()

def recognize_bird_species(img_path):
    img_features = get_clip_features(img_path)
    similarities = []
    for class_name in classes['class_name']:
        text_features = get_clip_text_features(class_name)
        similarity_score = get_clip_similarity_score(img_features, text_features)
        similarities.append((class_name, similarity_score))
    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[0][0]

In [7]:
bird_species = recognize_bird_species('data/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg')
print(bird_species)

001.Black_footed_Albatross
