In [None]:
import json
import random
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split

# Load dataset
with open("dataset.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# Define split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Identify and remove invalid self-room pairs
invalid_pairs = {("kitchen", "kitchen"), ("terrace", "terrace"), ("bedroom", "bedroom"),
                 ("corridor", "corridor"), ("study room", "study room"),
                 ("store room", "store room"), ("prayer room", "prayer room")}

data = [entry for entry in data if tuple(sorted(entry["rooms"])) not in invalid_pairs]

# Group dataset by (intent, room combination, action combination)
grouped_data = defaultdict(list)
for entry in data:
    room_key = tuple(sorted(entry["rooms"]))  # Sorted tuple ensures consistency
    action_key = tuple(sorted((action["room"], action["action"]) for action in entry["actions"]))
    key = (entry["intent"], room_key, action_key)
    grouped_data[key].append(entry)

# Initialize split sets
train_set, val_set, test_set = [], [], []

# Perform stratified split while keeping multi-room commands intact
for key, samples in grouped_data.items():
    num_samples = len(samples)

    if num_samples == 1:
        print(f"⚠ Warning: Category {key} has only 1 sample. Assigning to training set.")
        train_set.extend(samples)
        continue

    if num_samples < 5:
        print(f"⚠ Warning: Category {key} has very few samples ({num_samples}). Oversampling.")
        samples.extend(random.choices(samples, k=5 - num_samples))  # Oversample to at least 5 examples

    train, temp = train_test_split(samples, test_size=(val_ratio + test_ratio), random_state=42)
    val, test = train_test_split(temp, test_size=(test_ratio / (test_ratio + val_ratio)), random_state=42)

    train_set.extend(train)
    val_set.extend(val)
    test_set.extend(test)

# Function to compute category distribution in the dataset
def compute_distribution(dataset, name):
    counter = Counter((entry["intent"], tuple(sorted(entry["rooms"])), 
                       tuple(sorted((action["room"], action["action"]) for action in entry["actions"]))) 
                      for entry in dataset)
    
    print(f"\n{name} set distribution:")
    for category, count in counter.items():
        print(f"  {category}: {count} samples")

# Display distributions to verify balance
compute_distribution(train_set, "Train")
compute_distribution(val_set, "Validation")
compute_distribution(test_set, "Test")

# Save the split datasets
with open("train.json", "w", encoding="utf-8") as f:
    json.dump(train_set, f, indent=4)
with open("val.json", "w", encoding="utf-8") as f:
    json.dump(val_set, f, indent=4)
with open("test.json", "w", encoding="utf-8") as f:
    json.dump(test_set, f, indent=4)

print(f"\n✅ Dataset split complete: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test samples")
