In [None]:
# change paths to files in blocks 3 and 5 before use

In [None]:
import os
import json
import pandas as pd
import random
import numpy as np

In [None]:
# Get all JSON files
data_dir = "/path/to/folder/with/json/files/"
json_files = [f for f in os.listdir(data_dir) if f.endswith(".json")]

# Load data from JSON files and store as DataFrame
data = []
for filename in json_files:
    with open(os.path.join(data_dir, filename), "r") as f:
        data.append(json.load(f))
df = pd.DataFrame(data)

In [None]:
def stratsplit(df, column, proportions, shuffle=True, return_indices=False, check=False):
    """
    Splits a DataFrame into n non-overlapping sets with specified proportions for a column. Handles some edge cases.

    df: DataFrame to split (panda df)
    column: Column to stratify by (str)
    proportions: List of proportions for each set (list of floats <1 each), length equal to number of sets to split into
    shuffle: Whether to shuffle the data after splitting (bool, default True)
    return_indices: Whether to return the indices of the data points in the original df that were assigned to each set (bool, default False)
    check: Whether to check if everything is consistent and provide numerical comparisions as proof (bool, default False)

    Returns a list of DataFrames, one for each set.
    """
    setcount = int(len(proportions))
    
    # For trying to split the df into 0 sets...
    if setcount == 0:
        return pd.DataFrame()

    # For trying to split the df into 1 set...
    if setcount == 1:
        return df
    
    # Suggest alternative scikit-learn functions if the number of sets is 2
    if setcount == 2:
        print("Consider using train_test_split (not stratified by default), StratifiedShuffleSplit (potential overlap in test sets), or even StratifiedKFold from scikit-learn instead.")
    
    # Check if proportions are valid for the number of sets
    if sum([element * 100 for element in proportions]) != 100:
        raise ValueError("Proportions list must be floats that sum to 1.")

    # Check if column name is a string
    if not isinstance(column, str):
        raise ValueError("Column name must be a string (enclosed in speech marks).")
    
    # List of set sizes based on proportions (they will add up to len(df) because not division so no rounding inaccuracies)
    set_sizes = [ int(prop * len(df)) for prop in proportions ] # ...needs to be adjusted if len(df) is small (otherwise eg. 19.417% of 70 is 13.59)

    index_collection = [[] for _ in df[column].value_counts().index] # initialise a list of empty lists, one for each unique value in the stratify column
    unique_proportions = [[] for _ in df[column].value_counts().index]

    for j in range(len(df[column].value_counts().index)): # for each unique value in the stratify column
        current = df[column].value_counts().index[j]
        index_collection[j] = np.array(df[ df[column] == current ].index) # collect all of the data point indices (in df) for that current .index[] in a numpy array

         # Proportion of data points with this unique value in stratify column
        unique_proportions[j] = df.value_counts(column).values[j] / len(df)
    
    # For each set, calculate the number of data points for each unique value in the stratify column
    unique_per_set = np.zeros((len(df[column].value_counts().index), setcount)) # initialise a 2D array of zeroes, with dimensions of the number of unique values in the stratify column and the number of sets
    for i in range(setcount):
        for j in range(len(df[column].value_counts().index)):
            unique_per_set[j][i] = int(unique_proportions[j] * set_sizes[i])

    # Adjust sizes of unique_per_set to match df[column].value_counts().values if needed (because most important is all data points are used without repetition, then proportions, THEN set_sizes are least important as they are flexible)
    chosen_add, chosen_remove = [], [] # lists of indices of sets that have been previously edited and so should get less priority in future edits
    for j in range(len(df[column].value_counts().index)):

        while sum(unique_per_set[j,:]) < df.value_counts(column).values[j]: # i.e. there are data points (leftovers) with unique value [j] still to be assigned (explanation for code in > case down below) ...maybe they should be assigned cyclically through the sets but not sure how to choose starting point for each j (because if they all start from the same set and loop through the sets in the same order, then the first one will almost always be slightly bigger than necessary - try distributing 4 red candies, 9 blue candies, and 14 green candies into 5 groups while always starting from group 1 haha)
            # ...problem is this deals with leftover data points one by one - maybe should add some way to check if the amount of leftovers for some unique value (from column) is divisible by number of sets, if yes then distribute some amount x equally to all before using the following code to distribute the remainder
            prop_diff = [[] for _ in range(setcount)]
            for i in range(setcount):
                prop_diff[i] = ( unique_per_set[j,i] / sum(unique_per_set[:,i]) ) - unique_proportions[j] # ...could potentially set up something with abs(difference) for redistribution instead of just adding
            min_index = [n for n, value in enumerate(prop_diff) if value == min(prop_diff)]
            if len(min_index) == 1:
                min_index = min_index[0]
                if min_index not in chosen_add:
                    chosen_add.append(min_index)
            else:
                min_index_temp = [k for k in min_index if k not in chosen_add]
                if min_index_temp == []:
                    chosen_add = [add_i for add_i in chosen_add if add_i not in min_index]
                    min_index = random.choice(min_index)
                    chosen_add.append(min_index)
                else:
                    min_index = random.choice(min_index_temp)
                    chosen_add.append(min_index)
            unique_per_set[j,min_index] += 1

        # I actually think the following case should never happen because all sizes are calculated using some proportions and int(), which "rounds" down, but if the above while loop somehow messes with that, then hopefully this will fix things
        while sum(unique_per_set[j,:]) > df.value_counts(column).values[j]:
            # remove data point from the set where the local prop of data with unique value j is most different from the global prop of data with unique value j in df
            prop_diff = [[] for _ in range(setcount)] # needs to be recalculated after every removal/addition and the while loops will keep going until equality
            for i in range(setcount):
                prop_diff[i] = ( unique_per_set[j,i] / sum(unique_per_set[:,i]) ) - unique_proportions[j]
            max_index = [n for n, value in enumerate(prop_diff) if value == max(prop_diff)] # ...can potentially set up something to redistribute data points AFTER this
            if len(max_index) == 1:
                    max_index = max_index[0]
                    if max_index not in chosen_remove:
                        chosen_remove.append(max_index)
            else: # ...but is this logical to do even if all prop_diff are negative (i.e. none have reached the desired proportions)?
                max_index_temp = [k for k in max_index if k not in chosen_remove] # check the elements of max_index that are NOT in chosen_remove
                if max_index_temp == []:
                    chosen_remove = [rmv_i for rmv_i in chosen_remove if rmv_i not in max_index] # we remove ONLY the ones already in max_index
                    max_index = random.choice(max_index)
                    chosen_remove.append(max_index)
                else:
                    max_index = random.choice(max_index_temp)
                    chosen_remove.append(max_index)
            unique_per_set[j,max_index] -= 1
        
        if sum(unique_per_set[j,:]) == df.value_counts(column).values[j]:
            break
    
    # ...could check if proportions in each set for the unique values are not too different from the global proportions, but how to reshuffle them if they aren't? maybe run for loop multiple times (different orders of j... but how to determine order?) and check proportions after for loop, then choose distribution with the smallest difference overall (how to optimise this?) from global proportions

    # Shuffle the indices of data points in df with some unique value [j]
    # faster to deal with indices than with the df itself, note: index_collection order is the same as df[column].value_counts().index order (which may be different to df.sort_values(by=column))
    for j in range(len(df[column].value_counts().index)):
        np.random.shuffle(index_collection[j])
        np.random.shuffle(index_collection[j])
    
    # Create list of lists of indices to be assigned to each set
    split_indices = [[] for _ in range(setcount)]

    # Add unique_per_set[j,i] number of elements (the integer indices) from index_collection[j] to split_indices[i]
    for i in range(setcount): # ...haven't checked if looping through j first (and accessing [j,:]) will be faster
        for j in range(len(df[column].value_counts().index)):
            split_indices[i].extend(index_collection[j][:int(unique_per_set[j,i])]) # neither append nor concatenate work (they changed element types from numpy.int64 to numpy.float64) ...why?
            index_collection[j] = index_collection[j][int(unique_per_set[j,i]):] # delete the "taken" elements
    
    # Check length of sums of splits equals to len(df) i.e. if sizes are consistent ...but maybe better to check that every data point appears in split_indices and only once each
    if sum([len(split_indices[i]) for i in range(setcount)]) != len(df):
        raise ValueError("Splitting unsuccessful, check code.") # ...could potentially return the original df instead of raising an error

    # Shuffle the new dfs if shuffle is True
    if shuffle:
        for i in range(setcount):
            np.random.shuffle(split_indices[i]) # ...can also shuffle the dfs using .sample(frac=1) but the indices will not be reset ("in order")
    
    # Create list of new dfs (the split sets)
    split = [pd.DataFrame() for _ in range(setcount)]

    # Assign data points corresponding to the indices in split_indices to the new dfs
    for i in range(setcount):
        split[i] = df.loc[split_indices[i],:]
        if shuffle:
            split[i].reset_index(drop=True, inplace=True)

    if check:
        checklist = [] # ...couldn't get it to print \n using vscode

        # DF length consistency check
        if sum([len(s) for s in split]) != len(df):
            checklist.append(f"The lengths are NOT consistent; the length of df is {len(df)} and the sum of the lengths of the split sets is {sum([len(s) for s in split])}. ")
        else:
            checklist.append(f"The lengths ARE consistent; both df and the sum of the lengths of the split sets have {len(df)} data points. ")

        # Stratify column length consistency check
        sumcheck = [0 for _ in range(df.value_counts(column).index.size)]
        for i in range(setcount):
            for j in range(df.value_counts(column).index.size): # sumcheck will use the order of .index in the original df
                try:
                    j_split = split[i].value_counts(column).index.get_loc(df.value_counts(column).index[j])
                except ValueError:
                    checklist.append(f"The category {df.value_counts(column).index[j]} is not present in set {i}. ")
                    break
                try:
                    sumcheck[j_split] += split[i].value_counts(column).values[j]
                except ValueError:
                    checklist.append(f"There are less unique values from {column} in split set {i} than in the original data set. ")
                    break
        for j in range(len(sumcheck)):
            # note: now j in here is the correct index for both sumcheck and df.value_counts(column).index
            try:
                if sumcheck[j] != df.value_counts(column).values[j]:
                    checklist.append(f"There is an inconsistency in the stratify column {column}; there are a total of {sumcheck[j]} data points in the split sets that belong to category {df.value_counts(column).index[j]}, while there are {df.value_counts(column).values[j]} in the original data set. ")
                else:
                    checklist.append(f"The category {df.value_counts(column).index[j]} is consistent; there are {sumcheck[j]} data points in both the original data set and totalled across all of the split data sets. ")
            except ValueError:
                checklist.append(f"There was a problem summing up data points for category {df.value_counts(column).index[j]}. ")
                break

        # Proportion consistency check
        for i in range(setcount):
            div = len(split[i])/len(df)
            if abs( div - proportions[i] ) > 0.09: # ...what is a good value to use here?
                checklist.append(f"The proportion of set {i} is inconsistent; it is {div} when it should be {proportions[i]}. ")
            else:
                checklist.append(f"The proportion of set {i} is consistent; it is {div}, which is within 0.1 of the assigned {proportions[i]} proportion. ")

        # Proportion consistency check within split sets
        unique_splitted = np.zeros((len(df[column].value_counts().index), setcount))
        for i in range(setcount):
            for j in range(df[column].value_counts().index.size):
                unique_splitted[j][i] = split[i].value_counts(column).values[j] / len(split[i])
        message = f"The stratify column is \"{column}\". "
        for j in range(df[column].value_counts().index.size):
            message += f"For {df[column].value_counts().index[j]}, the original proportions are {unique_proportions[j]} and in the {setcount} split sets, they are {unique_splitted[j,:]} respectively. "
        checklist.append(message)

        # ...add some way to check if each data point from the original df only appears once across all split sets (maybe use indices list)

        if return_indices:
            return split, split_indices, checklist
        
        return split, checklist
    
    if return_indices:
        return split, split_indices

    return split


# Function to write JSON objects from a DataFrame as JSONL format
def write_jsonl(df, filename):
    with open(filename, "w") as f:
        for index, row in df.iterrows():
            row_dict = row.to_dict() # convert rows to dictionaries
            f.write(json.dumps(row_dict) + "\n")


# Function to read jsonl files and convert them into dataframes just to check the split datasets to jsonl files function worked
def read_jsonl(pathtofile):
  filedata = []
  with open(pathtofile, "r") as f:
    for line in f:
      if not line.strip(): # skips empty lines
        continue
      filedata.append(json.loads(line))
  return pd.DataFrame(filedata)

In [None]:
# Use the splitter
split, checklist = stratsplit(df, "topic;", [0.7, 0.2, 0.1], shuffle=True, return_indices=False, check=True)
print(checklist)

# Write the split sets to JSONL files
write_jsonl(split[0], "/path/to/train/file.jsonl")
write_jsonl(split[1], "/path/to/test/file.jsonl")
write_jsonl(split[2], "/path/to/val/file.jsonl") # least important, could/should(?) actually be taken from within training set

# Import one back to check if it worked
#df_test = read_jsonl("/path/to/train/file.jsonl")
#len(df_test)
#df_test.value_counts("topic;")

Redistribution idea:
* if, for some j, we have sum(unique_per_set[j,:]) > df.value_counts(column).values[j], then check if there are instances of sum(unique_per_set[other_j,:]) < df.value_counts(column).values[other_j]
    * if there is, then store other_j (can be more than one)
    * idea is to -1 from unique_per_set[j,:] and +1 to unique_per_set[other_j,:]
    * if this is done, need to consider set_sizes
        * ideally redistribute within same set, but next best case is to remove from a set that hasn't had data points removed from before AND doesn't have perfect proportions, and add to a set that hasn't had data points added to before AND doesn't have perfect proportions
* determining where data should be added to or removed from, need to calculate proportions of all unique values (j) for all sets (i)
    * prop[j][i] = unique_per_set[j,i] / set_sizes[i]
    * should also consider prop[j][i] - unique_proportions[j]
        * a greater value (either the negative closest to 0 or the biggest positive) is ideal to take from, a smaller value is idea to give to, a value close to 0 should mean that set will remain untouched unless if all sets are like that
* ideally, redistribution would be from unique value j to unique value other_j, from set with greater difference value to the set with the smallest
    * if there are multiple sets with greatest (or smallest) difference value, then randomise to choose one
    * not sure if them being the same set will cause problems or not
    * if only splitting into 2 sets and somehow both have same differences from unique_proportions, then is it better to redistribute from one to the other or redistribute within same set?