In [1]:
# imports
import os
import random
import shutil
import time
import json
import pandas as pd
import webdataset as wds
import math


In [2]:
# Fetch HF token from .env

load_dotenv()

hf_token = os.getenv("HF_TOKEN")

# Check if HF token is available, If it is available, HF will use it from the env automatically
# Send an error if not, however this can be ignored if using SSH access.
if hf_token is None:
	raise ValueError("HF_TOKEN not found in environment variables. Either set it in your .env file or ignore if you are using SSH access.")

NameError: name 'load_dotenv' is not defined

In [None]:
DATA_ROOT = Path("data/biotrove_train")
META_DIR = DATA_ROOT / "raw_metadata"
PARAQUETS_PATH = META_DIR / "BioTrove-train"
OUT_DIR = DATA_ROOT / "processed_metadata"
FILTERED_OUT = DATA_ROOT / "filtered_reptilia"
INPUT_PARQUETS = FILTERED_OUT / "merged_cases"
IMG_DIR = DATA_ROOT / "images_reptilia"
SPLITS_CSV = DATA_ROOT /  "reptilia_tar_splits.csv"
COUNTS = DATA_ROOT / "processed_metadata" / "combined_sample_counts_per_species.csv"

META_DIR.mkdir(parents=True, exist_ok=True)
OUT_DIR.mkdir(parents=True, exist_ok=True)
INPUT_PARQUETS.mkdir(parents=True, exist_ok=True)
IMG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# Download the biotrove-train metadata from huggingface
snapshot_download(
    repo_id="BGLab/BioTrove-Train",
    repo_type="dataset",
    local_dir=str(META_DIR),
    # Only download parquet files
    allow_patterns=["BioTrove-Train/*.parquet"],
)

In [None]:
# Process the parquet files into csvs
mp = MetadataProcessor(
    source_folder=str(PARAQUETS_PATH),
    destination_folder=str(OUT_DIR),
    categories=["Reptilia"],
)
mp.process_all_files()

In [None]:
# Generate shuffled chunks with filtering
# Overriding the class because their process method doesnt not properly filter out the data that we want it to.
# The vast majority of this logic is copied from the original class, the only change is adding a filter to the 
# pd.read_paraquet method. Without this, it reads all data including non-reptilia data and their capped_filtered_df
# only filters out rare cases, not rare cases and categories that are not in the species count data.
class GenShuffledChunksReptilia(GenShuffledChunks):
	def process_files(self):
		"""
		Process files based on configuration parameters. Filters rare cases,
		caps frequent cases, and shuffles the data into specified parts.
		"""
		start_time = time.time()
		
		final_counts = pd.read_csv(self.species_count_data)
		rare_case = set(final_counts[final_counts['count'] < self.rare_threshold]['species'])
		frequent_case = set(final_counts[final_counts['count'] > self.cap_threshold]['species'])

		for dir_path in [self.rare_dir, self.cap_filtered_dir_train, self.capped_dir, self.merged_dir]:
			if os.path.exists(dir_path):
				shutil.rmtree(dir_path)
			os.makedirs(dir_path, exist_ok=True)

		frequent_counts = {}
		capped_cases = []
		files = [f for f in os.listdir(self.directory) if f.endswith(".parquet")]

		for filename in tqdm(files, desc="Processing files"):
			filepath = os.path.join(self.directory, filename)
			df = pd.read_parquet(filepath, filters=[("class", "==", "Reptilia")]).dropna()

			rare_df = df[df['species'].isin(rare_case)]
			capped_filtered_df = df[~df['species'].isin(rare_case)]

			rare_df.to_parquet(os.path.join(self.rare_dir, filename), index=False)

			frequent_df = capped_filtered_df[capped_filtered_df['species'].isin(frequent_case)]
			frequent_case_counts = frequent_df['species'].value_counts().to_dict()

			for case, count in frequent_case_counts.items():
				frequent_counts[case] = frequent_counts.get(case, 0) + count
				if frequent_counts[case] > self.cap_threshold and case not in capped_cases:
					capped_cases.append(case)
					cap_case_df = frequent_df[frequent_df['species'] == case]
					cap_case_df.to_parquet(os.path.join(self.capped_dir, f'capped_{case}.parquet'), index=False)

			capped_df = capped_filtered_df[~capped_filtered_df['species'].isin(capped_cases)]
			df_shuffled = capped_df.sample(frac=1, random_state=self.random_seed).reset_index(drop=True)
			num_parts = max(1, round(len(df_shuffled) / self.part_size))
			rows_per_part = len(df_shuffled) // num_parts

			df_parts = [df_shuffled.iloc[i * rows_per_part: (i + 1) * rows_per_part] for i in range(num_parts)]

			if len(df_shuffled) % num_parts != 0:
				df_parts[-1] = pd.concat([df_parts[-1], df_shuffled.iloc[num_parts * rows_per_part:]], ignore_index=True)

			base_filename, _ = os.path.splitext(filename)
			for i, part in enumerate(df_parts):
				cap_filtered_filepath = os.path.join(self.cap_filtered_dir_train, f'{base_filename}_part{i+1}.parquet')
				part.to_parquet(cap_filtered_filepath, index=False)

		self.merge_shuffled_files()
		elapsed_time = time.time() - start_time
		print(f"Processing completed in {elapsed_time:.2f} seconds.")

In [None]:
# Generate shuffled chunks with filtering
# - cap species that have more than 50 samples, to only 50 samples.
# and if a species has less than 1000 samples, it will be dropped.
# Without the override above, the process method does not filter properly and generates the shuffled chunks
# for all categories, resulting in over 100m rows and images to download, with it we only get 10k rows.
# processing also goes from ~5 minutes to 15 seconds.
gen = GenShuffledChunksReptilia(
    species_count_data=COUNTS,
    directory=PARAQUETS_PATH,
    rare_threshold=1000,          # drop species with <1000 samples
    cap_threshold=50,         # cap species above this count
    part_size=500,              # rows per shuffled part
    rare_dir=str(FILTERED_OUT / "rare_cases"),
    cap_filtered_dir_train=str(FILTERED_OUT / "cap_filtered_train"),
    capped_dir=str(FILTERED_OUT / "capped_cases"),
    merged_dir=str(FILTERED_OUT / "merged_cases"),
    files_per_chunk=80,
    random_seed=521,
)
gen.process_files()

In [None]:
# this cell is just for testing and visualization to see how many rows we have in the filtered parquet files.
from pathlib import Path
import pyarrow.parquet as pq 

parquet_dir = INPUT_PARQUETS 
total_rows = 0
files = sorted(Path(parquet_dir).glob("*.parquet"))
for f in files:
    total_rows += pq.ParquetFile(f).metadata.num_rows

print(f"Parquet files: {len(files)}, total rows: {total_rows}")

In [None]:
# Took about 1 minute and 15 seconds to download 10k images
gi = GetImages(
    INPUT_PARQUETS,
    output_folder=str(IMG_DIR),
    concurrent_downloads=1000,
)
await gi.download_images()

In [None]:
textgen = GenImgTxtPair(
    INPUT_PARQUETS,
    img_folder= IMG_DIR,
    generate_tar=True,
)

textgen.create_image_text_pairs()

In [None]:
min_images_per_class = 20
split_ratios = {"train": 0.8, "val": 0.1, "test": 0.1} # 80/10/10 split

train_p = split_ratios["train"]
val_p = split_ratios["val"]

# Check available images before splitting to ensure the data stays balanced
available_ids = {int(p.stem) for p in IMG_DIR.rglob("*.jpg")}
print(f"Images found: {len(available_ids)}")

dfs = []
for p in Path(INPUT_PARQUETS).glob("*.parquet"):
    df = pd.read_parquet(p, columns=["photo_id", "scientificName", "photo_url"])
    dfs.append(df)
    

meta = pd.concat(dfs, ignore_index=True).dropna(subset=["photo_id", "scientificName", "photo_url"])

meta["photo_id"] = meta["photo_id"].astype(int)
meta = meta[meta["photo_id"].isin(available_ids)].drop_duplicates(subset=["photo_id"], keep=False)

print(f"Rows with images: {len(meta["photo_id"].unique())}")

class_counts = meta["scientificName"].value_counts()
keep_classes = set(class_counts[class_counts >= min_images_per_class].index)
print(f"Classes kept (>= {min_images_per_class} images): {len(keep_classes)}")
print(f"Classes dropped: {len(class_counts) - len(keep_classes)}")


meta = meta[meta["scientificName"].isin(keep_classes)].copy()
print(f"Rows after class filter: {len(meta)}")

rng = random.Random(521)
rows = []
for label, g in meta.groupby("scientificName"):
    ids = g["photo_id"].tolist()
    rng.shuffle(ids)
    n = len(ids)
    train_n = int(n * split_ratios["train"])
    val_n = int((n - train_n) // 2)
    # val_n = int(n * split_ratios["val"])
    for pid in ids[:train_n]:
        rows.append((pid, "train"))
    for pid in ids[train_n:train_n + val_n]:
        rows.append((pid, "test"))
    for pid in ids[train_n + val_n:]:
        rows.append((pid, "val"))
        

split_df = pd.DataFrame(rows, columns=["photo_id", "split"])
split_lookup = dict(zip(split_df.photo_id, split_df.split))
print(split_df["split"].value_counts())

meta["split"] = meta["photo_id"].map(split_lookup)
print("Rows per split:\n", meta["split"].value_counts())

per_class = meta.groupby(["scientificName", "split"]).size().unstack(fill_value=0)
print("\nPer-class counts (train/val/test) head:\n", per_class.head())

# Simple imbalance stats
train_counts = per_class["train"]
val_counts   = per_class["val"]
test_counts  = per_class["test"]
print(f"\nTrain min/median/max: {train_counts.min()} / {train_counts.median()} / {train_counts.max()}")
print(f"Val   min/median/max: {val_counts.min()} / {val_counts.median()} / {val_counts.max()}")
print(f"Test  min/median/max: {test_counts.min()} / {test_counts.median()} / {test_counts.max()}")