In [None]:
""" Balance of the dataset | This code will downsample the dataset to easily-tune 
a max radio of 1:3. Minimum samples per class is also easily-tune parameter.

This code will balance each target active-inactives radio, NOT just
whole set
"""

## modules 
from sklearn.utils import resample

## target dataframe from original dataframe
targets = df['PDB_ID'].unique()

## hyperparameters
min_examples = 100   # Minimum number of examples for each class
max_ratio = 3  # Maximum ratio between the minority and majority classes

balanced_data = []

## loop for each target
for target in targets:
    target_data = df[df['PDB_ID'] == target]
    
    # Calculate the counts of each class
    class_counts = target_data['label'].value_counts()
    
    # only consider targets that has at least 'min_examples' examples in both classes
    if len(class_counts) < 2 or any(class_counts < min_examples):
        continue
    
    # determine the minority and majority classes
    minority_class = class_counts.idxmin()
    majority_class = class_counts.idxmax()
    
    # Separate the minority and majority classes
    minority_data = target_data[target_data['label'] == minority_class]
    majority_data = target_data[target_data['label'] == majority_class]
    
    minority_count = minority_data.shape[0]
    majority_count = majority_data.shape[0]
    
    if minority_count >= min_examples and minority_count * max_ratio <= majority_count:
        # Exclude targets with already balanced or minority-dominant class distributions
        continue
    
    if majority_count > minority_count * max_ratio:
        # Undersample the majority class to achieve the desired ratio
        desired_majority_count = min(minority_count * max_ratio, majority_count)  # Maximum ratio
        
        undersampled_majority = resample(
            majority_data,
            replace=False,
            n_samples=desired_majority_count,
            random_state=42
        )
        
        balanced_target_data = pd.concat([minority_data, undersampled_majority])
    else:
        balanced_target_data = target_data
    
    # Add the balanced data for this target to the list
    balanced_data.append(balanced_target_data)
        
# Combine the balanced data for all targets into a single DataFrame
balanced_df = pd.concat(balanced_data)

# rename dataset to original df
df = balanced_df ## basically I reused code already named df