In [None]:
!pip install open_clip_torch torch torchvision
!pip install torchao

[33mDEPRECATION: celery 4.4.0 has a non-standard dependency specifier pytz>dev. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of celery or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[33mDEPRECATION: celery 4.4.0 has a non-standard dependency specifier pytz>dev. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of celery or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mCollecting polars
  Obtaining dependency information for polars from https://files.pythonhosted.org/packages/9f/4c/21a227b722534404241c2a76beceb7463469d50c775a227fc5c209eb8adc/polars-1.35.1-py3-none-any.whl.metadata
  Downloading polars-1.35.1-py3-none-a

In [None]:
import torch
import open_clip
import requests
from PIL import Image
import os
import zipfile
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from huggingface_hub import hf_hub_download
import json

## Dowload testing dataset


In [None]:
# download dataset
DATASET_URL = "https://www.kaggle.com/api/v1/datasets/download/nguyenletruongthien/animals-and-plants-dataset"

# Download and unzip the dataset

if not os.path.exists("animal_plant_samples"):
    try:
        with requests.get(DATASET_URL, stream=True) as r:
            r.raise_for_status()
            with open("animals-and-plants-dataset.zip", "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
        with zipfile.ZipFile("animals-and-plants-dataset.zip", 'r') as zip_ref:
            zip_ref.extractall("animal_plant_samples")
        os.remove("animals-and-plants-dataset.zip")
    except Exception as e:
        print(f"An error occurred while downloading or extracting the dataset: {e}")
            

## Loading the main Model

In [2]:

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')

## Model prediction function

Based on the demo at https://huggingface.co/spaces/imageomics/bioclip-2-demo

We are going to use the open_domain_classification

In [18]:
preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


In [19]:
txt_emb = torch.from_numpy(
    np.load(
        hf_hub_download(
            repo_id="imageomics/TreeOfLife-200M",
            filename="embeddings/txt_emb_species.npy",
            repo_type="dataset",
        )
    )
)
with open(hf_hub_download(
        repo_id="imageomics/TreeOfLife-200M",
        filename="embeddings/txt_emb_species.json",
        repo_type="dataset",
    )) as fd:
        txt_names = json.load(fd)

In [20]:
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
device = torch.device("cpu")


In [None]:
def format_name(taxon, common):
    if not common:
        return " ".join(taxon)
    else:
        return f"{common}"


In [None]:
@torch.no_grad()
def open_domain_classification(img):
    """
    Predicts from the entire tree of life.
    If targeting a higher rank than species, then this function predicts among all
    species, then sums up species-level probabilities for the given rank.
    """
    k = 1

    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
    probs = F.softmax(logits, dim=0)

    topk = probs.topk(k)
    prediction_dict = {
        format_name(*txt_names[i]): prob
        for i, prob in zip(topk.indices, topk.values)
    }
    print(f"INFO: prediction with prob.: {prediction_dict}")
    return prediction_dict


## Test Main Model accuracy on dataset

In [53]:
prediction_dic = open_domain_classification(
    Image.open(
        "animal_plant_samples/Animals and Plants Dataset/train/Aves/Aves_image_1000.jpg"
    )
)

prediction_name = str(*prediction_dic.keys()).split("(")[1].strip(")")
print(f"prediction name: {prediction_name}")


INFO: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Limosa fedoa (Marbled Godwit)': tensor(0.5502)}
prediction name: Marbled Godwit
