In [None]:
import os, json
import random
import tqdm
import open_clip
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import torchvision.transforms as T
from collections import OrderedDict

In [2]:
proj_path = os.path.dirname(os.getcwd())
data_root = os.path.join(proj_path, 'rare_animals_dataset')
img_dir = os.path.join(data_root, 'images')
description_dir = os.path.join(data_root, 'definitions.json')
saved_model_dir = os.path.join(proj_path, "weights", "fine-tuned", "clip_finetuned.pt")
clip_pretrained = os.path.join(proj_path, "weights", "pretrained", "biotroveclip-vit-b-16-from-openai-epoch-40.pt")

BATCH_SIZE = 2
EPOCHS = 10
LR = 2e-5
CLIP_MODEL_NAME = "ViT-B-16"

In [None]:
class CustomAnimalDataset(Dataset):
    def __init__(self, image_dir, description_file, preprocess_fn, split='train'):
        self.image_dir = image_dir
        self.preprocess_fn = preprocess_fn
        self.data_pairs = []

        with open(description_file, 'r', encoding='utf-8') as f:
            self.animal_descriptions = json.load(f)

        self.all_species_names = sorted(list(self.animal_descriptions.keys()))
        self.species_name_to_idx = {name: i for i, name in enumerate(self.all_species_names)}

        all_labels_in_folders = [d for d in os.listdir(image_dir) if os.path.isdir(os.path.join(image_dir, d))]

        for label_folder_name in all_labels_in_folders:
            label_folder_path = os.path.join(image_dir, label_folder_name)
            if label_folder_name not in self.animal_descriptions:
                continue

            english_name = label_folder_name
            details = self.animal_descriptions[english_name]
            scientific_name = details.get('scientific', '')
            description_text = details.get('description', '')
            species_prompts = []
            species_prompts.append(f"a photo of a {english_name}")
            if scientific_name:
                species_prompts.append(f"a photo of a {scientific_name}")
            if description_text:
                species_prompts.append(f"a photo of a {english_name}, {description_text}")
            species_prompts = list(set(species_prompts))
            #print (species_prompts)

            for img_name in os.listdir(label_folder_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(label_folder_path, img_name)
                    try:
                        Image.open(img_path).convert("RGB").close()
                    except Exception as e:
                        print(f"ERROR: Could not open image {img_path}: {e}. Skipping.")
                        continue
                    self.data_pairs.append((img_path, species_prompts, english_name))

        if split == 'train':
            transform_set = T.Compose(
            
            )
            self.current_data, _ = train_test_split(self.data_pairs, test_size=0.2, random_state=42)
        elif split == 'val':
            _, self.current_data = train_test_split(self.data_pairs, test_size=0.2, random_state=42)
        else:
            raise ValueError("Split must be 'train' or 'val'")

    def __len__(self):
        return len(self.current_data)

    def __getitem__(self, idx):
        img_path, text_prompts_list, english_name = self.current_data[idx]

        image = Image.open(img_path).convert("RGB")
        image = self.preprocess_fn(image)

        chosen_text_prompt = random.choice(text_prompts_list)
        text_tokens = open_clip.tokenize(chosen_text_prompt)

        ground_truth_global_idx = self.species_name_to_idx[english_name]
        return image, text_tokens, ground_truth_global_idx

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model, _, preprocess = open_clip.create_model_and_transforms(
    model_name=CLIP_MODEL_NAME,
    pretrained=None,
    device=device
)

full_checkpoint = torch.load(clip_pretrained, map_location=device, weights_only=False)
pretrained_state_dict = full_checkpoint['state_dict']

new_state_dict = OrderedDict()
for k, v in pretrained_state_dict.items():
    if k.startswith('module.'):
        new_state_dict[k[7:]] = v
    else:
        new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print("BioTrove-CLIP model loaded.")

Using device: cpu
BioTrove-CLIP model loaded.


In [5]:
temp_dataset = CustomAnimalDataset(img_dir, description_dir, preprocess, split='train')
all_species_names = temp_dataset.all_species_names
species_name_to_idx = temp_dataset.species_name_to_idx

all_eval_text_prompts = [f"a photo of a {name}" for name in all_species_names]
#print(f"Encoding {len(all_eval_text_prompts)} prompts for evaluation...")
with torch.no_grad():
    all_text_tokens_for_eval = open_clip.tokenize(all_eval_text_prompts).to(device)
    all_eval_text_features = model.encode_text(all_text_tokens_for_eval)
    all_eval_text_features = F.normalize(all_eval_text_features, p=2, dim=-1)
#print("Finished encoding all prompts for evaluation.")

train_dataset = CustomAnimalDataset(img_dir, description_dir, preprocess, split='train')
val_dataset = CustomAnimalDataset(img_dir, description_dir, preprocess, split='val')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    train_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (Train)")

    for batch_idx, (images, texts_tokens, _) in enumerate(train_bar):
        images = images.to(device)
        texts_tokens = texts_tokens.squeeze(1).to(device)

        optimizer.zero_grad()

        image_features = model.encode_image(images)
        text_features = model.encode_text(texts_tokens)

        image_features = F.normalize(image_features, p=2, dim=-1)
        text_features = F.normalize(text_features, p=2, dim=-1)

        logit_scale = model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logits_per_image.T

        labels = torch.arange(len(images), device=device) # Label for contrastive loss
        
        loss_i = F.cross_entropy(logits_per_image, labels)
        loss_t = F.cross_entropy(logits_per_text, labels)
        loss = (loss_i + loss_t) / 2

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_bar.set_postfix(loss=total_loss / (batch_idx + 1))

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0
    val_correct_top1 = 0
    val_correct_top3 = 0
    total_samples = 0
    val_bar = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (Validation)")

    with torch.no_grad():
        for images, texts_tokens_for_loss, ground_truth_indices in val_bar:
            images = images.to(device)
            texts_tokens_for_loss = texts_tokens_for_loss.squeeze(1).to(device)
            ground_truth_indices = ground_truth_indices.to(device)

            image_features = model.encode_image(images)
            image_features = F.normalize(image_features, p=2, dim=-1)

            val_text_features = model.encode_text(texts_tokens_for_loss)
            val_text_features = F.normalize(val_text_features, p=2, dim=-1)
            val_logits = model.logit_scale.exp() * image_features @ val_text_features.T
            
            labels_val = torch.arange(len(images), device=device)
            loss_i = F.cross_entropy(val_logits, labels_val)
            loss_t = F.cross_entropy(val_logits.T, labels_val)
            loss = (loss_i + loss_t) / 2
            val_loss += loss.item()

            logits_per_image_all_classes = (100.0 * image_features @ all_eval_text_features.T)
            
            num_total_species = all_eval_text_features.size(0)
            k_val = min(3, num_total_species)

            if k_val == 0:
                continue

            _, top_preds = logits_per_image_all_classes.topk(k_val, dim=1)
            
            val_correct_top1 += (top_preds[:, 0] == ground_truth_indices).sum().item()
            val_correct_top3 += (top_preds == ground_truth_indices.unsqueeze(1)).any(dim=1).sum().item()

            total_samples += images.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_acc_top1 = val_correct_top1 / total_samples if total_samples > 0 else 0
    val_acc_top3 = val_correct_top3 / total_samples if total_samples > 0 else 0

    print(f"Epoch {epoch+1} Val loss: {avg_val_loss:.4f}")
    print(f"Epoch {epoch+1} Validation Top-1 acc: {val_acc_top1:.4f}")
    print(f"Epoch {epoch+1} Validation Top-3 acc: {val_acc_top3:.4f}")

Epoch 1/10 (Train): 100%|██████████| 60/60 [03:59<00:00,  3.99s/it, loss=1.26] 


Epoch 1 Training Loss: 1.2560


Epoch 1/10 (Validation): 100%|██████████| 16/16 [00:12<00:00,  1.23it/s]


Epoch 1 Val loss: 1.3216
Epoch 1 Validation Top-1 acc: 0.2258
Epoch 1 Validation Top-3 acc: 0.5161


Epoch 2/10 (Train): 100%|██████████| 60/60 [03:29<00:00,  3.50s/it, loss=0.685]  


Epoch 2 Training Loss: 0.6852


Epoch 2/10 (Validation): 100%|██████████| 16/16 [00:15<00:00,  1.03it/s]


Epoch 2 Val loss: 0.7887
Epoch 2 Validation Top-1 acc: 0.2581
Epoch 2 Validation Top-3 acc: 0.4839


Epoch 3/10 (Train): 100%|██████████| 60/60 [15:43<00:00, 15.73s/it, loss=0.573]   


Epoch 3 Training Loss: 0.5727


Epoch 3/10 (Validation): 100%|██████████| 16/16 [00:25<00:00,  1.59s/it]


Epoch 3 Val loss: 0.8825
Epoch 3 Validation Top-1 acc: 0.2903
Epoch 3 Validation Top-3 acc: 0.4516


Epoch 4/10 (Train): 100%|██████████| 60/60 [04:31<00:00,  4.52s/it, loss=0.454]


Epoch 4 Training Loss: 0.4537


Epoch 4/10 (Validation): 100%|██████████| 16/16 [00:18<00:00,  1.14s/it]


Epoch 4 Val loss: 0.6027
Epoch 4 Validation Top-1 acc: 0.1935
Epoch 4 Validation Top-3 acc: 0.5806


Epoch 5/10 (Train): 100%|██████████| 60/60 [04:52<00:00,  4.87s/it, loss=0.257] 


Epoch 5 Training Loss: 0.2571


Epoch 5/10 (Validation): 100%|██████████| 16/16 [00:20<00:00,  1.26s/it]


Epoch 5 Val loss: 0.4382
Epoch 5 Validation Top-1 acc: 0.2258
Epoch 5 Validation Top-3 acc: 0.4194


Epoch 6/10 (Train): 100%|██████████| 60/60 [04:11<00:00,  4.19s/it, loss=0.262]  


Epoch 6 Training Loss: 0.2625


Epoch 6/10 (Validation): 100%|██████████| 16/16 [00:14<00:00,  1.10it/s]


Epoch 6 Val loss: 0.6155
Epoch 6 Validation Top-1 acc: 0.1613
Epoch 6 Validation Top-3 acc: 0.5484


Epoch 7/10 (Train): 100%|██████████| 60/60 [04:25<00:00,  4.43s/it, loss=0.594]  


Epoch 7 Training Loss: 0.5941


Epoch 7/10 (Validation): 100%|██████████| 16/16 [00:16<00:00,  1.03s/it]


Epoch 7 Val loss: 0.3642
Epoch 7 Validation Top-1 acc: 0.2903
Epoch 7 Validation Top-3 acc: 0.5161


Epoch 8/10 (Train):  65%|██████▌   | 39/60 [02:39<01:30,  4.33s/it, loss=0.117] 

In [None]:
# torch.save(model.state_dict(), SAVE_MODEL_PATH)
# print(f"Fine-tuned model saved to {SAVE_MODEL_PATH}")

Fine-tuned model saved to biotrove_clip_finetuned.pt
