In [None]:
import pandas as pd
import mysql.connector
import numpy as np
from collections import defaultdict

# Function to create MySQL connection
def create_mysql_connection():
    return mysql.connector.connect(
        host='localhost',
        user='************',
        password='**************',
        database='Audio_features'
    )

# Function to load MySQL table into a DataFrame
def load_mysql_table_to_df(table_name, conn):
    cursor = conn.cursor(dictionary=True)
    query = f"SELECT * FROM {table_name}"
    cursor.execute(query)
    rows = cursor.fetchall()
    cursor.close()
    return pd.DataFrame(rows)

# Improved function for more balanced folds
def assign_folds_balanced(df, group_id_col, n_splits=5):
    """
    Assign fold numbers to each record ensuring an even distribution of total samples per fold.

    Args:
    - df (pd.DataFrame): DataFrame with data.
    - group_id_col (str): Column name containing group_id (message_id).
    - n_splits (int): Number of folds for cross-validation.

    Returns:
    - pd.DataFrame: DataFrame with an additional 'fold' column.
    """
    # Extract klaatch_id from group_id
    df['klaatch_id'] = df[group_id_col].str.split('_').str[0]
    
    # Count number of records per klaatch_id
    klaatch_id_counts = df['klaatch_id'].value_counts().reset_index()
    klaatch_id_counts.columns = ['klaatch_id', 'count']

    # Sort klaatch_ids by number of records (descending)
    klaatch_id_counts = klaatch_id_counts.sort_values(by='count', ascending=False)

    # Initialize fold assignments
    fold_sizes = {i: 0 for i in range(n_splits)}  # Tracks total records per fold
    fold_assignments = {}

    # Greedy assignment: Assign each klaatch_id to the fold with the least total records
    for _, row in klaatch_id_counts.iterrows():
        klaatch_id = row['klaatch_id']
        count = row['count']
        # Find the fold with the least total assigned records
        best_fold = min(fold_sizes, key=fold_sizes.get)
        fold_assignments[klaatch_id] = best_fold
        fold_sizes[best_fold] += count  # Update fold size

    # Map fold assignments to the DataFrame
    df['fold'] = df['klaatch_id'].map(fold_assignments)
    
    return df

# Function to update SQL table with new columns and values
def update_sql_table_with_folds(df, table_name, klaatch_id_col, fold_col, conn):
    cursor = conn.cursor()
    
    # Check if columns exist and add them if they don’t
    cursor.execute(f"SHOW COLUMNS FROM {table_name} LIKE '{klaatch_id_col}'")
    if not cursor.fetchone():
        cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {klaatch_id_col} VARCHAR(255)")

    cursor.execute(f"SHOW COLUMNS FROM {table_name} LIKE '{fold_col}'")
    if not cursor.fetchone():
        cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {fold_col} INT")

    # Update klaatch_id and fold values
    update_query = f"""
        UPDATE {table_name}
        SET {klaatch_id_col} = %s, {fold_col} = %s
        WHERE message_id = %s
    """
    
    for _, row in df.iterrows():
        cursor.execute(update_query, (row['klaatch_id'], row['fold'], row['message_id']))
    
    conn.commit()
    cursor.close()

# Define parameters
table_name = 'stratified_female'
klaatch_id_col = 'klaatch_id'
fold_col = 'fold'

# Create MySQL connection
conn = create_mysql_connection()

# Load the table into a DataFrame
df = load_mysql_table_to_df(table_name, conn)

# Assign folds with better balance
df_with_folds = assign_folds_balanced(df, group_id_col='message_id', n_splits=5)




In [3]:
def validate_folds(df, klaatch_id_col='klaatch_id', fold_col='fold'):
    """
    Validates that each klaatch_id is assigned to only one fold.

    Args:
    - df (pd.DataFrame): DataFrame with klaatch_id and fold assignments.
    - klaatch_id_col (str): Column name containing klaatch_id.
    - fold_col (str): Column name containing fold assignments.

    Returns:
    - dict: klaatch_id(s) that are assigned to multiple folds (if any).
    """
    # Group by klaatch_id and check how many unique folds each klaatch_id appears in
    klaatch_fold_map = df.groupby(klaatch_id_col)[fold_col].nunique()
    
    # Identify klaatch_ids that appear in multiple folds
    invalid_klaatch_ids = klaatch_fold_map[klaatch_fold_map > 1].to_dict()
    
    if invalid_klaatch_ids:
        print("Warning: The following klaatch_id(s) are assigned to multiple folds:")
        print(invalid_klaatch_ids)
    else:
        print("Validation passed: Each klaatch_id is assigned to only one fold.")
    
    return invalid_klaatch_ids

# Run the validation function on df_with_folds
invalid_klaatch_ids = validate_folds(df_with_folds)


Validation passed: Each klaatch_id is assigned to only one fold.


In [3]:
df_with_folds[['fold','message_id','klaatch_id']]


Unnamed: 0,fold,message_id,klaatch_id
0,3,559_2021-01-22,559
1,1,66_2021-01-26,66
2,0,340_2021-01-26,340
3,0,343_2021-01-26,343
4,0,383_2021-01-27,383
...,...,...,...
1463,1,941934_2024-11-07,941934
1464,4,952421_2024-11-19,952421
1465,3,981476_2024-11-01,981476
1466,2,983777_2024-11-22,983777


In [16]:
fold_counts = df_with_folds['fold'].value_counts()
fold_counts

fold
1    76
3    76
2    76
0    76
4    75
Name: count, dtype: int64

In [4]:
kk =df_with_folds[df_with_folds['klaatch_id'] == '747']
kk[['KlaatchID','fold' ]]

Unnamed: 0,KlaatchID,fold
571,747,2
981,747,2
1107,747,2
1413,747,2


In [17]:
# Update SQL table with klaatch_id and fold assignments
update_sql_table_with_folds(df_with_folds, table_name, klaatch_id_col, fold_col, conn)

# Close the connection
conn.close()