### Import Libraries

In [3]:
from datasets import load_dataset, load_from_disk
from pathlib import Path
import numpy as np
from collections import defaultdict
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
import json
from collections import defaultdict

### Import data

In [2]:
# Define path
dataset_path = Path("./data")

# Check if dataset is saved to disk
if dataset_path.exists() and (dataset_path / "dataset_dict.json").exists():
    print("Loading dataset from disk...")
    dataset = load_from_disk(str(dataset_path))
    print(f"✓ Loaded {len(dataset['labelled'])} examples from disk")
    
# If not on disk, check if dataset variable exists in memory
elif 'dataset' in globals() and dataset is not None:
    print("Dataset variable found in memory, saving to disk...")
    dataset.save_to_disk(str(dataset_path))
    print(f"✓ Saved {len(dataset['labelled'])} examples to disk")
    
# Otherwise, load from HuggingFace and save
else:
    print("Loading dataset from HuggingFace...")
    dataset = load_dataset("GlobalWheat/GWFSS_v1.0")
    print(f"✓ Loaded {len(dataset['labelled'])} examples from HuggingFace")
    
    print("Saving dataset to disk...")
    dataset.save_to_disk(str(dataset_path))
    print(f"✓ Saved dataset to disk")

Loading dataset from disk...
✓ Loaded 1096 examples from disk


### Stratification

In [6]:
# Define class mapping and class counts path
class_counts_path = Path("./cache/class_counts.json")
rgb_to_class = {
    (0, 0, 0): 0,              # Background (Black)
    (214, 255, 50): 1,         # Leaf (Yellow-green)
    (50, 132, 255): 2,         # Stem (Blue)
    (50, 255, 132): 3,         # Head (Cyan-green)
}

# Load class counts if available
if class_counts_path.exists():
    with open(class_counts_path, 'r') as f:
        class_counts_dict = json.load(f)
    class_counts = defaultdict(list, {int(k): v for k, v in class_counts_dict.items()})
    print(f"✓ Loaded class counts from cache")
else:
    print("⚠ Class counts not found. Please run the EDA notebook first.")

# Define function to get the dominant class of an image
def get_dominant_class(image_idx):
    counts = [class_counts[class_id][image_idx] for class_id in range(4)]
    return np.argmax(counts)

# Create a list containing the dominant class label for each image (0-3)
stratify_labels = [get_dominant_class(i) for i in range(len(dataset["labelled"]))]
print(f"✓ Created stratification labels")

✓ Loaded class counts from cache
✓ Created stratification labels


### Splitting

To split the data into three groups (Train/Val/Test), we need to do two separate splits.

In [8]:
# Split 1: Train/ValTest (70:30)
train_indices, valtest_indices = train_test_split(
    range(len(dataset["labelled"])),
    test_size=0.3,
    stratify=stratify_labels,
    random_state=42
)

# Split 2: Val/Test (50:50)
valtest_labels = [stratify_labels[i] for i in valtest_indices]
val_indices, test_indices = train_test_split(
    valtest_indices,
    test_size=0.5,  # 50% of ValTest = 15% of total
    stratify=valtest_labels,
    random_state=42
)

# Create splits
train_split = dataset["labelled"].select(train_indices)
val_split = dataset["labelled"].select(val_indices)
test_split = dataset["labelled"].select(test_indices)

print(f"✓ Created stratified splits:")
print(f"  Train: {len(train_split)} examples ({100*len(train_split)/len(dataset['labelled']):.1f}%)")
print(f"  Val:   {len(val_split)} examples ({100*len(val_split)/len(dataset['labelled']):.1f}%)")
print(f"  Test:  {len(test_split)} examples ({100*len(test_split)/len(dataset['labelled']):.1f}%)")

✓ Created stratified splits:
  Train: 767 examples (70.0%)
  Val:   164 examples (15.0%)
  Test:  165 examples (15.1%)


In [11]:
# Define function to check dominant class distribution in a split
def check_stratification(split_name, split_indices):
    split_labels = [stratify_labels[i] for i in split_indices]
    class_names = ['Background', 'Leaf', 'Stem', 'Head']
    print(f"\n{split_name} dominant class distribution:")
    for class_id, class_name in enumerate(class_names):
        count = sum(1 for label in split_labels if label == class_id)
        pct = 100 * count / len(split_labels)
        print(f"  {class_name}: {count} ({pct:.1f}%)")

# Check the dominant class distribution of the full dataset        
check_stratification("Full Dataset", range(len(dataset["labelled"])))

# Check the dominant class distribution of each split
check_stratification("Train", train_indices)
check_stratification("Val", val_indices)
check_stratification("Test", test_indices)


Full Dataset dominant class distribution:
  Background: 301 (27.5%)
  Leaf: 772 (70.4%)
  Stem: 0 (0.0%)
  Head: 23 (2.1%)

Train dominant class distribution:
  Background: 211 (27.5%)
  Leaf: 540 (70.4%)
  Stem: 0 (0.0%)
  Head: 16 (2.1%)

Val dominant class distribution:
  Background: 45 (27.4%)
  Leaf: 116 (70.7%)
  Stem: 0 (0.0%)
  Head: 3 (1.8%)

Test dominant class distribution:
  Background: 45 (27.3%)
  Leaf: 116 (70.3%)
  Stem: 0 (0.0%)
  Head: 4 (2.4%)
