# Comparison Smile Scores Temperature

This notebook compares the predictions including confidences of the smile classifier with and without temperature scaling in order to decide, whether temperature scaling is necessary for the smile classifier or not.

## Setup

In [None]:
from src.classification.smile_classifier import SmileClassifier

In [None]:
# Configuration
PRETRAINED_CLASSIFIER_PATH = "../models/classifier/celeba_smile/predictor_128.pth.tar"
SCALED_CLASSIFIER_PATH = "../models/classifier/celeba_smile/predictor_128_scaled3.pth.tar"
ATTR_FILE = "../models/classifier/celeba_smile/attributes.json"

### Load smile classifiers

In [None]:
# Create classifier
smile_classifier = SmileClassifier(PRETRAINED_CLASSIFIER_PATH, ATTR_FILE, scaled=False, device="cpu")

# Scaled classifier
smile_classifier_scaled = SmileClassifier(SCALED_CLASSIFIER_PATH, ATTR_FILE, scaled=True, device="cpu")

### Load FFHQ dataset

In [None]:
from argparse import Namespace
from torchvision import transforms
from src.dataloader.ffhq import FFHQDataset

# Load FFHQ dataset
ffhq_dataset = FFHQDataset(
    args=Namespace(
		img_dir="../data/ffhq/images1024x1024",
		attr_path="../data/ffhq/ffhq_smile_scores.json",
		max_property_value=5,
		min_property_value=0,
		batch_size=16,
		num_workers=0,
		val_split=0,
	),
    transform=transforms.Compose([
		transforms.Resize((224, 224)),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
	]),
)

### Load eval batch

In [None]:
import torch

eval_batch = torch.load("../data/ffhq/eval/batch_256.pt")

## Visual Comparison on eval batch

In [None]:
# Predict smile scores using unscaled classifier
unscaled_predictions = smile_classifier(eval_batch)

# Predict smile scores using scaled classifier
scaled_predictions = smile_classifier_scaled(eval_batch)

In [None]:
import matplotlib.pyplot as plt

# Visualize predictions
fig, axes = plt.subplots(2, 6, figsize=(10, 5))
axes = axes.flatten()

for i in range(12):
    axes[i].imshow(eval_batch[i].permute(1, 2, 0).numpy() * 0.5 + 0.5)
    axes[i].axis('off')
    axes[i].set_title(f"Unscaled: {unscaled_predictions[i].item():.2f}\nScaled: {scaled_predictions[i].item():.2f}")

plt.tight_layout()
plt.show()

## Histogram of predictions

To plot the following histogram, the files `data/ffhq/ffhq_smile_scores.json` and `data/ffhq/ffhq_smile_scores_scaled.json` are required. These files contain the predictions of the unscaled and scaled classifier, respectively, for the FFHQ dataset.

If not available, they can be generated by running the file `src/run/initial_smile_classification.py` for both the unscaled and scaled classifier.

In [None]:
# Load smile scores from JSON files
import json

with open("../data/ffhq/smile_scores.json", "r") as f:
	unscaled_scores = json.load(f)
	unscaled_scores = [score for score in unscaled_scores.values()]

with open("../data/ffhq/smile_scores_scaled.json", "r") as f:
	scaled_scores = json.load(f)
	scaled_scores = [score for score in scaled_scores.values()]

In [None]:
import matplotlib.pyplot as plt

# Plot histogram of predictions
plt.figure(figsize=(12, 6))
plt.hist(unscaled_scores, bins=150, color='blue', alpha=0.7, label='Unscaled')
plt.hist(scaled_scores, bins=150, color='orange', alpha=0.7, label='Scaled')
plt.title('Comparison of Unscaled and Temperature Scaled Smile Scores')
plt.xlabel('Smile Score')
plt.ylabel('Frequency')
plt.legend()
plt.show()

## Unscaled Histogram

In [None]:
import matplotlib.pyplot as plt

# Plot histogram of predictions
plt.figure(figsize=(12, 6))

# Compute histogram without coloring
counts, bin_edges, patches = plt.hist(
    unscaled_scores,
    bins=150,
    color='gray',      # default fallback
    alpha=0.7
)

# Re-color each patch based on its right edge
for edge, patch in zip(bin_edges[1:], patches):
    if edge < 2:
        patch.set_facecolor('blue')
    else:
        patch.set_facecolor('gray')

plt.axvline(x=2, color='red', linestyle='--', label='Input Max')
plt.title('Smile Scores Distribution on FFHQ Dataset')
plt.xlabel('Smile Score')
plt.ylabel('Frequency')
plt.legend()
plt.show()

## Unscaled examples

Plots an example image for each smile score in the range of 0 to 5. The images are taken from the FFHQ dataset and the smile scores are taken from the unscaled classifier.

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

# Quantize smile scores to nearest integer
quantized_scores = [round(score) for score in unscaled_scores]

# Create a figure with subplots
fig, axes = plt.subplots(1, 6, figsize=(15, 5))
# Iterate over the range of smile scores
for i in range(6):
	# Get the index of the image with the current smile score
	idx = quantized_scores.index(i)
	# Load the image
	img = Image.open(f"../data/ffhq/images1024x1024/{idx:05d}.png")
	axes[i].imshow(img)
	axes[i].axis('off')
	axes[i].set_title(f"{i}", fontsize=16)
plt.tight_layout()
plt.show()