# Dataset Analysis and Train/Test Split

This notebook analyzes the augmented dataset and creates a reproducible train/test split for CLIP fine-tuning experiments.

**Objectives:**
1. Load the augmented dataset from MinIO (`train_pairs_augmented_with_negatives.csv`)
2. Filter only positive pairs (label=1) for training/evaluation
3. Create train/test split by `recipe_id` (no data leakage)
4. Generate dataset statistics and verification report
5. Save split manifests to MinIO for reproducibility

**Outputs:**
- `fine-tuning-zone/datasets/train_manifest.csv` â€” Training pairs
- `fine-tuning-zone/datasets/test_manifest.csv` â€” Test pairs
- `fine-tuning-zone/datasets/dataset_report.json` â€” Statistics and metadata


## 1. Setup and Configuration


In [None]:
import os
import io
import json
import hashlib
from pathlib import Path
from typing import Dict, List, Set, Tuple
from datetime import datetime

import numpy as np
import pandas as pd
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from dotenv import load_dotenv
from sklearn.model_selection import train_test_split

# Load environment variables
NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = NOTEBOOK_DIR.parent.parent

ENV_PATHS = [
    PROJECT_ROOT / "notebooks" / ".env",
    PROJECT_ROOT / "app" / ".env",
    PROJECT_ROOT / ".env",
]

env_loaded = False
for env_path in ENV_PATHS:
    if env_path.exists():
        load_dotenv(env_path)
        print(f"âœ“ Loaded .env from: {env_path}")
        env_loaded = True
        break

if not env_loaded:
    print("âš  No .env file found, trying default load_dotenv()...")
    load_dotenv()

# MinIO Configuration
MINIO_USER = os.getenv("MINIO_USER")
MINIO_PASSWORD = os.getenv("MINIO_PASSWORD")
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT")

# Bucket configuration
FINE_TUNING_BUCKET = "fine-tuning-zone"
DATASETS_PREFIX = "datasets"

# Input/Output paths
INPUT_DATASET_KEY = f"{DATASETS_PREFIX}/train_pairs_augmented_with_negatives.csv"
TRAIN_MANIFEST_KEY = f"{DATASETS_PREFIX}/train_manifest.csv"
TEST_MANIFEST_KEY = f"{DATASETS_PREFIX}/test_manifest.csv"
REPORT_KEY = f"{DATASETS_PREFIX}/dataset_report.json"

# Split configuration
TEST_SIZE = 0.2  # 80% train, 20% test
RANDOM_SEED = 42  # Fixed seed for reproducibility
SPLIT_BY_RECIPE = True  # Split by recipe_id to avoid leakage

print(f"Configuration:")
print(f"  MinIO Endpoint: {MINIO_ENDPOINT}")
print(f"  Fine-tuning Bucket: {FINE_TUNING_BUCKET}")
print(f"  Test Size: {TEST_SIZE * 100:.0f}%")
print(f"  Random Seed: {RANDOM_SEED}")
print(f"  Split by Recipe ID: {SPLIT_BY_RECIPE}")


## 2. Initialize MinIO Client


In [None]:
# Initialize S3/MinIO client
session = boto3.session.Session(
    aws_access_key_id=MINIO_USER,
    aws_secret_access_key=MINIO_PASSWORD,
    region_name="us-east-1"
)
s3 = session.client(
    "s3",
    endpoint_url=MINIO_ENDPOINT,
    config=Config(signature_version="s3v4", s3={"addressing_style": "path"})
)

def ensure_bucket_exists(bucket: str) -> bool:
    """Create bucket if it doesn't exist."""
    try:
        s3.head_bucket(Bucket=bucket)
        return True
    except ClientError as e:
        error_code = e.response.get("Error", {}).get("Code", "")
        if error_code in ("404", "NoSuchBucket"):
            try:
                s3.create_bucket(Bucket=bucket)
                print(f"âœ“ Created bucket '{bucket}'")
                return True
            except ClientError as create_error:
                print(f"âœ— Failed to create bucket '{bucket}': {create_error}")
                return False
        return False

# Verify buckets
print("Checking buckets...")
ensure_bucket_exists(FINE_TUNING_BUCKET)
print("âœ“ Buckets ready")


## 3. Load and Filter Dataset


In [None]:
def load_csv_from_minio(bucket: str, key: str) -> pd.DataFrame:
    """Load CSV file from MinIO into a DataFrame."""
    try:
        obj = s3.get_object(Bucket=bucket, Key=key)
        df = pd.read_csv(io.BytesIO(obj["Body"].read()))
        print(f"âœ“ Loaded {len(df)} rows from s3://{bucket}/{key}")
        return df
    except ClientError as e:
        print(f"âœ— Failed to load s3://{bucket}/{key}: {e}")
        return pd.DataFrame()

# Load augmented dataset
print("Loading augmented dataset from MinIO...")
full_df = load_csv_from_minio(FINE_TUNING_BUCKET, INPUT_DATASET_KEY)

if full_df.empty:
    raise RuntimeError(f"Could not load dataset from s3://{FINE_TUNING_BUCKET}/{INPUT_DATASET_KEY}")

print(f"\nFull dataset shape: {full_df.shape}")
print(f"Columns: {list(full_df.columns)}")
print(f"\nLabel distribution:")
print(full_df["label"].value_counts())

# Filter only positive pairs (label=1) for training/evaluation
positive_df = full_df[full_df["label"] == 1].copy()
print(f"\nâœ“ Filtered to {len(positive_df)} positive pairs (label=1)")
print(f"  Removed {len(full_df) - len(positive_df)} negative pairs")

print(f"\nPositive pairs preview:")
display(positive_df.head())


## 4. Train/Test Split by Recipe ID

**Critical:** We split by `recipe_id` to ensure no data leakage. All images and captions from the same recipe stay together in either train or test.


In [None]:
def split_by_recipe_id(df: pd.DataFrame, test_size: float, random_seed: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Split dataset by recipe_id to avoid data leakage.
    
    All pairs from the same recipe_id go to the same split.
    """
    # Get unique recipe IDs
    unique_recipes = df["recipe_id"].unique()
    n_recipes = len(unique_recipes)
    
    # Split recipe IDs
    train_recipe_ids, test_recipe_ids = train_test_split(
        unique_recipes,
        test_size=test_size,
        random_state=random_seed
    )
    
    # Create train and test DataFrames
    train_df = df[df["recipe_id"].isin(train_recipe_ids)].copy()
    test_df = df[df["recipe_id"].isin(test_recipe_ids)].copy()
    
    return train_df, test_df, train_recipe_ids, test_recipe_ids

# Perform split
print("Creating train/test split by recipe_id...")
train_df, test_df, train_recipe_ids, test_recipe_ids = split_by_recipe_id(
    positive_df,
    test_size=TEST_SIZE,
    random_seed=RANDOM_SEED
)

print(f"\nâœ“ Split complete:")
print(f"  Train recipes: {len(train_recipe_ids)}")
print(f"  Test recipes: {len(test_recipe_ids)}")
print(f"  Train pairs: {len(train_df)}")
print(f"  Test pairs: {len(test_df)}")

# Verify no leakage
train_recipe_set = set(train_recipe_ids)
test_recipe_set = set(test_recipe_ids)
overlap = train_recipe_set & test_recipe_set

if overlap:
    print(f"\nâš  WARNING: Found {len(overlap)} overlapping recipe IDs between train and test!")
else:
    print(f"\nâœ“ No data leakage: train and test recipe sets are disjoint")

print(f"\nTrain set preview:")
display(train_df.head())
print(f"\nTest set preview:")
display(test_df.head())


## 5. Dataset Statistics


In [None]:
def compute_dataset_stats(df: pd.DataFrame, split_name: str) -> Dict:
    """Compute comprehensive statistics for a dataset split."""
    stats = {
        "split": split_name,
        "total_pairs": len(df),
        "unique_recipes": df["recipe_id"].nunique(),
        "unique_images": df["image_key"].nunique(),
        "unique_captions": df["caption"].nunique(),
    }
    
    # Images per recipe
    images_per_recipe = df.groupby("recipe_id")["image_key"].nunique()
    stats["images_per_recipe"] = {
        "mean": float(images_per_recipe.mean()),
        "median": float(images_per_recipe.median()),
        "min": int(images_per_recipe.min()),
        "max": int(images_per_recipe.max()),
        "std": float(images_per_recipe.std())
    }
    
    # Captions per recipe
    captions_per_recipe = df.groupby("recipe_id")["caption"].nunique()
    stats["captions_per_recipe"] = {
        "mean": float(captions_per_recipe.mean()),
        "median": float(captions_per_recipe.median()),
        "min": int(captions_per_recipe.min()),
        "max": int(captions_per_recipe.max()),
        "std": float(captions_per_recipe.std())
    }
    
    # Caption length distribution
    caption_lengths = df["caption"].str.len()
    stats["caption_length"] = {
        "mean": float(caption_lengths.mean()),
        "median": float(caption_lengths.median()),
        "min": int(caption_lengths.min()),
        "max": int(caption_lengths.max()),
        "std": float(caption_lengths.std())
    }
    
    return stats

# Compute statistics
train_stats = compute_dataset_stats(train_df, "train")
test_stats = compute_dataset_stats(test_df, "test")

print("=" * 60)
print("Dataset Statistics")
print("=" * 60)

print(f"\nðŸ“Š TRAIN SET:")
print(f"  Total pairs: {train_stats['total_pairs']}")
print(f"  Unique recipes: {train_stats['unique_recipes']}")
print(f"  Unique images: {train_stats['unique_images']}")
print(f"  Unique captions: {train_stats['unique_captions']}")
print(f"  Images per recipe: {train_stats['images_per_recipe']['mean']:.2f} (mean), {train_stats['images_per_recipe']['median']:.1f} (median)")
print(f"  Caption length: {train_stats['caption_length']['mean']:.1f} chars (mean), {train_stats['caption_length']['median']:.1f} (median)")

print(f"\nðŸ“Š TEST SET:")
print(f"  Total pairs: {test_stats['total_pairs']}")
print(f"  Unique recipes: {test_stats['unique_recipes']}")
print(f"  Unique images: {test_stats['unique_images']}")
print(f"  Unique captions: {test_stats['unique_captions']}")
print(f"  Images per recipe: {test_stats['images_per_recipe']['mean']:.2f} (mean), {test_stats['images_per_recipe']['median']:.1f} (median)")
print(f"  Caption length: {test_stats['caption_length']['mean']:.1f} chars (mean), {test_stats['caption_length']['median']:.1f} (median)")

# Create full report
dataset_report = {
    "metadata": {
        "created_at": datetime.utcnow().isoformat() + "Z",
        "random_seed": RANDOM_SEED,
        "test_size": TEST_SIZE,
        "split_method": "by_recipe_id",
        "source_dataset": INPUT_DATASET_KEY,
        "filter_applied": "label == 1 (positive pairs only)"
    },
    "train": train_stats,
    "test": test_stats,
    "verification": {
        "no_leakage": len(set(train_recipe_ids) & set(test_recipe_ids)) == 0,
        "train_recipe_count": len(train_recipe_ids),
        "test_recipe_count": len(test_recipe_ids)
    }
}

print(f"\nâœ“ Dataset report generated")


In [None]:
def save_csv_to_minio(df: pd.DataFrame, bucket: str, key: str) -> bool:
    """Save DataFrame as CSV to MinIO."""
    try:
        csv_buffer = io.StringIO()
        df.to_csv(csv_buffer, index=False, encoding="utf-8")
        csv_bytes = csv_buffer.getvalue().encode("utf-8")
        
        s3.put_object(
            Bucket=bucket,
            Key=key,
            Body=csv_bytes,
            ContentType="text/csv",
            Metadata={
                "rows": str(len(df)),
                "random_seed": str(RANDOM_SEED),
            }
        )
        
        size_kb = len(csv_bytes) / 1024
        print(f"âœ“ Saved to s3://{bucket}/{key} ({size_kb:.1f} KB)")
        return True
    except Exception as e:
        print(f"âœ— Failed to save: {e}")
        return False

def save_json_to_minio(data: Dict, bucket: str, key: str) -> bool:
    """Save dictionary as JSON to MinIO."""
    try:
        json_bytes = json.dumps(data, indent=2).encode("utf-8")
        
        s3.put_object(
            Bucket=bucket,
            Key=key,
            Body=json_bytes,
            ContentType="application/json",
        )
        
        size_kb = len(json_bytes) / 1024
        print(f"âœ“ Saved to s3://{bucket}/{key} ({size_kb:.1f} KB)")
        return True
    except Exception as e:
        print(f"âœ— Failed to save: {e}")
        return False

# Save manifests and report
print("=" * 60)
print("Saving to MinIO")
print("=" * 60)

# Select only necessary columns for manifests
manifest_columns = ["recipe_id", "image_key", "caption"]

save_csv_to_minio(train_df[manifest_columns], FINE_TUNING_BUCKET, TRAIN_MANIFEST_KEY)
save_csv_to_minio(test_df[manifest_columns], FINE_TUNING_BUCKET, TEST_MANIFEST_KEY)
save_json_to_minio(dataset_report, FINE_TUNING_BUCKET, REPORT_KEY)

print(f"\nâœ… All files saved successfully!")
print(f"  Train manifest: s3://{FINE_TUNING_BUCKET}/{TRAIN_MANIFEST_KEY}")
print(f"  Test manifest: s3://{FINE_TUNING_BUCKET}/{TEST_MANIFEST_KEY}")
print(f"  Dataset report: s3://{FINE_TUNING_BUCKET}/{REPORT_KEY}")
