In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="PIL")

class_names = [
    "road", "sidewalk", "building", "wall", "fence", "pole",
    "traffic light", "traffic sign", "vegetation", "terrain", "sky",
    "person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"
]

palette = [
    (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
    (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
    (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
    (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
    (0, 80, 100), (0, 0, 230), (119, 11, 32)
]

mask_dir = "/masks"


def compute_class_pixel_counts(mask_dir, palette):
    counts = np.zeros(len(palette), dtype=np.int64)
    for fname in tqdm(sorted(os.listdir(mask_dir)), desc="Counting masks"):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        mask = np.array(Image.open(os.path.join(mask_dir, fname)).convert("RGB"))
        for idx, color in enumerate(palette):
            counts[idx] += np.all(mask == color, axis=-1).sum()
    return counts


counts = compute_class_pixel_counts(mask_dir, palette)
total_pixels = counts.sum()
freq = counts / total_pixels

plt.figure(figsize=(12, 6))
plt.bar(class_names, freq)
plt.xticks(rotation=90)
plt.ylabel("Fraction of pixels")
plt.title("Class Pixel Distribution")
plt.tight_layout()
plt.show()

median = np.median(freq[freq > 0])
w = median / freq
print("Median-frequency weights:", w)

w = np.maximum(w, 0.01)
w = w / w.sum()
print("Capped & normalized weights:", w)

w = np.sqrt(w)
w = w / w.sum()
print("Sqrt-scaled & normalized weights:", w)

COLOR_MAP = {tuple(color): idx for idx, color in enumerate(palette)}
CLASS_WEIGHTS = {
    COLOR_MAP[tuple(palette[i])]: round(float(w[i]), 2)
    for i in range(len(w))
}

print("COLOR_MAP =", COLOR_MAP)
print("FINAL CLASS_WEIGHTS =", CLASS_WEIGHTS)
