# Under construction: trying to make this into a Google Colab notebook that people can run without installing Python

In [None]:
import pandas as pd
import numpy as np
from itertools import cycle

In [3]:
import pandas as pd
import numpy as np
from itertools import cycle
from IPython.display import display
import ipywidgets as widgets

# Upload a file in Jupyter/Colab
# from google.colab import files  # remove this line if not using Colab

# ---- Upload file interactively ----
# uploaded = files.upload()
# tsv_file = list(uploaded.keys())[0]

In [4]:
# ---- Functions ----
def load_data(file_path):
    try:
        data = pd.read_csv(file_path, sep='\t')
        if 'Class' not in data.columns or 'Order' not in data.columns:
            raise KeyError("Dataset must contain 'Class' and 'Order' columns.")
        return data
    except Exception as e:
        raise RuntimeError(f"Failed to load data: {e}")

def filter_data(data, column, value, negate=False):
    if column not in data.columns:
        raise ValueError(f"Column '{column}' does not exist in the DataFrame.")
    if negate:
        return data[data[column] != value]
    else:
        return data[data[column] == value]

def sample_from_orders(data_subset, orders, total_samples):
    sampled = pd.DataFrame()
    order_cycle = cycle(orders)
    while len(sampled) < total_samples:
        current_order = next(order_cycle)
        subset = data_subset[data_subset['Order'] == current_order]
        if not subset.empty:
            sampled = pd.concat([sampled, subset.sample(n=1, replace=True)], ignore_index=True)
    return sampled

def sample_data(data, selected_class, num_samples, num_norders):
    class_subset = filter_data(data, 'Class', selected_class)
    if class_subset.empty:
        raise ValueError("No data found for the selected class.")

    unique_orders = class_subset['Order'].dropna().unique()
    num_orders_to_sample = min(5, len(unique_orders))
    chosen_orders = np.random.choice(unique_orders, num_orders_to_sample, replace=False)
    print(f"Selected orders from class '{selected_class}': {chosen_orders}")

    selected_samples = sample_from_orders(class_subset, chosen_orders, num_samples)

    nonselected_subset = filter_data(data, 'Class', selected_class, negate=True)
    valid_norders = nonselected_subset['Order'].dropna().unique()
    if num_norders > len(valid_norders):
        raise ValueError("Not enough unique non-selected orders to sample from.")

    chosen_norders = np.random.choice(valid_norders, num_norders, replace=False)
    print(f"Selected orders from non-selected class: {chosen_norders}")

    nonselected_samples = sample_from_orders(nonselected_subset, chosen_norders, num_samples)

    return pd.concat([selected_samples, nonselected_samples], ignore_index=True)

In [None]:
# ---- Load data and show columns ----
# tsv_file = 'path_to_your_file.tsv'  # Uncomment and set your file path if not using upload
data = load_data(tsv_file)
print(f"Columns in data: {data.columns.tolist()}")
display(data.head())

# ---- Interactive Controls ----
class_dropdown = widgets.Dropdown(
    options=sorted(data['Class'].dropna().unique()),
    description='Class:',
    style={'description_width': 'initial'}
)

num_samples_input = widgets.IntText(
    value=5,
    description='Samples per order:',
    style={'description_width': 'initial'}
)

num_norders_input = widgets.IntText(
    value=5,
    description='# of non-class orders:',
    style={'description_width': 'initial'}
)

output_name = widgets.Text(
    value='sampled_data.tsv',
    description='Output file:',
    style={'description_width': 'initial'}
)

run_button = widgets.Button(description="Run Sampling")

output_box = widgets.Output()

def on_run_button_clicked(b):
    with output_box:
        output_box.clear_output()
        try:
            result = sample_data(
                data,
                class_dropdown.value,
                num_samples_input.value,
                num_norders_input.value
            )
            print("Sampled data:")
            display(result)

            outname = output_name.value.strip()
            if not outname.endswith(".tsv"):
                outname += ".tsv"

            result.to_csv(outname, sep='\t', index=False)
            print(f"Output written to: {outname}")
            files.download(outname)
        except Exception as e:
            print(f"Error: {e}")

run_button.on_click(on_run_button_clicked)

# ---- Display the widgets ----
display(class_dropdown, num_samples_input, num_norders_input, output_name, run_button, output_box)