<a href="https://colab.research.google.com/github/ethanwongca/hai_work/blob/main/cv_split.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
 #@title Step 2: Import Libraries and Define the Grouped Split Function

import os
import pickle
import random
import numpy as np
from sklearn.model_selection import GroupKFold

# Set manual seed for reproducibility
MANUAL_SEED = 1
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)

def get_grouped_splits(confused_items, not_confused_items, k):
    """
    Splits data ensuring that no single user’s data appears in both training
    and testing sets using GroupKFold.

    Args:
        confused_items (list): list of file names for the confused class.
        not_confused_items (list): list of file names for the not_confused class.
        k (int): number of folds for cross-validation.

    Returns:
        tuple: (train_confused_splits, test_confused_splits,
                train_not_confused_splits, test_not_confused_splits)
    """
    train_confused_splits = []
    test_confused_splits = []
    train_not_confused_splits = []
    test_not_confused_splits = []

    # Build a groups list based on a user identifier extracted from the filename.
    # This example assumes that the filename contains the user ID at the beginning,
    # e.g. "U123_itemname.pkl". Adjust the splitting logic if your filenames differ.
    groups_confused = [fname.split('_')[0] + "_" + fname.split('_')[1] for fname in confused_items]
    groups_not_confused = [fname.split('_')[0] + "_" + fname.split('_')[1] for fname in not_confused_items]

    # Combine items and groups for both classes.
    items = confused_items + not_confused_items
    groups = groups_confused + groups_not_confused
    # Dummy labels: 0 for confused, 1 for not_confused.
    dummy_y = [0] * len(confused_items) + [1] * len(not_confused_items)

    gkf = GroupKFold(n_splits=k)
    for train_idx, test_idx in gkf.split(X=items, y=dummy_y, groups=groups):
        # Split items based on the dummy labels.
        train_confused = [items[i] for i in train_idx if dummy_y[i] == 0]
        test_confused = [items[i] for i in test_idx if dummy_y[i] == 0]
        train_not_confused = [items[i] for i in train_idx if dummy_y[i] == 1]
        test_not_confused = [items[i] for i in test_idx if dummy_y[i] == 1]

        train_confused_splits.append(train_confused)
        test_confused_splits.append(test_confused)
        train_not_confused_splits.append(train_not_confused)
        test_not_confused_splits.append(test_not_confused)

    return (train_confused_splits, test_confused_splits,
            train_not_confused_splits, test_not_confused_splits)


In [None]:
#@title Step 3: Define Data Directories and Collect File Names

# Adjust these paths to point to your data directories in Google Drive.
# For example, suppose your Drive contains a folder "dataset" with two subfolders:
#   - confused: containing files for the confused class
#   - not_confused: containing files for the not_confused class

confused_dir = '/content/drive/My Drive/msnv_data/VTNet_att/msnv_final_data/Meara_label/high/pickle_files'
not_confused_dir = '/content/drive/My Drive/msnv_data/VTNet_att/msnv_final_data/Meara_label/low/pickle_files'

# Get list of .pkl files from each directory.
confused_items = sorted([f for f in os.listdir(confused_dir) if f.endswith('.pkl')])
not_confused_items = sorted([f for f in os.listdir(not_confused_dir) if f.endswith('.pkl')])

print("Number of confused items:", len(confused_items))
print("Number of not_confused items:", len(not_confused_items))


Number of confused items: 260
Number of not_confused items: 252


In [None]:
#@title Step 4: Create 10-Fold Cross-Validation Splits

# Number of folds for cross-validation
k = 10

# Generate the splits
splits = get_grouped_splits(confused_items, not_confused_items, k)

(train_confused_splits, test_confused_splits,
 train_not_confused_splits, test_not_confused_splits) = splits

# For example, print the splits for the first fold:
print("Fold 1 - Confused Training Files:", train_confused_splits[0])
print("Fold 1 - Confused Testing Files:", test_confused_splits[0])
print("Fold 1 - Not Confused Training Files:", train_not_confused_splits[0])
print("Fold 1 - Not Confused Testing Files:", test_not_confused_splits[0])


Fold 1 - Confused Training Files: ['bar_10_1.pkl', 'bar_10_2.pkl', 'bar_10_3.pkl', 'bar_10_4.pkl', 'bar_13_1.pkl', 'bar_13_2.pkl', 'bar_13_3.pkl', 'bar_13_4.pkl', 'bar_14_1.pkl', 'bar_14_2.pkl', 'bar_14_3.pkl', 'bar_14_4.pkl', 'bar_15_1.pkl', 'bar_15_2.pkl', 'bar_15_3.pkl', 'bar_15_4.pkl', 'bar_18_1.pkl', 'bar_18_2.pkl', 'bar_18_3.pkl', 'bar_18_4.pkl', 'bar_1_1.pkl', 'bar_1_2.pkl', 'bar_1_3.pkl', 'bar_1_4.pkl', 'bar_20_1.pkl', 'bar_20_2.pkl', 'bar_20_3.pkl', 'bar_20_4.pkl', 'bar_21_1.pkl', 'bar_21_2.pkl', 'bar_21_3.pkl', 'bar_21_4.pkl', 'bar_22_1.pkl', 'bar_22_2.pkl', 'bar_22_3.pkl', 'bar_22_4.pkl', 'bar_24_1.pkl', 'bar_24_2.pkl', 'bar_24_3.pkl', 'bar_24_4.pkl', 'bar_25_1.pkl', 'bar_25_2.pkl', 'bar_25_3.pkl', 'bar_25_4.pkl', 'bar_29_1.pkl', 'bar_29_2.pkl', 'bar_29_3.pkl', 'bar_29_4.pkl', 'bar_2_1.pkl', 'bar_2_2.pkl', 'bar_2_3.pkl', 'bar_2_4.pkl', 'bar_30_1.pkl', 'bar_30_2.pkl', 'bar_30_3.pkl', 'bar_30_4.pkl', 'bar_35_1.pkl', 'bar_35_2.pkl', 'bar_35_3.pkl', 'bar_35_4.pkl', 'bar_37_1.pkl

In [None]:
#@title Step 5: Save the Splits to a Pickle File

# Combine the splits into a single object. This structure can be loaded later for cross-validation.
cv_splits = [train_confused_splits, test_confused_splits, train_not_confused_splits, test_not_confused_splits]

# Define an output path for the pickle file.
output_pickle_path = '/content/drive/My Drive/msnv_data/VTNet_att/meara.pickle'

with open(output_pickle_path, 'wb') as f:
    pickle.dump(cv_splits, f)

print(f"Saved 10-fold CV splits to {output_pickle_path}")


Saved 10-fold CV splits to /content/drive/My Drive/msnv_data/VTNet_att/meara.pickle
