In [21]:
import os
import json
import random
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split

In [22]:
root = "../../data"
filename = "restructured_Clean_FINAL.json"

jsonFile = os.path.join(root, filename)

with open(jsonFile, "r", encoding="utf-8") as f:
    data = json.load(f)

In [23]:
len(data)

15796

In [24]:
# Defining the split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

In [25]:
# identifying invalid self-room pairs
invalid_pairs = {("kitchen", "kitchen"), ("terrace", "terrace"), ("bedroom", "bedroom"),
                 ("corridor", "corridor"), ("study room", "study room"),
                 ("store room", "store room"), ("prayer room", "prayer room"), ("balcony", "balcony")}

In [26]:
data = [entry for entry in data if tuple(sorted(entry["rooms"])) not in invalid_pairs]
len(data)

15777

In [27]:
# Group dataset by (intent, room combination, action combination)
grouped_data = defaultdict(list)
for entry in data:
    room_key = tuple(sorted(entry["rooms"]))  # Sorted tuple ensures consistency
    action_key = tuple(sorted((action["room"], action["action"]) for action in entry["actions"]))
    key = (entry["intent"], room_key, action_key)
    grouped_data[key].append(entry)

In [28]:
# Initialize split sets
train_set, val_set, test_set = [], [], []

In [29]:
# Perform stratified split while keeping multi-room commands intact
for key, samples in grouped_data.items():
    num_samples = len(samples)

    if num_samples == 1:
        print(f"⚠ Warning: Category {key} has only 1 sample. Assigning to training set.")
        train_set.extend(samples)
        continue

    if num_samples < 5:
        print(f"⚠ Warning: Category {key} has very few samples ({num_samples}). Oversampling.")
        samples.extend(random.choices(samples, k=5 - num_samples))  # Oversample to at least 5 examples

    train, temp = train_test_split(samples, test_size=(val_ratio + test_ratio), random_state=42)
    val, test = train_test_split(temp, test_size=(test_ratio / (test_ratio + val_ratio)), random_state=42)

    train_set.extend(train)
    val_set.extend(val)
    test_set.extend(test)



In [30]:
# Function to compute category distribution in the dataset
def compute_distribution(dataset, name):
    counter = Counter((entry["intent"], tuple(sorted(entry["rooms"])), 
                       tuple(sorted((action["room"], action["action"]) for action in entry["actions"]))) 
                      for entry in dataset)
    
    print(f"\n{name} set distribution:")
    for category, count in counter.items():
        print(f"  {category}: {count} samples")

In [31]:
# Display distributions to verify balance
compute_distribution(train_set, "Train")
compute_distribution(val_set, "Validation")
compute_distribution(test_set, "Test")


Train set distribution:
  ('single_room_control', ('kitchen',), (('kitchen', 'turn_off'),)): 80 samples
  ('single_room_control', ('store room',), (('store room', 'turn_on'),)): 133 samples
  ('multi_room_control', ('kitchen', 'living room'), (('kitchen', 'adjust_brightness'), ('living room', 'adjust_brightness'))): 39 samples
  ('single_room_control', ('balcony',), (('balcony', 'turn_off'),)): 112 samples
  ('single_room_control', ('study room',), (('study room', 'change_color'),)): 115 samples
  ('single_room_control', ('bedroom',), (('bedroom', 'adjust_brightness'),)): 151 samples
  ('single_room_control', ('kitchen',), (('kitchen', 'change_color'),)): 249 samples
  ('multi_room_control', ('corridor', 'terrace'), (('corridor', 'turn_on'), ('terrace', 'adjust_brightness'))): 1 samples
  ('multi_room_control', ('balcony', 'corridor'), (('balcony', 'turn_on'), ('corridor', 'turn_off'))): 4 samples
  ('single_room_control', ('prayer room',), (('prayer room', 'change_color'),)): 205 sam

In [32]:
# Save the split datasets
with open("train.json", "w", encoding="utf-8") as f:
    json.dump(train_set, f, indent=4)
with open("val.json", "w", encoding="utf-8") as f:
    json.dump(val_set, f, indent=4)
with open("test.json", "w", encoding="utf-8") as f:
    json.dump(test_set, f, indent=4)

print(f"\n✅ Dataset split complete: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test samples")


✅ Dataset split complete: 11106 train, 2404 val, 2510 test samples
