In [1]:
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader

class PuzzleDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

def load_puzzle_data(data_folder):
    X = []
    Y = []

    for filename in os.listdir(data_folder):
        if filename.endswith('.json'):
            file_path = os.path.join(data_folder, filename)
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            # Process test data
            test_input = torch.tensor(data['test'][0]['input'], dtype=torch.float32)
            test_output = torch.tensor(data['test'][0]['output'], dtype=torch.float32)
            
            # Process train data
            train_data = []
            for item in data['train']:
                input_tensor = torch.tensor(item['input'], dtype=torch.float32)
                output_tensor = torch.tensor(item['output'], dtype=torch.float32)
                train_data.append([input_tensor, output_tensor])
            
            # Create X and Y
            X.append([test_input, train_data])
            Y.append(test_output)

    return X, Y

def create_dataloaders(X, Y, batch_size=1, train_split=0.8):
    dataset = PuzzleDataset(X, Y)
    dataset_size = len(dataset)
    train_size = int(train_split * dataset_size)
    test_size = dataset_size - train_size
    
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# Usage
data_folder = os.path.join('..', 'data', 'training')
X, Y = load_puzzle_data(data_folder)
train_loader, test_loader = create_dataloaders(X, Y)

# Example of iterating through the data
for batch_X, batch_Y in train_loader:
    print("Batch X shape:", batch_X[0][0].shape)  # Shape of test_input
    print("Batch X train data length:", len(batch_X[0][1]))  # Number of train examples
    print("Batch Y shape:", batch_Y.shape)  # Shape of test_output
    break  # Just print the first batch

Batch X shape: torch.Size([10, 10])


IndexError: index 1 is out of bounds for dimension 0 with size 1