In [None]:
!pip install open_clip_torch
!pip install pyinaturalist
!pip install captum
!pip install pybioclip
!pip install scikit-image
!pip install matplotlib
!pip install scipy
!pip install torchcam



In [6]:
import os
import gc  # Import garbage collector
import pandas as pd
from PIL import Image
import torch
import open_clip
import numpy as np
from skimage.transform import resize
from pyinaturalist import get_observations, Observation
from bioclip import TreeOfLifeClassifier, Rank
import matplotlib.pyplot as plt

# Mount Google Drive (for saving outputs)
from google.colab import drive
drive.mount('/content/drive')

# Initialize the model and preprocess functions
model, preprocess_train, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')

# Define the class labels to index
class_labels_to_index = {}

# Fetch observations from iNaturalist
SPECIES = "Panthera leo"
IMAGES_DIR = "images"
OUTPUT_DIR = "/content/drive/My Drive/GradCam"
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)  # Create output directory if it doesn't exist
MAX_IMAGES = 100

response = get_observations(taxon_name=SPECIES, per_page=MAX_IMAGES, has_photos=True)
observations = Observation.from_json_list(response['results'])

# Initialize the BioCLIP classifier
classifier = TreeOfLifeClassifier()

# Define Grad-CAM class with hook removal
class GradCAM:
    def __init__(self, model):
        self.model = model
        self.model.eval()
        self.gradients = None
        self.activations = None
        self.hooks = []

    def save_gradients(self, grad):
        self.gradients = grad.detach()  # Detach gradients to save memory

    def forward(self, x):
        # Remove previous hooks to prevent accumulation
        self._remove_hooks()

        # Register hook to save activations
        def hook_fn(module, input, output):
            self.activations = output.detach()
            self.gradients = None  # Clear previous gradients

        # Register forward hooks for Conv2d layers only once
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                hook = module.register_forward_hook(hook_fn)
                self.hooks.append(hook)

        x = self.model(x)
        return x

    def generate_cam(self, class_idx):
        # Compute weights and CAM
        grads = self.gradients
        activations = self.activations
        weights = torch.mean(grads, dim=[0, 2, 3])
        cam = torch.zeros(activations.shape[2:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[0, i, :, :]
        cam = torch.clamp(cam, min=0)
        cam = cam / torch.max(cam)
        return cam

    def _remove_hooks(self):
        for hook in self.hooks:
            hook.remove()  # Remove each hook to avoid memory buildup
        self.hooks.clear()

grad_cam = GradCAM(model.visual)

# Process each observation individually
for i, observation in enumerate(observations):
    if observation.photos:
        # Save image and metadata
        path = os.path.join(IMAGES_DIR, f"{observation.id}.jpg")
        with observation.photos[0].open(size='medium') as infile:
            with open(path, 'wb') as outfile:
                outfile.write(infile.read())

        taxon = observation.taxon.name if observation.taxon else 'Unknown'

        # Load and preprocess the image to tensor
        img = Image.open(path).convert('RGB')
        tensor = preprocess_train(img).unsqueeze(0)  # Add batch dimension

        # Make prediction
        try:
            prediction = classifier.predict(path, Rank.SPECIES, k=1)[0]
            species, family = prediction['species'], prediction['family']
        except Exception as e:
            print(f"Prediction failed for {path}: {e}")
            continue

        # Automatically add new species to the index if not present
        if species not in class_labels_to_index:
            class_labels_to_index[species] = len(class_labels_to_index)

        class_idx = class_labels_to_index[species]

        # Generate Grad-CAM
        try:
            tensor.requires_grad_()
            output = grad_cam.forward(tensor)
            output[0, class_idx].backward()
            grad_cam.save_gradients(tensor.grad)
            cam = grad_cam.generate_cam(class_idx)

            # Resize CAM to match the original image size
            original_size = img.size
            cam_resized = resize(cam.detach().cpu().numpy(), original_size[::-1], mode='reflect', anti_aliasing=True)

            # Convert CAM to color map
            cam_colormap = plt.cm.plasma(cam_resized)
            cam_colormap = (cam_colormap[:, :, :3] * 255).astype(np.uint8)
            cam_colormap_img = Image.fromarray(cam_colormap).resize(original_size, Image.BILINEAR)

            # Overlay Grad-CAM on the original image
            blended = Image.blend(img.convert('RGB'), cam_colormap_img, alpha=0.5)

            # Save plots directly to PDF in Google Drive
            pdf_path = os.path.join(OUTPUT_DIR, f"{observation.id}_gradcam.pdf")
            with plt.ioff():
                plt.figure(figsize=(12, 12))

                plt.subplot(2, 2, 1)
                plt.imshow(img)
                plt.title('Image')

                plt.subplot(2, 2, 2)
                plt.imshow(cam_resized, cmap='plasma', interpolation='bilinear')
                plt.colorbar()
                plt.title('Grad-CAM (Plasma Colormap)')

                plt.subplot(2, 2, 3)
                plt.imshow(blended)
                plt.title('Image + Grad-CAM Overlay')

                plt.subplot(2, 2, 4)
                plt.axis('off')  # Hide axes
                text = f"iNaturalist Taxon: {taxon}\nBioCLIP Species: {species}\nBioCLIP Family: {family}"
                plt.text(0.5, 0.5, text, ha='center', va='center', fontsize=12, wrap=True)
                plt.title("Species Information")

                plt.savefig(pdf_path, format='pdf')
                plt.close('all')

            # Output results
            print(f"Processed image {i+1}/{MAX_IMAGES}: {path}")
            print("iNaturalist taxon:", taxon)
            print("BioCLIP species:", species)
            print("BioCLIP family:", family)
            print("\n")

            # Clear tensors and free up memory
            del tensor, output, cam, cam_resized, cam_colormap, cam_colormap_img, blended
            torch.cuda.empty_cache()  # Use this if running on GPU
            gc.collect()  # Force garbage collection to clear unused memory


        except Exception as e:
          print(f"Error processing image {path}: {e}")
          continue


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Processed image 1/100: images/249043966.jpg
iNaturalist taxon: Panthera leo
BioCLIP species: Panthera leo
BioCLIP family: Felidae


Processed image 2/100: images/249042871.jpg
iNaturalist taxon: Panthera leo leo
BioCLIP species: Cryptoprocta ferox
BioCLIP family: Eupleridae


Processed image 4/100: images/249028299.jpg
iNaturalist taxon: Panthera leo leo
BioCLIP species: Panthera leo
BioCLIP family: Felidae


Processed image 5/100: images/249023394.jpg
iNaturalist taxon: Panthera
BioCLIP species: Panthera pardus
BioCLIP family: Felidae


Processed image 6/100: images/249018218.jpg
iNaturalist taxon: Panthera leo melanochaita
BioCLIP species: Panthera leo
BioCLIP family: Felidae


Processed image 7/100: images/248965483.jpg
iNaturalist taxon: Panthera leo
BioCLIP species: Panthera leo
BioCLIP family: Felidae


Processed image 8/100: images/248952882.jpg
iNatur

KeyboardInterrupt: 