In [1]:
import os
import numpy as np
from PIL import Image
import json
from os.path import join, basename, splitext
from glob import glob

In [None]:
def compute_class_weights(folder, num_classes=9):
    """
    Compute class weights using median frequency balancing.

    Args:
        folder (str): The folder containing the segmentation masks.
        num_classes (int): Total number of classes including background.

    Returns:
        dict: A dictionary with keys as class index (string) and values as class weights.
    """
    # Initialize an array to store the pixel counts for each class.
    counts = np.zeros(num_classes, dtype=np.float64)
    total_pixels = 0
    
    # Iterate over all files in the provided folder.
    for filename in os.listdir(folder):
        # Filter for common image extensions.
        if filename.lower().endswith(('.png')):
            image_path = os.path.join(folder, filename)
            try:
                with Image.open(image_path) as img:
                    # Convert image to grayscale; segmentation masks are assumed to have values 0 to num_classes-1.
                    # Adjust mode conversion if your masks are stored differently.
                    arr = np.array(img)
            except Exception as e:
                print(f"Could not process {filename}: {e}")
                continue
            
            # Count the number of pixels for each class.
            for i in range(num_classes):
                counts[i] += np.sum(arr == i)
            total_pixels += arr.size

    # Calculate the frequency of each class (proportion of pixels).
    frequencies = counts / total_pixels if total_pixels > 0 else np.zeros_like(counts)
    
    # For median frequency balancing, we compute the median frequency among classes that are present.
    nonzero_frequencies = frequencies[frequencies > 0]
    if len(nonzero_frequencies) == 0:
        print("No class pixels found in the dataset.")
        return {}
    median_freq = np.median(nonzero_frequencies)

    # Compute weight for each class
    weights = {}
    for i in range(num_classes):
        if frequencies[i] > 0:
            weights[str(i)] = float(median_freq / frequencies[i])
        else:
            # If a class does not exist in the dataset, you may set its weight to 0 or some default value.
            weights[str(i)] = 0.0
    
    return weights

In [3]:
target_root = '../../data/tum_material'
assert os.path.isdir(target_root)
masks_dirs = glob(join(target_root, '*', 'anno'))
masks_dirs

['../../data/tum_material/Hand-drawn/anno',
 '../../data/tum_material/CAD/anno',
 '../../data/tum_material/BIM/anno']

In [4]:
for folder_path in masks_dirs:
    folder_name = basename(folder_path)
    print(f'Processing {folder_name}')
    weights = compute_class_weights(folder_path)
    parent_folder = os.path.dirname(folder_path)
    
    # Save the weights to a JSON file
    output_json_path = join(parent_folder, f'class_weights.json')
    with open(output_json_path, 'w') as json_file:
        json.dump(weights, json_file, indent=4)
    
    print(f'Class weights saved to {output_json_path}')

Processing anno
Class weights saved to ../../data/tum_material/Hand-drawn/class_weights.json
Processing anno
Class weights saved to ../../data/tum_material/CAD/class_weights.json
Processing anno
Class weights saved to ../../data/tum_material/BIM/class_weights.json
