In [8]:
import pandas as pd
from collections import Counter
import random
from collections import Counter
import ast

In [9]:
train_data = pd.read_csv('median/preprocessed data/preprocessed_train_data.csv')

In [10]:
# Flatten the list of labels if it's a list of lists
flattened_labels = []
for label_list in train_data["labels"]:
    # Convert string representation of list to actual list
    if isinstance(label_list, str):
        # Remove brackets and split by comma
        clean_label_list = label_list.strip('[]').replace("'", "").split(', ')
    else:
        clean_label_list = label_list
    
    # Add each label to the flattened list
    flattened_labels.extend(clean_label_list)

# Print labels distribution count
label_counts = Counter(flattened_labels)
print("Label Distribution Count (Training Data):")

for i, (label, count) in enumerate(label_counts.most_common()):
    print(f"{i+1}. {label}: {count}")

Label Distribution Count (Training Data):
1. area/test: 309
2. area/kubelet: 304
3. area/apiserver: 196
4. area/cloudprovider: 171
5. area/kubectl: 133
6. area/dependency: 80
7. area/code-generation: 65
8. area/provider/azure: 64
9. area/ipvs: 39
10. area/kubeadm: 33
11. area/provider/gcp: 32
12. area/api: 28
13. area/e2e-test-framework: 28
14. area/kube-proxy: 28
15. area/release-eng: 28
16. area/conformance: 28
17. area/batch: 28
18. area/deflake: 28
19. area/network-policy: 28
20. area/client-libraries: 28
21. area/code-organization: 28
22. area/security: 28
23. area/etcd: 28
24. area/custom-resources: 28
25. area/provider/aws: 28


In [11]:
# Ensure 'labels' column contains lists of strings
def parse_labels(label_list_str):
    if isinstance(label_list_str, str):
        try:
            # Safely evaluate the string representation of the list
            parsed_list = ast.literal_eval(label_list_str)
            if isinstance(parsed_list, list):
                # Ensure all elements are strings
                return [str(item) for item in parsed_list]
            else:
                return [] # Return empty list if parsing results in non-list
        except (ValueError, SyntaxError):
             # Handle cases like '[]', "['label1', 'label2']" or malformed strings
             clean_label_list_str = label_list_str.strip('[]').replace("'", "").replace('"', '')
             if not clean_label_list_str: # Handle empty string case '[]'
                 return []
             clean_label_list = clean_label_list_str.split(',') # Split by comma first
             # Clean whitespace and filter empty strings resulting from split
             return [label.strip() for label in clean_label_list if label.strip()] 
    elif isinstance(label_list_str, list):
         return [str(item) for item in label_list_str] # Ensure all elements are strings
    else: # Handle other potential types like float (NaN) if there are missing values
        return []

# Apply parsing to create a reliable list representation
train_data['parsed_labels'] = train_data['labels'].apply(parse_labels)

# Recalculate flattened_labels and label_counts using the parsed column
flattened_labels_parsed = [label for sublist in train_data['parsed_labels'] for label in sublist]
label_counts_parsed = Counter(flattened_labels_parsed)

# --- Undersampling Logic ---
random.seed(42) # for reproducibility

max_samples_per_label = 28
oversampled_labels = {label for label, count in label_counts_parsed.items() if count > max_samples_per_label}

indices_to_keep = set()
indices_per_label = {label: [] for label in label_counts_parsed.keys()}

# Group indices by label
for index, row in train_data.iterrows():
    for label in row['parsed_labels']:
        if label in indices_per_label: # Ensure label exists in our count keys
             indices_per_label[label].append(index)

# Undersample for over-represented labels and collect indices
for label, indices in indices_per_label.items():
    unique_indices = list(set(indices)) # Ensure unique indices per label before sampling
    if label in oversampled_labels:
        if len(unique_indices) > max_samples_per_label:
             indices_to_keep.update(random.sample(unique_indices, max_samples_per_label))
        else:
             # This case means the count was > 28 but unique rows are <= 28 (due to multi-label rows)
             # Keep all unique rows containing this label
             indices_to_keep.update(unique_indices)
    else:
        # Keep all unique indices for labels already at or below the threshold
        indices_to_keep.update(unique_indices)


# Create the balanced dataframe
balanced_train_data = train_data.loc[sorted(list(indices_to_keep))].copy()

# Optional: Drop the temporary parsed_labels column if not needed later
# balanced_train_data = balanced_train_data.drop(columns=['parsed_labels'])

print(f"Original training data size: {len(train_data)}")
print(f"Balanced training data size: {len(balanced_train_data)}")


Original training data size: 1268
Balanced training data size: 609


In [None]:
# Verify Balanced Distribution
# Ensure 'parsed_labels' column exists if it was dropped, otherwise re-parse or use 'labels'
if 'parsed_labels' not in balanced_train_data.columns:
     # Re-apply parsing if the column was dropped or doesn't exist
     balanced_train_data['parsed_labels'] = balanced_train_data['labels'].apply(parse_labels)

flattened_labels_balanced = [label for sublist in balanced_train_data['parsed_labels'] for label in sublist]
label_counts_balanced = Counter(flattened_labels_balanced)

print("\nLabel Distribution Count (Balanced Training Data):")
# Sort by count descending for clarity
for i, (label, count) in enumerate(label_counts_balanced.most_common()):
    print(f"{i+1}. {label}: {count}")


Label Distribution Count (Balanced Training Data):
1. area/test: 96
2. area/cloudprovider: 89
3. area/apiserver: 76
4. area/kubelet: 64
5. area/dependency: 54
6. area/kubectl: 54
7. area/code-generation: 47
8. area/ipvs: 33
9. area/provider/azure: 32
10. area/kubeadm: 29
11. area/api: 28
12. area/e2e-test-framework: 28
13. area/kube-proxy: 28
14. area/release-eng: 28
15. area/provider/gcp: 28
16. area/conformance: 28
17. area/batch: 28
18. area/deflake: 28
19. area/network-policy: 28
20. area/client-libraries: 28
21. area/code-organization: 28
22. area/security: 28
23. area/etcd: 28
24. area/custom-resources: 28
25. area/provider/aws: 28

All labels have <= 28 samples: False


In [13]:
# export the balanced data
balanced_train_data.to_csv('median/preprocessed data/balanced_train_data.csv', index=False)