# CelebA Dataset Analysis Notebook

This notebook documents the full analysis workflow for CelebA: loading annotations, computing attribute prevalence and correlations, generating sample grids, and creating overlays. It mirrors the logic in `src/celeba_analysis.py` for reproducibility.

In [None]:
import os, json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import cv2
sns.set_theme(style='whitegrid')
proj_root = os.path.abspath(os.path.join(os.getcwd(), '..')) if os.path.basename(os.getcwd()) == 'notebooks' else os.getcwd()
data_root = os.path.join(proj_root, 'data', 'archive')
img_root = os.path.join(data_root, 'img_align_celeba')
out_root = os.path.join(proj_root, 'outputs', 'celeba_analysis', 'figures')
os.makedirs(out_root, exist_ok=True)
attrs = pd.read_csv(os.path.join(data_root, 'list_attr_celeba.csv'))
parts = pd.read_csv(os.path.join(data_root, 'list_eval_partition.csv'))
lms = pd.read_csv(os.path.join(data_root, 'list_landmarks_align_celeba.csv'))
bbox = pd.read_csv(os.path.join(data_root, 'list_bbox_celeba.csv'))
df = attrs.merge(parts, on='image_id', how='left').merge(lms, on='image_id', how='left').merge(bbox, on='image_id', how='left')
attr_cols = [c for c in attrs.columns if c != 'image_id']
for c in attr_cols: df[c] = (df[c] == 1).astype(int)
n_images = df.shape[0]
print('Images:', n_images)

In [None]:
# Prevalence (Top 12)
prev = df[attr_cols].mean().sort_values(ascending=False)
plt.figure(figsize=(12,6))
sns.barplot(x=prev.head(12).index, y=prev.head(12).values, color='#4C78A8')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Prevalence')
plt.title('CelebA Attribute Prevalence (Top 12)')
plt.tight_layout()
plt.savefig(os.path.join(out_root, 'attr_prevalence_top12.png'), dpi=150)
plt.show()

In [None]:
# Correlation heatmap
corr = df[attr_cols].corr()
plt.figure(figsize=(12,10))
sns.heatmap(corr, cmap='coolwarm', vmin=-1, vmax=1)
plt.title('CelebA Attribute Correlation')
plt.tight_layout()
plt.savefig(os.path.join(out_root, 'attr_correlation.png'), dpi=150)
plt.show()

## Sample Grids and Overlays
The following cells generate positive sample grids for selected attributes and example overlays for bounding boxes and landmarks.

In [None]:
def make_grid(df, attr, n=16, fname='grid.png'):
    pos = df[df[attr]==1].sample(min(n, (df[attr]==1).sum()), random_state=42)
    imgs = []
    for _, row in pos.iterrows():
        p = os.path.join(img_root, row['image_id'])
        if not os.path.exists(p):
            continue
        im = cv2.imread(p)
        if im is None: continue
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        imgs.append(im)
        if len(imgs) >= n: break
    if not imgs: return
    rows = int(np.sqrt(n)); cols = int(np.ceil(n/rows)); h,w = 128,128
    canvas = np.ones((rows*h, cols*w, 3), dtype=np.uint8)*255
    i=0
    for r in range(rows):
        for c in range(cols):
            if i>=len(imgs): break
            thumb = cv2.resize(imgs[i], (w,h))
            canvas[r*h:(r+1)*h, c*w:(c+1)*w] = thumb
            i+=1
    plt.figure(figsize=(cols*2.5, rows*2.5)); plt.imshow(canvas); plt.axis('off'); plt.title(f'{attr}=1');
    outp = os.path.join(out_root, fname)
    plt.tight_layout(); plt.savefig(outp, dpi=150); plt.show()
    return outp

make_grid(df, 'Smiling', 16, 'samples_smiling.png')
make_grid(df, 'Male', 16, 'samples_male.png')
make_grid(df, 'Young', 16, 'samples_young.png')


In [None]:
# Overlay examples (first 12)
subset = df.head(2000)
picked = subset.head(12)
rows = 3; cols = 4; h,w = 160,160
canvas = np.ones((rows*h, cols*w, 3), dtype=np.uint8)*255
for i, (_, row) in enumerate(picked.iterrows()):
    p = os.path.join(img_root, row['image_id'])
    im = cv2.imread(p)
    if im is None: continue
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    # bbox
    try:
        x,y,bw,bh = int(row['x_1']), int(row['y_1']), int(row['width']), int(row['height'])
        cv2.rectangle(im, (x,y), (x+bw,y+bh), (255,0,0), 2)
    except: pass
    # landmarks
    try:
        pts = [(int(row['lefteye_x']), int(row['lefteye_y'])), (int(row['righteye_x']), int(row['righteye_y'])), (int(row['nose_x']), int(row['nose_y'])), (int(row['leftmouth_x']), int(row['leftmouth_y'])), (int(row['rightmouth_x']), int(row['rightmouth_y']))]
        for (px,py) in pts: cv2.circle(im, (px,py), 2, (0,255,0), -1)
    except: pass
    r=i//cols; c=i%cols
    thumb = cv2.resize(im, (w,h))
    canvas[r*h:(r+1)*h, c*w:(c+1)*w] = thumb
plt.figure(figsize=(cols*2.8, rows*2.8)); plt.imshow(canvas); plt.axis('off'); plt.title('BBoxes & Landmarks (samples)');
plt.tight_layout(); plt.savefig(os.path.join(out_root, 'bbox_landmarks_samples.png'), dpi=150); plt.show()


## Save Summary JSON
A compact JSON with the key numbers for programmatic use.

In [None]:
summary = {
    'n_images': int(df.shape[0]),
    'split_counts': df['partition'].map({0:'train',1:'val',2:'test'}).value_counts().to_dict(),
    'prevalence_top12': {k: float(v) for k, v in df[attr_cols].mean().sort_values(ascending=False).head(12).items()}
}
with open(os.path.join(proj_root, 'reports', 'celeba_summary.json'), 'w', encoding='utf-8') as f: json.dump(summary, f, indent=2)
summary