In [18]:
# Import statements
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset

In [19]:
# Paths to image and metadata
# NOTE: Add your file paths for images AND csv below
img_dir = "[YOUR IMAGE DIRECTORY PATH GOES HERE]"    
csv_path = "[YOUR METADATA CSV PATH GOES HERE]"        

# Load metadata + sanity check assertions
df_meta = pd.read_csv(csv_path)
assert 'hash_id' in df_meta.columns, "CSV missing 'hash_id' column"
assert 'family' in df_meta.columns, "CSV missing 'family' column"

# Inspect metadata
print("Metadata preview:")
print(df_meta.head())

# Map images to csv filepaths
image_dir = Path(img_dir)
allowed_exts = {'.jpg', '.jpeg', '.png'}

# Index all images once (recursive, meaning: look into all directories and subdirectories)
all_paths = [p for p in image_dir.rglob("*") if p.suffix.lower() in allowed_exts]
print(f"Indexed {len(all_paths):,} image files")

# Build stem to list(paths) mapping
# Defaultdict(list) creates an empty list for new keys and '.append' adds the object 'p' as a string in that list for every path in our list 'all_paths'
stem_to_paths = defaultdict(list)
for p in all_paths:
    stem_to_paths[p.stem].append(str(p)) # p.stem is filepath without the extension (without jpeg, jpg, or png)

# Create a simple first-match dict (stem to first path)
# If images share a hashid, the first image "found" and connected to a hashid is kept
first_match = {stem: paths[0] for stem, paths in stem_to_paths.items()}

# Ensure hash_id is a string and has no whitespace by creating 'hash_id_str' column and adding to df_meta
df_meta['hash_id_str'] = df_meta['hash_id'].astype(str).str.strip()
# Create column for path for each hash_id and add to df_meta
df_meta['path'] = df_meta['hash_id_str'].map(first_match)  # NaN for missing
# show df_meta
print("Updated df_meta:")
print(df_meta.head())

# Create objects for: found metdata successfully linked image paths to hashids & object for unsucessfuly linked hashids
df_meta_found = df_meta[df_meta['path'].notna()].reset_index(drop=True)
hashid_to_path = dict(zip(df_meta_found['hash_id_str'], df_meta_found['path']))
missing_hashids = df_meta[df_meta['path'].isna()]['hash_id_str'].tolist()

# Print numbers of found and missing 
print(f"\nFound {len(hashid_to_path):,} matches; {len(missing_hashids):,} missing.")
if missing_hashids:
    print("Example missing ids:", missing_hashids[:5])

Metadata preview:
                    hash_id       family
0  223m6ywujk3htx2s3kfqx7ee  Acanthaceae
1  2aba7w224g4tso44mtzpnizg  Acanthaceae
2  2dovrj4uex7apou4zyu7nau7  Acanthaceae
3  2f53p6wsfhsnik2sy3jxn2ok  Acanthaceae
4  2fvqsa7ldatavhuevcvia5lm  Acanthaceae
Indexed 49,633 image files
Updated df_meta:
                    hash_id       family               hash_id_str  \
0  223m6ywujk3htx2s3kfqx7ee  Acanthaceae  223m6ywujk3htx2s3kfqx7ee   
1  2aba7w224g4tso44mtzpnizg  Acanthaceae  2aba7w224g4tso44mtzpnizg   
2  2dovrj4uex7apou4zyu7nau7  Acanthaceae  2dovrj4uex7apou4zyu7nau7   
3  2f53p6wsfhsnik2sy3jxn2ok  Acanthaceae  2f53p6wsfhsnik2sy3jxn2ok   
4  2fvqsa7ldatavhuevcvia5lm  Acanthaceae  2fvqsa7ldatavhuevcvia5lm   

                                                path  
0  C:\Users\mlasz\OneDrive\Desktop\mlm_25\biotrov...  
1  C:\Users\mlasz\OneDrive\Desktop\mlm_25\biotrov...  
2  C:\Users\mlasz\OneDrive\Desktop\mlm_25\biotrov...  
3  C:\Users\mlasz\OneDrive\Desktop\mlm_25\biotrov..

In [23]:
# Dataset and dataloader

# Image transforms - catered to ImageNet input data since ResNet models use pretrained ImageNet weights
img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Dataset
class ClusteringBioTroveDataset(Dataset):
    def __init__(self, mapping, transform=None):
        self.items = list(mapping.items())
        self.transform = transform

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        hid, path = self.items[idx]
        with Image.open(path) as img:
            img = img.convert('RGB')
            if self.transform:
                img = self.transform(img)
        return hid, img

dataset = ClusteringBioTroveDataset(hashid_to_path, img_transform)
# Adjust batch_size and num_workers as needed
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)