# BioClip Training

In [1]:
# (Optional) Install required packages if not present.
# Uncomment and run the next line in the notebook if you need to install dependencies.
# !pip install open_clip_torch pillow pandas requests

In [2]:
# Imports and model setup
import open_clip
import torch
from PIL import Image
import pandas as pd
import requests
from io import BytesIO
import time
import unicodedata
from typing import Optional, Dict, List

# Load BioCLIP model and transforms (imageomics/bioclip)
# This may download weights the first time you run it.
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()


CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [3]:
# Load and clean plant names from CSV
plant_df = pd.read_csv('data/Plants_Formatted.csv', encoding='latin-1')

import re

def clean_scientific_name(s: str) -> str:
    """Normalize scientific names and remove cultivar/variety details.

    Follows the same intent as the project's `wikipediaApi.ts`:
    - Removes cultivar text in single quotes
    - Removes hybrid markers like " x "
    - Removes variety/subspecies markers (var., subsp., sub., forma)
    - If string contains ' sp.' or 'unknown', reduce to genus only
    - Normalize unicode and collapse whitespace
    """
    if not isinstance(s, str):
        s = str(s)
    s = s.strip()
    if not s:
        return s

    # Normalize unicode (remove accents)
    s = unicodedata.normalize('NFKD', s)
    s = ''.join(ch for ch in s if not unicodedata.combining(ch))

    # Remove weird replacement characters
    s = s.replace('\u2019', "'").replace('\u201c', '"').replace('\u201d', '"')

    # Remove any trailing or internal newlines
    s = s.replace('\n', ' ').replace('\r', ' ').strip()

    # Remove cultivar names in single quotes: "Genus species 'Cultivar'" -> "Genus species"
    if "'" in s:
        # remove the quote and everything after the first quote occurrence
        s = s.split("'")[0].strip()

    # Remove hybrid markers like ' x '
    if ' x ' in s:
        s = s.split(' x ')[0].strip()

    # Remove variety/subspecies markers
    markers = [' var. ', ' subsp. ', ' sub. ', ' forma ', ' f. ']
    for marker in markers:
        if marker in s:
            s = s.split(marker)[0].strip()

    # Handle 'sp.' or 'sp ' or 'unknown' -> return genus only
    if ' sp.' in s or re.search(r'\bsp\b', s, flags=re.IGNORECASE) or 'unknown' in s.lower():
        parts = s.split()
        s = parts[0] if parts else s

    # Collapse multiple spaces
    s = ' '.join(s.split())

    return s

# Some rows may have empty Scientific Name cells; drop those
plant_df['Scientific Name'] = plant_df['Scientific Name'].fillna('').astype(str)
all_names = [clean_scientific_name(x) for x in plant_df['Scientific Name'].tolist() if x.strip()!='']
# Deduplicate while preserving order
seen = set()
plant_names = []
# Keep a mapping to original names too (for evaluation display)
original_map = {}
for orig_raw in plant_df['Scientific Name'].tolist():
    orig = str(orig_raw) if not pd.isna(orig_raw) else ''
    cleaned = clean_scientific_name(orig)
    if cleaned and cleaned not in seen:
        seen.add(cleaned)
        plant_names.append(cleaned)
        original_map[cleaned] = orig

print(f'Unique species to evaluate: {len(plant_names)}')
plant_names[:10]


Unique species to evaluate: 349


['Adiantum peruvianum',
 'Adiantum raddianum',
 'Adiantum ternerum',
 'Adiantum trapeziforme',
 'Aechmea',
 'Aechmea blanchetiana',
 'Aechmea chantinii',
 'Aechmea fasciata',
 'Aechmea fulgens',
 'Aechmea gamosepala']

In [4]:
# Load and preprocess your image
image = Image.open('data/example_images/Adiantum-peruvianum-Silver-Dollar-Fern-Amazon-Spheres.jpg.webp')
image = preprocess_val(image).unsqueeze(0).to(device)
image

tensor([[[[-1.6171, -1.6755, -1.6755,  ..., -1.2959, -1.2375, -1.1499],
          [-1.6901, -1.7485, -1.7193,  ..., -1.0039, -1.2375, -1.1353],
          [-1.7339, -1.7631, -1.7047,  ..., -0.3032, -0.4492, -1.1353],
          ...,
          [-1.3105, -1.3105, -1.2521,  ..., -0.7266, -0.6098, -0.1280],
          [-1.3397, -1.3251, -1.3251,  ..., -0.6390, -0.4930, -0.2740],
          [-1.2667, -1.3251, -1.2083,  ..., -0.4930, -0.4054, -0.1134]],

         [[-1.4369, -1.4970, -1.4970,  ..., -1.1668, -1.1518, -1.0918],
          [-1.5120, -1.5870, -1.5720,  ..., -0.6415, -1.1068, -1.0918],
          [-1.5570, -1.6170, -1.5720,  ...,  0.0488, -0.2363, -1.0918],
          ...,
          [-1.2568, -1.2568, -1.1968,  ..., -0.5965, -0.4614,  0.0338],
          [-1.2718, -1.2718, -1.2718,  ..., -0.5065, -0.3564, -0.1313],
          [-1.1668, -1.2418, -1.1368,  ..., -0.3564, -0.2513,  0.0488]],

         [[-1.1532, -1.2243, -1.2243,  ..., -0.9399, -0.8688, -0.8119],
          [-1.2243, -1.3096, -

In [5]:
# Tokenize and encode all plant names (text features) once
with torch.no_grad():
    # Tokenizer accepts list of strings
    text_tokens = tokenizer(plant_names)
    text_tokens = text_tokens.to(device) if hasattr(text_tokens, 'to') else text_tokens
    text_features = model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# Improved Wikipedia fetching: session with retries, batch MediaWiki queries, and local cache
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import os
import json
from urllib.parse import quote_plus

WIKI_REST_BASE = 'https://en.wikipedia.org/api/rest_v1/page/summary/'
WIKI_API_BASE = 'https://en.wikipedia.org/w/api.php'
HEADERS = {'User-Agent': 'BioCLIP-Eval/1.0 (Educational Project)', 'Accept': 'application/json'}

# Prepare a requests session with retries and backoff
session = requests.Session()
retry_strategy = Retry(
    total=5,
    backoff_factor=0.8,
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=frozenset(["HEAD", "GET", "OPTIONS"])
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
session.mount("http://", adapter)
session.headers.update(HEADERS)

# Local cache for wiki metadata and downloaded thumbnails
cache_path = 'data/wiki_cache.json'
img_cache_dir = 'data/wiki_images'
os.makedirs(img_cache_dir, exist_ok=True)

if os.path.exists(cache_path):
    try:
        with open(cache_path, 'r', encoding='utf-8') as fh:
            wiki_cache = json.load(fh)
    except Exception:
        wiki_cache = {}
else:
    wiki_cache = {}


def save_cache():
    try:
        with open(cache_path, 'w', encoding='utf-8') as fh:
            json.dump(wiki_cache, fh, ensure_ascii=False, indent=2)
    except Exception:
        pass


def slugify_name(name: str) -> str:
    return ''.join(c if c.isalnum() or c in (' ', '-', '_') else '_' for c in name).replace(' ', '_')[:200]


def batch_query_pageimages(names: list, batch_size: int = 50) -> Dict[str, Dict]:
    """Query MediaWiki action=query in batches to get thumbnails for many titles at once.

    Returns a mapping from requested-name -> dict(title, thumbnail, url) for those that had thumbnails.
    """
    mapping = {}
    for i in range(0, len(names), batch_size):
        batch = names[i:i+batch_size]
        titles = '|'.join(batch)
        params = {
            'action': 'query',
            'titles': titles,
            'prop': 'pageimages',
            'pithumbsize': 1000,
            'redirects': 1,
            'format': 'json'
        }
        try:
            resp = session.get(WIKI_API_BASE, params=params, timeout=15)
            if resp.status_code != 200:
                time.sleep(0.1)
                continue
            data = resp.json()
            query = data.get('query', {})
            pages = query.get('pages', {})
            # Build normalized/redirect maps
            norm_map = {n['from']: n['to'] for n in query.get('normalized', [])} if query.get('normalized') else {}
            redir_map = {r['from']: r['to'] for r in query.get('redirects', [])} if query.get('redirects') else {}

            # title -> page mapping
            title_to_page = {page.get('title', ''): page for page in pages.values()}

            for requested in batch:
                mapped = requested
                if requested in norm_map:
                    mapped = norm_map[requested]
                if requested in redir_map:
                    mapped = redir_map[requested]

                page = title_to_page.get(mapped)
                if page:
                    thumb = page.get('thumbnail', {}).get('source', '')
                    if thumb:
                        mapping[requested] = {
                            'title': page.get('title', ''),
                            'thumbnail': thumb,
                            'url': f"https://en.wikipedia.org/?curid={page.get('pageid')}"
                        }
        except Exception:
            # on any error, wait and continue
            time.sleep(0.2)
            continue
        # brief pause between batches
        time.sleep(0.1)
    return mapping


def get_wikipedia_summary_with_thumbnail(scientific_name: str, timeout: int = 10) -> Optional[Dict]:
    """Fallback single-name fetch using REST summary endpoint (used only when batch lookup didn't find a thumbnail).
    Returns dict with title, extract, url, thumbnail or None.
    """
    try:
        encoded = requests.utils.quote(scientific_name)
        resp = session.get(f'{WIKI_REST_BASE}{encoded}', timeout=timeout)
        if resp.status_code == 200:
            data = resp.json()
            thumbnail = data.get('thumbnail', {}).get('source', '') if isinstance(data.get('thumbnail'), dict) else ''
            return {
                'title': data.get('title', ''),
                'extract': data.get('extract', ''),
                'url': data.get('content_urls', {}).get('desktop', {}).get('page', ''),
                'thumbnail': thumbnail
            }
        else:
            return None
    except Exception:
        return None


def download_image_bytes(url: str, timeout: int = 10, attempts: int = 3) -> Optional[bytes]:
    """Download image bytes with retries using the session. Skips SVG/XML content types.
    """
    for attempt in range(attempts):
        try:
            r = session.get(url, timeout=timeout, stream=True)
            if r.status_code == 200:
                content_type = r.headers.get('Content-Type', '').lower()
                if 'svg' in content_type or 'xml' in content_type:
                    return None
                data = r.content
                if data:
                    return data
            else:
                # pause and retry
                time.sleep(0.5 * (attempt + 1))
        except Exception:
            time.sleep(0.5 * (attempt + 1))
            continue
    return None


## Test Plant Classifier Class

In [6]:
# Evaluate across all plant names using cached thumbnails and batch MediaWiki lookup
results = []
start = time.time()

# 1) Use batch MediaWiki query to get thumbnails for many species at once
print('Running batch MediaWiki pageimages lookup...')
batch_map = batch_query_pageimages(plant_names, batch_size=50)
print(f'Batch lookup returned thumbnails for {len(batch_map)} names')

for i, pname in enumerate(plant_names):
    print(f'[{i+1}/{len(plant_names)}] {pname}', end='')

    # If cached and already has thumbnail stored, reuse
    cache_entry = wiki_cache.get(pname)
    if cache_entry and cache_entry.get('thumbnail'):
        thumb = cache_entry['thumbnail']
        fetch_method = cache_entry.get('fetch_method', 'cache')
    else:
        # Try batch_map first
        meta = batch_map.get(pname)
        fetch_method = None
        if meta and meta.get('thumbnail'):
            thumb = meta['thumbnail']
            fetch_method = 'batch'
        else:
            # Try genus (if different)
            genus = pname.split()[0] if pname.split() else pname
            meta_genus = batch_map.get(genus)
            if meta_genus and meta_genus.get('thumbnail'):
                thumb = meta_genus['thumbnail']
                fetch_method = 'batch_genus'
            else:
                # Finally, try REST summary per-name as fallback
                wiki = get_wikipedia_summary_with_thumbnail(pname)
                if wiki and wiki.get('thumbnail'):
                    thumb = wiki['thumbnail']
                    fetch_method = 'rest'
                else:
                    thumb = ''

        # Save to cache placeholder (we'll enrich later if we download image)
        wiki_cache[pname] = {'thumbnail': thumb, 'fetch_method': fetch_method}
        save_cache()

    if not thumb:
        print(' - no wiki image')
        results.append({
            'plant_name': pname,
            'found_image': False,
            'thumbnail': '',
            'top_1': '',
            'top_1_conf': 0.0,
            'top_5': [],
            'fetch_method': fetch_method
        })
        # polite sleep
        time.sleep(0.12)
        continue

    # If we've already downloaded this thumbnail image locally, reuse it
    slug = slugify_name(pname)
    img_path = os.path.join(img_cache_dir, f"{slug}.jpg")
    if os.path.exists(img_path):
        with open(img_path, 'rb') as f:
            img_bytes = f.read()
    else:
        img_bytes = download_image_bytes(thumb)
        if img_bytes:
            try:
                with open(img_path, 'wb') as f:
                    f.write(img_bytes)
            except Exception:
                pass

    if not img_bytes:
        print(' - failed download')
        # update cache to mark failure and continue
        wiki_cache[pname] = {'thumbnail': thumb, 'fetch_method': fetch_method, 'downloaded': False}
        save_cache()
        results.append({
            'plant_name': pname,
            'found_image': False,
            'thumbnail': thumb,
            'top_1': '',
            'top_1_conf': 0.0,
            'top_5': [],
            'fetch_method': fetch_method
        })
        time.sleep(0.12)
        continue

    # Preprocess and classify
    try:
        img = Image.open(BytesIO(img_bytes)).convert('RGB')
        image_tensor = preprocess_val(img).unsqueeze(0).to(device)
    except Exception:
        print(' - image open error')
        wiki_cache[pname] = {'thumbnail': thumb, 'fetch_method': fetch_method, 'downloaded': False}
        save_cache()
        results.append({
            'plant_name': pname,
            'found_image': False,
            'thumbnail': thumb,
            'top_1': '',
            'top_1_conf': 0.0,
            'top_5': [],
            'fetch_method': fetch_method
        })
        time.sleep(0.12)
        continue

    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * (image_features @ text_features.T)
        probs = logits.softmax(dim=-1).cpu().numpy()[0]

    top_idxs = probs.argsort()[::-1][:5]
    top5 = [(plant_names[idx], float(probs[idx])) for idx in top_idxs]
    top1_name, top1_conf = top5[0][0], top5[0][1]

    correct = (top1_name == pname)
    print(f' - image OK - top1={top1_name} ({top1_conf*100:.2f}%)' + ( ' ✅' if correct else '' ))

    # update cache
    wiki_cache[pname].update({'downloaded': True, 'local_path': img_path, 'title': wiki_cache.get(pname, {}).get('title', '')})
    save_cache()

    results.append({
        'plant_name': pname,
        'found_image': True,
        'thumbnail': thumb,
        'top_1': top1_name,
        'top_1_conf': float(top1_conf),
        'top_5': top5,
        'fetch_method': fetch_method
    })

    # polite sleep between processing items
    time.sleep(0.12)

# Save results to CSV and show summary
out_df = pd.DataFrame(results)
out_path = 'data/bioclip_wikipedia_eval.csv'
out_df.to_csv(out_path, index=False)

elapsed = time.time() - start
print(f"\nDone. Results saved to {out_path}. Elapsed: {elapsed:.1f}s")

# Basic summary
total = len(out_df)
with_images = out_df['found_image'].sum()
if with_images>0:
    correct_top1 = ((out_df['found_image']) & (out_df['top_1'] == out_df['plant_name'])).sum()
    print(f'Total species: {total}; with images: {with_images}; correct top-1 (on those): {correct_top1} ({correct_top1/with_images*100:.1f}%)')
else:
    print(f'Total species: {total}; with images: 0')


Running batch MediaWiki pageimages lookup...
Batch lookup returned thumbnails for 264 names
[1/349] Adiantum peruvianumBatch lookup returned thumbnails for 264 names
[1/349] Adiantum peruvianum - image OK - top1=Adiantum peruvianum (99.98%) ✅
[2/349] Adiantum raddianum - image OK - top1=Adiantum peruvianum (99.98%) ✅
[2/349] Adiantum raddianum - image OK - top1=Adiantum raddianum (73.26%) ✅
[3/349] Adiantum ternerum - no wiki image
 - image OK - top1=Adiantum raddianum (73.26%) ✅
[3/349] Adiantum ternerum - no wiki image
[4/349] Adiantum trapeziforme[4/349] Adiantum trapeziforme - image OK - top1=Adiantum peruvianum (55.79%)
[5/349] Aechmea - image OK - top1=Adiantum peruvianum (55.79%)
[5/349] Aechmea - image OK - top1=Aechmea fasciata (64.47%)
[6/349] Aechmea blanchetiana - image OK - top1=Aechmea fasciata (64.47%)
[6/349] Aechmea blanchetiana - image OK - top1=Vriesea (92.63%)
[7/349] Aechmea chantinii - image OK - top1=Vriesea (92.63%)
[7/349] Aechmea chantinii - image OK - top1=Vr

In [7]:
# Preview a few rows from the saved CSV
r = pd.read_csv('data/bioclip_wikipedia_eval.csv')
r.head()

Unnamed: 0,plant_name,found_image,thumbnail,top_1,top_1_conf,top_5,fetch_method
0,Adiantum peruvianum,True,https://upload.wikimedia.org/wikipedia/commons...,Adiantum peruvianum,0.999808,"[('Adiantum peruvianum', 0.9998077750205994), ...",batch
1,Adiantum raddianum,True,https://upload.wikimedia.org/wikipedia/commons...,Adiantum raddianum,0.732644,"[('Adiantum raddianum', 0.7326443195343018), (...",batch
2,Adiantum ternerum,False,,,0.0,[],
3,Adiantum trapeziforme,True,https://upload.wikimedia.org/wikipedia/commons...,Adiantum peruvianum,0.557934,"[('Adiantum peruvianum', 0.5579341650009155), ...",batch
4,Aechmea,True,https://upload.wikimedia.org/wikipedia/commons...,Aechmea fasciata,0.6447,"[('Aechmea fasciata', 0.6446998119354248), ('A...",batch
